diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index a718a27a..b9bdc983 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net" "net/http" "os" @@ -57,8 +58,10 @@ type callbackForwarder struct { } var ( - callbackForwardersMu sync.Mutex - callbackForwarders = make(map[int]*callbackForwarder) + callbackForwardersMu sync.Mutex + callbackForwarders = make(map[int]*callbackForwarder) + errAuthFileMustBeJSON = errors.New("auth file must be .json") + errAuthFileNotFound = errors.New("auth file not found") ) func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) { @@ -570,32 +573,57 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { return } ctx := c.Request.Context() - if file, err := c.FormFile("file"); err == nil && file != nil { - name := filepath.Base(file.Filename) - if !strings.HasSuffix(strings.ToLower(name), ".json") { - c.JSON(400, gin.H{"error": "file must be .json"}) - return - } - dst := filepath.Join(h.cfg.AuthDir, name) - if !filepath.IsAbs(dst) { - if abs, errAbs := filepath.Abs(dst); errAbs == nil { - dst = abs + + fileHeaders, errMultipart := h.multipartAuthFileHeaders(c) + if errMultipart != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid multipart form: %v", errMultipart)}) + return + } + if len(fileHeaders) == 1 { + if _, errUpload := h.storeUploadedAuthFile(ctx, fileHeaders[0]); errUpload != nil { + if errors.Is(errUpload, errAuthFileMustBeJSON) { + c.JSON(http.StatusBadRequest, gin.H{"error": "file must be .json"}) + return } - } - if errSave := c.SaveUploadedFile(file, dst); errSave != nil { - c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)}) + c.JSON(http.StatusInternalServerError, gin.H{"error": errUpload.Error()}) return } - data, errRead := os.ReadFile(dst) - if errRead != nil { - c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)}) + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return + } + if len(fileHeaders) > 1 { + uploaded := make([]string, 0, len(fileHeaders)) + failed := make([]gin.H, 0) + for _, file := range fileHeaders { + name, errUpload := h.storeUploadedAuthFile(ctx, file) + if errUpload != nil { + failureName := "" + if file != nil { + failureName = filepath.Base(file.Filename) + } + msg := errUpload.Error() + if errors.Is(errUpload, errAuthFileMustBeJSON) { + msg = "file must be .json" + } + failed = append(failed, gin.H{"name": failureName, "error": msg}) + continue + } + uploaded = append(uploaded, name) + } + if len(failed) > 0 { + c.JSON(http.StatusMultiStatus, gin.H{ + "status": "partial", + "uploaded": len(uploaded), + "files": uploaded, + "failed": failed, + }) return } - if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil { - c.JSON(500, gin.H{"error": errReg.Error()}) - return - } - c.JSON(200, gin.H{"status": "ok"}) + c.JSON(http.StatusOK, gin.H{"status": "ok", "uploaded": len(uploaded), "files": uploaded}) + return + } + if c.ContentType() == "multipart/form-data" { + c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"}) return } name := c.Query("name") @@ -612,17 +640,7 @@ func (h *Handler) UploadAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "failed to read body"}) return } - dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) - if !filepath.IsAbs(dst) { - if abs, errAbs := filepath.Abs(dst); errAbs == nil { - dst = abs - } - } - if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil { - c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)}) - return - } - if err = h.registerAuthFromFile(ctx, dst, data); err != nil { + if err = h.writeAuthFile(ctx, filepath.Base(name), data); err != nil { c.JSON(500, gin.H{"error": err.Error()}) return } @@ -669,11 +687,182 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { c.JSON(200, gin.H{"status": "ok", "deleted": deleted}) return } - name := c.Query("name") - if name == "" || strings.Contains(name, string(os.PathSeparator)) { + + names, errNames := requestedAuthFileNamesForDelete(c) + if errNames != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": errNames.Error()}) + return + } + if len(names) == 0 { c.JSON(400, gin.H{"error": "invalid name"}) return } + if len(names) == 1 { + if _, status, errDelete := h.deleteAuthFileByName(ctx, names[0]); errDelete != nil { + c.JSON(status, gin.H{"error": errDelete.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return + } + + deletedFiles := make([]string, 0, len(names)) + failed := make([]gin.H, 0) + for _, name := range names { + deletedName, _, errDelete := h.deleteAuthFileByName(ctx, name) + if errDelete != nil { + failed = append(failed, gin.H{"name": name, "error": errDelete.Error()}) + continue + } + deletedFiles = append(deletedFiles, deletedName) + } + if len(failed) > 0 { + c.JSON(http.StatusMultiStatus, gin.H{ + "status": "partial", + "deleted": len(deletedFiles), + "files": deletedFiles, + "failed": failed, + }) + return + } + c.JSON(http.StatusOK, gin.H{"status": "ok", "deleted": len(deletedFiles), "files": deletedFiles}) +} + +func (h *Handler) multipartAuthFileHeaders(c *gin.Context) ([]*multipart.FileHeader, error) { + if h == nil || c == nil || c.ContentType() != "multipart/form-data" { + return nil, nil + } + form, err := c.MultipartForm() + if err != nil { + return nil, err + } + if form == nil || len(form.File) == 0 { + return nil, nil + } + + keys := make([]string, 0, len(form.File)) + for key := range form.File { + keys = append(keys, key) + } + sort.Strings(keys) + + headers := make([]*multipart.FileHeader, 0) + for _, key := range keys { + headers = append(headers, form.File[key]...) + } + return headers, nil +} + +func (h *Handler) storeUploadedAuthFile(ctx context.Context, file *multipart.FileHeader) (string, error) { + if file == nil { + return "", fmt.Errorf("no file uploaded") + } + name := filepath.Base(strings.TrimSpace(file.Filename)) + if !strings.HasSuffix(strings.ToLower(name), ".json") { + return "", errAuthFileMustBeJSON + } + src, err := file.Open() + if err != nil { + return "", fmt.Errorf("failed to open uploaded file: %w", err) + } + defer src.Close() + + data, err := io.ReadAll(src) + if err != nil { + return "", fmt.Errorf("failed to read uploaded file: %w", err) + } + if err := h.writeAuthFile(ctx, name, data); err != nil { + return "", err + } + return name, nil +} + +func (h *Handler) writeAuthFile(ctx context.Context, name string, data []byte) error { + dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) + if !filepath.IsAbs(dst) { + if abs, errAbs := filepath.Abs(dst); errAbs == nil { + dst = abs + } + } + auth, err := h.buildAuthFromFileData(dst, data) + if err != nil { + return err + } + if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil { + return fmt.Errorf("failed to write file: %w", errWrite) + } + if err := h.upsertAuthRecord(ctx, auth); err != nil { + return err + } + return nil +} + +func requestedAuthFileNamesForDelete(c *gin.Context) ([]string, error) { + if c == nil { + return nil, nil + } + names := uniqueAuthFileNames(c.QueryArray("name")) + if len(names) > 0 { + return names, nil + } + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + return nil, fmt.Errorf("failed to read body") + } + body = bytes.TrimSpace(body) + if len(body) == 0 { + return nil, nil + } + + var objectBody struct { + Name string `json:"name"` + Names []string `json:"names"` + } + if body[0] == '[' { + var arrayBody []string + if err := json.Unmarshal(body, &arrayBody); err != nil { + return nil, fmt.Errorf("invalid request body") + } + return uniqueAuthFileNames(arrayBody), nil + } + if err := json.Unmarshal(body, &objectBody); err != nil { + return nil, fmt.Errorf("invalid request body") + } + + out := make([]string, 0, len(objectBody.Names)+1) + if strings.TrimSpace(objectBody.Name) != "" { + out = append(out, objectBody.Name) + } + out = append(out, objectBody.Names...) + return uniqueAuthFileNames(out), nil +} + +func uniqueAuthFileNames(names []string) []string { + if len(names) == 0 { + return nil + } + seen := make(map[string]struct{}, len(names)) + out := make([]string, 0, len(names)) + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" { + continue + } + if _, ok := seen[name]; ok { + continue + } + seen[name] = struct{}{} + out = append(out, name) + } + return out +} + +func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) { + name = strings.TrimSpace(name) + if name == "" || strings.Contains(name, string(os.PathSeparator)) { + return "", http.StatusBadRequest, fmt.Errorf("invalid name") + } targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) targetID := "" @@ -690,22 +879,19 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { } if errRemove := os.Remove(targetPath); errRemove != nil { if os.IsNotExist(errRemove) { - c.JSON(404, gin.H{"error": "file not found"}) - } else { - c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)}) + return filepath.Base(name), http.StatusNotFound, errAuthFileNotFound } - return + return filepath.Base(name), http.StatusInternalServerError, fmt.Errorf("failed to remove file: %w", errRemove) } if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil { - c.JSON(500, gin.H{"error": errDeleteRecord.Error()}) - return + return filepath.Base(name), http.StatusInternalServerError, errDeleteRecord } if targetID != "" { h.disableAuth(ctx, targetID) } else { h.disableAuth(ctx, targetPath) } - c.JSON(200, gin.H{"status": "ok"}) + return filepath.Base(name), http.StatusOK, nil } func (h *Handler) findAuthForDelete(name string) *coreauth.Auth { @@ -774,19 +960,27 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data [] if h.authManager == nil { return nil } + auth, err := h.buildAuthFromFileData(path, data) + if err != nil { + return err + } + return h.upsertAuthRecord(ctx, auth) +} + +func (h *Handler) buildAuthFromFileData(path string, data []byte) (*coreauth.Auth, error) { if path == "" { - return fmt.Errorf("auth path is empty") + return nil, fmt.Errorf("auth path is empty") } if data == nil { var err error data, err = os.ReadFile(path) if err != nil { - return fmt.Errorf("failed to read auth file: %w", err) + return nil, fmt.Errorf("failed to read auth file: %w", err) } } metadata := make(map[string]any) if err := json.Unmarshal(data, &metadata); err != nil { - return fmt.Errorf("invalid auth file: %w", err) + return nil, fmt.Errorf("invalid auth file: %w", err) } provider, _ := metadata["type"].(string) if provider == "" { @@ -820,13 +1014,25 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data [] if hasLastRefresh { auth.LastRefreshedAt = lastRefresh } - if existing, ok := h.authManager.GetByID(authID); ok { - auth.CreatedAt = existing.CreatedAt - if !hasLastRefresh { - auth.LastRefreshedAt = existing.LastRefreshedAt + if h != nil && h.authManager != nil { + if existing, ok := h.authManager.GetByID(authID); ok { + auth.CreatedAt = existing.CreatedAt + if !hasLastRefresh { + auth.LastRefreshedAt = existing.LastRefreshedAt + } + auth.NextRefreshAfter = existing.NextRefreshAfter + auth.Runtime = existing.Runtime } - auth.NextRefreshAfter = existing.NextRefreshAfter - auth.Runtime = existing.Runtime + } + return auth, nil +} + +func (h *Handler) upsertAuthRecord(ctx context.Context, auth *coreauth.Auth) error { + if h == nil || h.authManager == nil || auth == nil { + return nil + } + if existing, ok := h.authManager.GetByID(auth.ID); ok { + auth.CreatedAt = existing.CreatedAt _, err := h.authManager.Update(ctx, auth) return err } diff --git a/internal/api/handlers/management/auth_files_batch_test.go b/internal/api/handlers/management/auth_files_batch_test.go new file mode 100644 index 00000000..44cdbd5b --- /dev/null +++ b/internal/api/handlers/management/auth_files_batch_test.go @@ -0,0 +1,197 @@ +package management + +import ( + "bytes" + "encoding/json" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestUploadAuthFile_BatchMultipart(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + manager := coreauth.NewManager(nil, nil, nil) + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + + files := []struct { + name string + content string + }{ + {name: "alpha.json", content: `{"type":"codex","email":"alpha@example.com"}`}, + {name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`}, + } + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + for _, file := range files { + part, err := writer.CreateFormFile("file", file.name) + if err != nil { + t.Fatalf("failed to create multipart file: %v", err) + } + if _, err = part.Write([]byte(file.content)); err != nil { + t.Fatalf("failed to write multipart content: %v", err) + } + } + if err := writer.Close(); err != nil { + t.Fatalf("failed to close multipart writer: %v", err) + } + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + ctx.Request = req + + h.UploadAuthFile(ctx) + + if rec.Code != http.StatusOK { + t.Fatalf("expected upload status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if got, ok := payload["uploaded"].(float64); !ok || int(got) != len(files) { + t.Fatalf("expected uploaded=%d, got %#v", len(files), payload["uploaded"]) + } + + for _, file := range files { + fullPath := filepath.Join(authDir, file.name) + data, err := os.ReadFile(fullPath) + if err != nil { + t.Fatalf("expected uploaded file %s to exist: %v", file.name, err) + } + if string(data) != file.content { + t.Fatalf("expected file %s content %q, got %q", file.name, file.content, string(data)) + } + } + + auths := manager.List() + if len(auths) != len(files) { + t.Fatalf("expected %d auth entries, got %d", len(files), len(auths)) + } +} + +func TestUploadAuthFile_BatchMultipart_InvalidJSONDoesNotOverwriteExistingFile(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + manager := coreauth.NewManager(nil, nil, nil) + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + + existingName := "alpha.json" + existingContent := `{"type":"codex","email":"alpha@example.com"}` + if err := os.WriteFile(filepath.Join(authDir, existingName), []byte(existingContent), 0o600); err != nil { + t.Fatalf("failed to seed existing auth file: %v", err) + } + + files := []struct { + name string + content string + }{ + {name: existingName, content: `{"type":"codex"`}, + {name: "beta.json", content: `{"type":"claude","email":"beta@example.com"}`}, + } + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + for _, file := range files { + part, err := writer.CreateFormFile("file", file.name) + if err != nil { + t.Fatalf("failed to create multipart file: %v", err) + } + if _, err = part.Write([]byte(file.content)); err != nil { + t.Fatalf("failed to write multipart content: %v", err) + } + } + if err := writer.Close(); err != nil { + t.Fatalf("failed to close multipart writer: %v", err) + } + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/v0/management/auth-files", &body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + ctx.Request = req + + h.UploadAuthFile(ctx) + + if rec.Code != http.StatusMultiStatus { + t.Fatalf("expected upload status %d, got %d with body %s", http.StatusMultiStatus, rec.Code, rec.Body.String()) + } + + data, err := os.ReadFile(filepath.Join(authDir, existingName)) + if err != nil { + t.Fatalf("expected existing auth file to remain readable: %v", err) + } + if string(data) != existingContent { + t.Fatalf("expected existing auth file to remain %q, got %q", existingContent, string(data)) + } + + betaData, err := os.ReadFile(filepath.Join(authDir, "beta.json")) + if err != nil { + t.Fatalf("expected valid auth file to be created: %v", err) + } + if string(betaData) != files[1].content { + t.Fatalf("expected beta auth file content %q, got %q", files[1].content, string(betaData)) + } +} + +func TestDeleteAuthFile_BatchQuery(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + files := []string{"alpha.json", "beta.json"} + for _, name := range files { + if err := os.WriteFile(filepath.Join(authDir, name), []byte(`{"type":"codex"}`), 0o600); err != nil { + t.Fatalf("failed to write auth file %s: %v", name, err) + } + } + + manager := coreauth.NewManager(nil, nil, nil) + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + h.tokenStore = &memoryAuthStore{} + + rec := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest( + http.MethodDelete, + "/v0/management/auth-files?name="+url.QueryEscape(files[0])+"&name="+url.QueryEscape(files[1]), + nil, + ) + ctx.Request = req + + h.DeleteAuthFile(ctx) + + if rec.Code != http.StatusOK { + t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String()) + } + + var payload map[string]any + if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if got, ok := payload["deleted"].(float64); !ok || int(got) != len(files) { + t.Fatalf("expected deleted=%d, got %#v", len(files), payload["deleted"]) + } + + for _, name := range files { + if _, err := os.Stat(filepath.Join(authDir, name)); !os.IsNotExist(err) { + t.Fatalf("expected auth file %s to be removed, stat err: %v", name, err) + } + } +}