Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d6bd6f3fb9 | ||
|
|
8ae8a5c296 | ||
|
|
dc804e96fb | ||
|
|
ab76cb3662 | ||
|
|
2965bdadc1 | ||
|
|
40f7061b04 |
@@ -508,6 +508,10 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err = os.Remove(full); err == nil {
|
||||
if errDel := h.deleteTokenRecord(ctx, full); errDel != nil {
|
||||
c.JSON(500, gin.H{"error": errDel.Error()})
|
||||
return
|
||||
}
|
||||
deleted++
|
||||
h.disableAuth(ctx, full)
|
||||
}
|
||||
@@ -534,10 +538,32 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.deleteTokenRecord(ctx, full); err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.disableAuth(ctx, full)
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) authIDForPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return path
|
||||
}
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if authDir == "" {
|
||||
return path
|
||||
}
|
||||
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
|
||||
return rel
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
||||
if h.authManager == nil {
|
||||
return nil
|
||||
@@ -566,13 +592,18 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
}
|
||||
lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
|
||||
|
||||
authID := h.authIDForPath(path)
|
||||
if authID == "" {
|
||||
authID = path
|
||||
}
|
||||
attr := map[string]string{
|
||||
"path": path,
|
||||
"source": path,
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: path,
|
||||
ID: authID,
|
||||
Provider: provider,
|
||||
FileName: filepath.Base(path),
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attr,
|
||||
@@ -583,7 +614,7 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
if hasLastRefresh {
|
||||
auth.LastRefreshedAt = lastRefresh
|
||||
}
|
||||
if existing, ok := h.authManager.GetByID(path); ok {
|
||||
if existing, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
if !hasLastRefresh {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
@@ -598,10 +629,17 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h.authManager == nil || id == "" {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(id); ok {
|
||||
authID := h.authIDForPath(id)
|
||||
if authID == "" {
|
||||
authID = strings.TrimSpace(id)
|
||||
}
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.Disabled = true
|
||||
auth.Status = coreauth.StatusDisabled
|
||||
auth.StatusMessage = "removed via management API"
|
||||
@@ -610,9 +648,20 @@ func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
|
||||
if record == nil {
|
||||
return "", fmt.Errorf("token record is nil")
|
||||
func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return fmt.Errorf("auth path is empty")
|
||||
}
|
||||
store := h.tokenStoreWithBaseDir()
|
||||
if store == nil {
|
||||
return fmt.Errorf("token store unavailable")
|
||||
}
|
||||
return store.Delete(ctx, path)
|
||||
}
|
||||
|
||||
func (h *Handler) tokenStoreWithBaseDir() coreauth.Store {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
store := h.tokenStore
|
||||
if store == nil {
|
||||
@@ -624,6 +673,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
dirSetter.SetBaseDir(h.cfg.AuthDir)
|
||||
}
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
|
||||
if record == nil {
|
||||
return "", fmt.Errorf("token record is nil")
|
||||
}
|
||||
store := h.tokenStoreWithBaseDir()
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ func (h *Handler) GetConfigYAML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var node yaml.Node
|
||||
if err := yaml.Unmarshal(data, &node); err != nil {
|
||||
if err = yaml.Unmarshal(data, &node); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "parse_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -41,17 +41,18 @@ func (h *Handler) GetConfigYAML(c *gin.Context) {
|
||||
}
|
||||
|
||||
func WriteConfig(path string, data []byte) error {
|
||||
data = config.NormalizeCommentIndentation(data)
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write(data); err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
if _, errWrite := f.Write(data); errWrite != nil {
|
||||
_ = f.Close()
|
||||
return errWrite
|
||||
}
|
||||
if err := f.Sync(); err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
if errSync := f.Sync(); errSync != nil {
|
||||
_ = f.Close()
|
||||
return errSync
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
@@ -63,7 +64,7 @@ func (h *Handler) PutConfigYAML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var cfg config.Config
|
||||
if err := yaml.Unmarshal(body, &cfg); err != nil {
|
||||
if err = yaml.Unmarshal(body, &cfg); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -75,18 +76,20 @@ func (h *Handler) PutConfigYAML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
tempFile := tmpFile.Name()
|
||||
if _, err := tmpFile.Write(body); err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
|
||||
if _, errWrite := tmpFile.Write(body); errWrite != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
|
||||
return
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
|
||||
if errClose := tmpFile.Close(); errClose != nil {
|
||||
_ = os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempFile)
|
||||
defer func() {
|
||||
_ = os.Remove(tempFile)
|
||||
}()
|
||||
_, err = config.LoadConfigOptional(tempFile, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
|
||||
@@ -153,6 +156,14 @@ func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
|
||||
}
|
||||
|
||||
// Websocket auth
|
||||
func (h *Handler) GetWebsocketAuth(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
|
||||
}
|
||||
func (h *Handler) PutWebsocketAuth(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
|
||||
}
|
||||
|
||||
// Request retry
|
||||
func (h *Handler) GetRequestRetry(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
|
||||
|
||||
156
internal/api/handlers/management/vertex_import.go
Normal file
156
internal/api/handlers/management/vertex_import.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
|
||||
func (h *Handler) ImportVertexCredential(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg.AuthDir == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
fileHeader, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
var serviceAccount map[string]any
|
||||
if err := json.Unmarshal(data, &serviceAccount); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
serviceAccount = normalizedSA
|
||||
|
||||
projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
|
||||
if projectID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
|
||||
return
|
||||
}
|
||||
email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
|
||||
|
||||
location := strings.TrimSpace(c.PostForm("location"))
|
||||
if location == "" {
|
||||
location = strings.TrimSpace(c.Query("location"))
|
||||
}
|
||||
if location == "" {
|
||||
location = "us-central1"
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
|
||||
label := labelForVertex(projectID, email)
|
||||
storage := &vertex.VertexCredentialStorage{
|
||||
ServiceAccount: serviceAccount,
|
||||
ProjectID: projectID,
|
||||
Email: email,
|
||||
Location: location,
|
||||
Type: "vertex",
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"service_account": serviceAccount,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
"type": "vertex",
|
||||
"label": label,
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "vertex",
|
||||
FileName: fileName,
|
||||
Storage: storage,
|
||||
Label: label,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if reqCtx := c.Request.Context(); reqCtx != nil {
|
||||
ctx = reqCtx
|
||||
}
|
||||
savedPath, err := h.saveTokenRecord(ctx, record)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"auth-file": savedPath,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
})
|
||||
}
|
||||
|
||||
func valueAsString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return t
|
||||
default:
|
||||
return fmt.Sprint(t)
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeVertexFilePart(s string) string {
|
||||
out := strings.TrimSpace(s)
|
||||
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
|
||||
for i := 0; i < len(replacers); i += 2 {
|
||||
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
|
||||
}
|
||||
if out == "" {
|
||||
return "vertex"
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func labelForVertex(projectID, email string) string {
|
||||
p := strings.TrimSpace(projectID)
|
||||
e := strings.TrimSpace(email)
|
||||
if p != "" && e != "" {
|
||||
return fmt.Sprintf("%s (%s)", p, e)
|
||||
}
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
if e != "" {
|
||||
return e
|
||||
}
|
||||
return "vertex"
|
||||
}
|
||||
@@ -484,6 +484,9 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
|
||||
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
@@ -508,6 +511,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -462,13 +463,19 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
enc := yaml.NewEncoder(f)
|
||||
var buf bytes.Buffer
|
||||
enc := yaml.NewEncoder(&buf)
|
||||
enc.SetIndent(2)
|
||||
if err = enc.Encode(&original); err != nil {
|
||||
_ = enc.Close()
|
||||
return err
|
||||
}
|
||||
return enc.Close()
|
||||
if err = enc.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
data = NormalizeCommentIndentation(buf.Bytes())
|
||||
_, err = f.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
@@ -518,13 +525,40 @@ func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []stri
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
enc := yaml.NewEncoder(f)
|
||||
var buf bytes.Buffer
|
||||
enc := yaml.NewEncoder(&buf)
|
||||
enc.SetIndent(2)
|
||||
if err = enc.Encode(&root); err != nil {
|
||||
_ = enc.Close()
|
||||
return err
|
||||
}
|
||||
return enc.Close()
|
||||
if err = enc.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
data = NormalizeCommentIndentation(buf.Bytes())
|
||||
_, err = f.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// NormalizeCommentIndentation removes indentation from standalone YAML comment lines to keep them left aligned.
|
||||
func NormalizeCommentIndentation(data []byte) []byte {
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
changed := false
|
||||
for i, line := range lines {
|
||||
trimmed := bytes.TrimLeft(line, " \t")
|
||||
if len(trimmed) == 0 || trimmed[0] != '#' {
|
||||
continue
|
||||
}
|
||||
if len(trimmed) == len(line) {
|
||||
continue
|
||||
}
|
||||
lines[i] = append([]byte(nil), trimmed...)
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return data
|
||||
}
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
|
||||
// getOrCreateMapValue finds the value node for a given key in a mapping node.
|
||||
@@ -766,6 +800,7 @@ func matchSequenceElement(original []*yaml.Node, used []bool, target *yaml.Node)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
// Fallback to structural equality to preserve nodes lacking explicit identifiers.
|
||||
for i := range original {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
@@ -89,7 +90,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
return resp, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" {
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
@@ -184,7 +185,7 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return nil, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" {
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
@@ -295,7 +296,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
return cliproxyexecutor.Response{}, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" {
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
@@ -407,7 +408,10 @@ func vertexBaseURL(location string) string {
|
||||
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
|
||||
}
|
||||
|
||||
func vertexAccessToken(ctx context.Context, saJSON []byte) (string, error) {
|
||||
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
||||
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
}
|
||||
// Use cloud-platform scope for Vertex AI.
|
||||
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if errCreds != nil {
|
||||
|
||||
@@ -159,7 +159,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Printf("11111")
|
||||
|
||||
for i := 0; i < len(arr); i++ {
|
||||
m := arr[i]
|
||||
|
||||
@@ -41,24 +41,26 @@ type authDirProvider interface {
|
||||
|
||||
// Watcher manages file watching for configuration and authentication files
|
||||
type Watcher struct {
|
||||
configPath string
|
||||
authDir string
|
||||
config *config.Config
|
||||
clientsMutex sync.RWMutex
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
dispatchMu sync.Mutex
|
||||
dispatchCond *sync.Cond
|
||||
pendingUpdates map[string]AuthUpdate
|
||||
pendingOrder []string
|
||||
dispatchCancel context.CancelFunc
|
||||
storePersister storePersister
|
||||
mirroredAuthDir string
|
||||
oldConfigYaml []byte
|
||||
configPath string
|
||||
authDir string
|
||||
config *config.Config
|
||||
clientsMutex sync.RWMutex
|
||||
configReloadMu sync.Mutex
|
||||
configReloadTimer *time.Timer
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
dispatchMu sync.Mutex
|
||||
dispatchCond *sync.Cond
|
||||
pendingUpdates map[string]AuthUpdate
|
||||
pendingOrder []string
|
||||
dispatchCancel context.CancelFunc
|
||||
storePersister storePersister
|
||||
mirroredAuthDir string
|
||||
oldConfigYaml []byte
|
||||
}
|
||||
|
||||
type stableIDGenerator struct {
|
||||
@@ -113,7 +115,8 @@ type AuthUpdate struct {
|
||||
const (
|
||||
// replaceCheckDelay is a short delay to allow atomic replace (rename) to settle
|
||||
// before deciding whether a Remove event indicates a real deletion.
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
configReloadDebounce = 150 * time.Millisecond
|
||||
)
|
||||
|
||||
// NewWatcher creates a new file watcher instance
|
||||
@@ -172,9 +175,19 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
// Stop stops the file watcher
|
||||
func (w *Watcher) Stop() error {
|
||||
w.stopDispatch()
|
||||
w.stopConfigReloadTimer()
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
func (w *Watcher) stopConfigReloadTimer() {
|
||||
w.configReloadMu.Lock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
w.configReloadTimer = nil
|
||||
}
|
||||
w.configReloadMu.Unlock()
|
||||
}
|
||||
|
||||
// SetConfig updates the current configuration
|
||||
func (w *Watcher) SetConfig(cfg *config.Config) {
|
||||
w.clientsMutex.Lock()
|
||||
@@ -476,40 +489,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Handle config file changes
|
||||
if isConfigEvent {
|
||||
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
||||
data, err := os.ReadFile(w.configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read config file for hash check: %v", err)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty config file write event")
|
||||
return
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
newHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
currentHash := w.lastConfigHash
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if currentHash != "" && currentHash == newHash {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
sumUpdated := sha256.Sum256(updatedData)
|
||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||
} else if errRead != nil {
|
||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
w.lastConfigHash = finalHash
|
||||
w.clientsMutex.Unlock()
|
||||
w.persistConfigAsync()
|
||||
}
|
||||
w.scheduleConfigReload()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -530,6 +510,57 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) scheduleConfigReload() {
|
||||
w.configReloadMu.Lock()
|
||||
defer w.configReloadMu.Unlock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
}
|
||||
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
|
||||
w.configReloadMu.Lock()
|
||||
w.configReloadTimer = nil
|
||||
w.configReloadMu.Unlock()
|
||||
w.reloadConfigIfChanged()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) reloadConfigIfChanged() {
|
||||
data, err := os.ReadFile(w.configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read config file for hash check: %v", err)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty config file write event")
|
||||
return
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
newHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
currentHash := w.lastConfigHash
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if currentHash != "" && currentHash == newHash {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
sumUpdated := sha256.Sum256(updatedData)
|
||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||
} else if errRead != nil {
|
||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
w.lastConfigHash = finalHash
|
||||
w.clientsMutex.Unlock()
|
||||
w.persistConfigAsync()
|
||||
}
|
||||
}
|
||||
|
||||
// reloadConfig reloads the configuration and triggers a full reload
|
||||
func (w *Watcher) reloadConfig() bool {
|
||||
log.Debug("=========================== CONFIG RELOAD ============================")
|
||||
|
||||
Reference in New Issue
Block a user