diff --git a/.dockerignore b/.dockerignore index ef021aea..843c7e04 100644 --- a/.dockerignore +++ b/.dockerignore @@ -31,6 +31,7 @@ bin/* .agent/* .agents/* .opencode/* +.idea/* .bmad/* _bmad/* _bmad-output/* diff --git a/.gitignore b/.gitignore index 183138f9..90ff3a94 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ GEMINI.md .agents/* .agents/* .opencode/* +.idea/* .bmad/* _bmad/* _bmad-output/* diff --git a/README.md b/README.md index d15e4196..8491b97c 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/ ## Sponsor -[![z.ai](https://assets.router-for.me/english-4.7.png)](https://z.ai/subscribe?ic=8JVLJQFSKB) +[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB) This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN. -GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences. +GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences. Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB diff --git a/README_CN.md b/README_CN.md index 8be15461..6e987fdf 100644 --- a/README_CN.md +++ b/README_CN.md @@ -10,13 +10,13 @@ ## 赞助商 -[![bigmodel.cn](https://assets.router-for.me/chinese-4.7.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII) +[![bigmodel.cn](https://assets.router-for.me/chinese-5-0.jpg)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII) 本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。 -GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。 +GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。 -智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII +智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII --- diff --git a/config.example.yaml b/config.example.yaml index 7a3265b4..40bb8721 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -201,6 +201,9 @@ nonstream-keepalive-interval: 0 # alias: "vertex-flash" # client-visible alias # - name: "gemini-2.5-pro" # alias: "vertex-pro" +# excluded-models: # optional: models to exclude from listing +# - "imagen-3.0-generate-002" +# - "imagen-*" # Amp Integration # ampcode: diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 9f03c5b4..e0a16377 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -13,6 +13,7 @@ import ( "net/http" "os" "path/filepath" + "runtime" "sort" "strconv" "strings" @@ -42,14 +43,11 @@ import ( var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( - anthropicCallbackPort = 54545 - geminiCallbackPort = 8085 - codexCallbackPort = 1455 - geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" - geminiCLIVersion = "v1internal" - geminiCLIUserAgent = "google-api-nodejs-client/9.15.1" - geminiCLIApiClient = "gl-node/22.17.0" - geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI" + anthropicCallbackPort = 54545 + geminiCallbackPort = 8085 + codexCallbackPort = 1455 + geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" + geminiCLIVersion = "v1internal" ) type callbackForwarder struct { @@ -189,17 +187,6 @@ func startCallbackForwarder(port int, provider, targetBase string) (*callbackFor return forwarder, nil } -func stopCallbackForwarder(port int) { - callbackForwardersMu.Lock() - forwarder := callbackForwarders[port] - if forwarder != nil { - delete(callbackForwarders, port) - } - callbackForwardersMu.Unlock() - - stopForwarderInstance(port, forwarder) -} - func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) { if forwarder == nil { return @@ -641,44 +628,85 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) { c.JSON(400, gin.H{"error": "invalid name"}) return } - full := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) - if !filepath.IsAbs(full) { - if abs, errAbs := filepath.Abs(full); errAbs == nil { - full = abs + + targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name)) + targetID := "" + if targetAuth := h.findAuthForDelete(name); targetAuth != nil { + targetID = strings.TrimSpace(targetAuth.ID) + if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" { + targetPath = path } } - if err := os.Remove(full); err != nil { - if os.IsNotExist(err) { + if !filepath.IsAbs(targetPath) { + if abs, errAbs := filepath.Abs(targetPath); errAbs == nil { + targetPath = abs + } + } + if errRemove := os.Remove(targetPath); errRemove != nil { + if os.IsNotExist(errRemove) { c.JSON(404, gin.H{"error": "file not found"}) } else { - c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)}) + c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)}) } return } - if err := h.deleteTokenRecord(ctx, full); err != nil { - c.JSON(500, gin.H{"error": err.Error()}) + if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil { + c.JSON(500, gin.H{"error": errDeleteRecord.Error()}) return } - h.disableAuth(ctx, full) + if targetID != "" { + h.disableAuth(ctx, targetID) + } else { + h.disableAuth(ctx, targetPath) + } c.JSON(200, gin.H{"status": "ok"}) } +func (h *Handler) findAuthForDelete(name string) *coreauth.Auth { + if h == nil || h.authManager == nil { + return nil + } + name = strings.TrimSpace(name) + if name == "" { + return nil + } + if auth, ok := h.authManager.GetByID(name); ok { + return auth + } + auths := h.authManager.List() + for _, auth := range auths { + if auth == nil { + continue + } + if strings.TrimSpace(auth.FileName) == name { + return auth + } + if filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name { + return auth + } + } + return nil +} + func (h *Handler) authIDForPath(path string) string { path = strings.TrimSpace(path) if path == "" { return "" } - if h == nil || h.cfg == nil { - return path + id := path + if h != nil && h.cfg != nil { + authDir := strings.TrimSpace(h.cfg.AuthDir) + if authDir != "" { + if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" { + id = rel + } + } } - authDir := strings.TrimSpace(h.cfg.AuthDir) - if authDir == "" { - return path + // On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths. + if runtime.GOOS == "windows" { + id = strings.ToLower(id) } - if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" { - return rel - } - return path + return id } func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error { @@ -896,10 +924,19 @@ func (h *Handler) disableAuth(ctx context.Context, id string) { if h == nil || h.authManager == nil { return } - authID := h.authIDForPath(id) - if authID == "" { - authID = strings.TrimSpace(id) + id = strings.TrimSpace(id) + if id == "" { + return } + if auth, ok := h.authManager.GetByID(id); ok { + auth.Disabled = true + auth.Status = coreauth.StatusDisabled + auth.StatusMessage = "removed via management API" + auth.UpdatedAt = time.Now() + _, _ = h.authManager.Update(ctx, auth) + return + } + authID := h.authIDForPath(id) if authID == "" { return } @@ -2285,9 +2322,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string return fmt.Errorf("create request: %w", errRequest) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", geminiCLIUserAgent) - req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient) - req.Header.Set("Client-Metadata", geminiCLIClientMetadata) + req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) resp, errDo := httpClient.Do(req) if errDo != nil { @@ -2357,7 +2392,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec return false, fmt.Errorf("failed to create request: %w", errRequest) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", geminiCLIUserAgent) + req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) resp, errDo := httpClient.Do(req) if errDo != nil { return false, fmt.Errorf("failed to execute request: %w", errDo) @@ -2378,7 +2413,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec return false, fmt.Errorf("failed to create request: %w", errRequest) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", geminiCLIUserAgent) + req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) resp, errDo = httpClient.Do(req) if errDo != nil { return false, fmt.Errorf("failed to execute request: %w", errDo) diff --git a/internal/api/handlers/management/auth_files_delete_test.go b/internal/api/handlers/management/auth_files_delete_test.go new file mode 100644 index 00000000..7b7b888c --- /dev/null +++ b/internal/api/handlers/management/auth_files_delete_test.go @@ -0,0 +1,129 @@ +package management + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "os" + "path/filepath" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + tempDir := t.TempDir() + authDir := filepath.Join(tempDir, "auth") + externalDir := filepath.Join(tempDir, "external") + if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil { + t.Fatalf("failed to create auth dir: %v", errMkdirAuth) + } + if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil { + t.Fatalf("failed to create external dir: %v", errMkdirExternal) + } + + fileName := "codex-user@example.com-plus.json" + shadowPath := filepath.Join(authDir, fileName) + realPath := filepath.Join(externalDir, fileName) + if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil { + t.Fatalf("failed to write shadow file: %v", errWriteShadow) + } + if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil { + t.Fatalf("failed to write real file: %v", errWriteReal) + } + + manager := coreauth.NewManager(nil, nil, nil) + record := &coreauth.Auth{ + ID: "legacy/" + fileName, + FileName: fileName, + Provider: "codex", + Status: coreauth.StatusError, + Unavailable: true, + Attributes: map[string]string{ + "path": realPath, + }, + Metadata: map[string]any{ + "type": "codex", + "email": "real@example.com", + }, + } + if _, errRegister := manager.Register(context.Background(), record); errRegister != nil { + t.Fatalf("failed to register auth record: %v", errRegister) + } + + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + h.tokenStore = &memoryAuthStore{} + + deleteRec := httptest.NewRecorder() + deleteCtx, _ := gin.CreateTestContext(deleteRec) + deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil) + deleteCtx.Request = deleteReq + h.DeleteAuthFile(deleteCtx) + + if deleteRec.Code != http.StatusOK { + t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String()) + } + if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) { + t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal) + } + if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil { + t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow) + } + + listRec := httptest.NewRecorder() + listCtx, _ := gin.CreateTestContext(listRec) + listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil) + listCtx.Request = listReq + h.ListAuthFiles(listCtx) + + if listRec.Code != http.StatusOK { + t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String()) + } + var listPayload map[string]any + if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil { + t.Fatalf("failed to decode list payload: %v", errUnmarshal) + } + filesRaw, ok := listPayload["files"].([]any) + if !ok { + t.Fatalf("expected files array, payload: %#v", listPayload) + } + if len(filesRaw) != 0 { + t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw)) + } +} + +func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) { + t.Setenv("MANAGEMENT_PASSWORD", "") + gin.SetMode(gin.TestMode) + + authDir := t.TempDir() + fileName := "fallback-user.json" + filePath := filepath.Join(authDir, fileName) + if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil { + t.Fatalf("failed to write auth file: %v", errWrite) + } + + manager := coreauth.NewManager(nil, nil, nil) + h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager) + h.tokenStore = &memoryAuthStore{} + + deleteRec := httptest.NewRecorder() + deleteCtx, _ := gin.CreateTestContext(deleteRec) + deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil) + deleteCtx.Request = deleteReq + h.DeleteAuthFile(deleteCtx) + + if deleteRec.Code != http.StatusOK { + t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String()) + } + if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) { + t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat) + } +} diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index 66e89992..503179c1 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -516,12 +516,13 @@ func (h *Handler) PutVertexCompatKeys(c *gin.Context) { } func (h *Handler) PatchVertexCompatKey(c *gin.Context) { type vertexCompatPatch struct { - APIKey *string `json:"api-key"` - Prefix *string `json:"prefix"` - BaseURL *string `json:"base-url"` - ProxyURL *string `json:"proxy-url"` - Headers *map[string]string `json:"headers"` - Models *[]config.VertexCompatModel `json:"models"` + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + Models *[]config.VertexCompatModel `json:"models"` + ExcludedModels *[]string `json:"excluded-models"` } var body struct { Index *int `json:"index"` @@ -585,6 +586,9 @@ func (h *Handler) PatchVertexCompatKey(c *gin.Context) { if body.Value.Models != nil { entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...) } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } normalizeVertexCompatKey(&entry) h.cfg.VertexCompatAPIKey[targetIndex] = entry h.cfg.SanitizeVertexCompatKeys() @@ -1025,6 +1029,7 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) { entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = config.NormalizeHeaders(entry.Headers) + entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels) if len(entry.Models) == 0 { return } diff --git a/internal/api/modules/amp/proxy.go b/internal/api/modules/amp/proxy.go index b7d10760..ecc9da77 100644 --- a/internal/api/modules/amp/proxy.go +++ b/internal/api/modules/amp/proxy.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" log "github.com/sirupsen/logrus" ) @@ -76,6 +77,9 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi req.Header.Del("X-Api-Key") req.Header.Del("X-Goog-Api-Key") + // Remove proxy, client identity, and browser fingerprint headers + misc.ScrubProxyAndFingerprintHeaders(req) + // Remove query-based credentials if they match the authenticated client API key. // This prevents leaking client auth material to the Amp upstream while avoiding // breaking unrelated upstream query parameters. diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 1d8a1ae3..16af718e 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -20,6 +20,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" log "github.com/sirupsen/logrus" @@ -27,11 +28,8 @@ import ( ) const ( - geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" - geminiCLIVersion = "v1internal" - geminiCLIUserAgent = "google-api-nodejs-client/9.15.1" - geminiCLIApiClient = "gl-node/22.17.0" - geminiCLIClientMetadata = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI" + geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com" + geminiCLIVersion = "v1internal" ) type projectSelectionRequiredError struct{} @@ -409,9 +407,7 @@ func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string return fmt.Errorf("create request: %w", errRequest) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", geminiCLIUserAgent) - req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient) - req.Header.Set("Client-Metadata", geminiCLIClientMetadata) + req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) resp, errDo := httpClient.Do(req) if errDo != nil { @@ -630,7 +626,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec return false, fmt.Errorf("failed to create request: %w", errRequest) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", geminiCLIUserAgent) + req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) resp, errDo := httpClient.Do(req) if errDo != nil { return false, fmt.Errorf("failed to execute request: %w", errDo) @@ -651,7 +647,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec return false, fmt.Errorf("failed to create request: %w", errRequest) } req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", geminiCLIUserAgent) + req.Header.Set("User-Agent", misc.GeminiCLIUserAgent("")) resp, errDo = httpClient.Do(req) if errDo != nil { return false, fmt.Errorf("failed to execute request: %w", errDo) diff --git a/internal/config/config.go b/internal/config/config.go index d6e2bdc8..5a6595f7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -516,16 +516,6 @@ func LoadConfig(configFile string) (*Config, error) { // If optional is true and the file is missing, it returns an empty Config. // If optional is true and the file is empty or invalid, it returns an empty Config. func LoadConfigOptional(configFile string, optional bool) (*Config, error) { - // NOTE: Startup oauth-model-alias migration is intentionally disabled. - // Reason: avoid mutating config.yaml during server startup. - // Re-enable the block below if automatic startup migration is needed again. - // if migrated, err := MigrateOAuthModelAlias(configFile); err != nil { - // // Log warning but don't fail - config loading should still work - // fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err) - // } else if migrated { - // fmt.Println("Migrated oauth-model-mappings to oauth-model-alias") - // } - // Read the entire configuration file into memory. data, err := os.ReadFile(configFile) if err != nil { @@ -1560,9 +1550,6 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) { srcIdx := findMapKeyIndex(srcRoot, key) if srcIdx < 0 { // Keep an explicit empty mapping for oauth-model-alias when it was previously present. - // - // Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the - // oauth-model-alias key is missing, migration will add the default antigravity aliases. // When users delete the last channel from oauth-model-alias via the management API, // we want that deletion to persist across hot reloads and restarts. if key == "oauth-model-alias" { diff --git a/internal/config/oauth_model_alias_migration.go b/internal/config/oauth_model_alias_migration.go deleted file mode 100644 index f52df27a..00000000 --- a/internal/config/oauth_model_alias_migration.go +++ /dev/null @@ -1,277 +0,0 @@ -package config - -import ( - "os" - "strings" - - "gopkg.in/yaml.v3" -) - -// antigravityModelConversionTable maps old built-in aliases to actual model names -// for the antigravity channel during migration. -var antigravityModelConversionTable = map[string]string{ - "gemini-2.5-computer-use-preview-10-2025": "rev19-uic3-1p", - "gemini-3-pro-image-preview": "gemini-3-pro-image", - "gemini-3-pro-preview": "gemini-3-pro-high", - "gemini-3-flash-preview": "gemini-3-flash", - "gemini-claude-sonnet-4-5": "claude-sonnet-4-5", - "gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", - "gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking", - "gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking", -} - -// defaultAntigravityAliases returns the default oauth-model-alias configuration -// for the antigravity channel when neither field exists. -func defaultAntigravityAliases() []OAuthModelAlias { - return []OAuthModelAlias{ - {Name: "rev19-uic3-1p", Alias: "gemini-2.5-computer-use-preview-10-2025"}, - {Name: "gemini-3-pro-image", Alias: "gemini-3-pro-image-preview"}, - {Name: "gemini-3-pro-high", Alias: "gemini-3-pro-preview"}, - {Name: "gemini-3-flash", Alias: "gemini-3-flash-preview"}, - {Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"}, - {Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"}, - {Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"}, - {Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"}, - } -} - -// MigrateOAuthModelAlias checks for and performs migration from oauth-model-mappings -// to oauth-model-alias at startup. Returns true if migration was performed. -// -// Migration flow: -// 1. Check if oauth-model-alias exists -> skip migration -// 2. Check if oauth-model-mappings exists -> convert and migrate -// - For antigravity channel, convert old built-in aliases to actual model names -// -// 3. Neither exists -> add default antigravity config -func MigrateOAuthModelAlias(configFile string) (bool, error) { - data, err := os.ReadFile(configFile) - if err != nil { - if os.IsNotExist(err) { - return false, nil - } - return false, err - } - if len(data) == 0 { - return false, nil - } - - // Parse YAML into node tree to preserve structure - var root yaml.Node - if err := yaml.Unmarshal(data, &root); err != nil { - return false, nil - } - if root.Kind != yaml.DocumentNode || len(root.Content) == 0 { - return false, nil - } - rootMap := root.Content[0] - if rootMap == nil || rootMap.Kind != yaml.MappingNode { - return false, nil - } - - // Check if oauth-model-alias already exists - if findMapKeyIndex(rootMap, "oauth-model-alias") >= 0 { - return false, nil - } - - // Check if oauth-model-mappings exists - oldIdx := findMapKeyIndex(rootMap, "oauth-model-mappings") - if oldIdx >= 0 { - // Migrate from old field - return migrateFromOldField(configFile, &root, rootMap, oldIdx) - } - - // Neither field exists - add default antigravity config - return addDefaultAntigravityConfig(configFile, &root, rootMap) -} - -// migrateFromOldField converts oauth-model-mappings to oauth-model-alias -func migrateFromOldField(configFile string, root *yaml.Node, rootMap *yaml.Node, oldIdx int) (bool, error) { - if oldIdx+1 >= len(rootMap.Content) { - return false, nil - } - oldValue := rootMap.Content[oldIdx+1] - if oldValue == nil || oldValue.Kind != yaml.MappingNode { - return false, nil - } - - // Parse the old aliases - oldAliases := parseOldAliasNode(oldValue) - if len(oldAliases) == 0 { - // Remove the old field and write - removeMapKeyByIndex(rootMap, oldIdx) - return writeYAMLNode(configFile, root) - } - - // Convert model names for antigravity channel - newAliases := make(map[string][]OAuthModelAlias, len(oldAliases)) - for channel, entries := range oldAliases { - converted := make([]OAuthModelAlias, 0, len(entries)) - for _, entry := range entries { - newEntry := OAuthModelAlias{ - Name: entry.Name, - Alias: entry.Alias, - Fork: entry.Fork, - } - // Convert model names for antigravity channel - if strings.EqualFold(channel, "antigravity") { - if actual, ok := antigravityModelConversionTable[entry.Name]; ok { - newEntry.Name = actual - } - } - converted = append(converted, newEntry) - } - newAliases[channel] = converted - } - - // For antigravity channel, supplement missing default aliases - if antigravityEntries, exists := newAliases["antigravity"]; exists { - // Build a set of already configured model names (upstream names) - configuredModels := make(map[string]bool, len(antigravityEntries)) - for _, entry := range antigravityEntries { - configuredModels[entry.Name] = true - } - - // Add missing default aliases - for _, defaultAlias := range defaultAntigravityAliases() { - if !configuredModels[defaultAlias.Name] { - antigravityEntries = append(antigravityEntries, defaultAlias) - } - } - newAliases["antigravity"] = antigravityEntries - } - - // Build new node - newNode := buildOAuthModelAliasNode(newAliases) - - // Replace old key with new key and value - rootMap.Content[oldIdx].Value = "oauth-model-alias" - rootMap.Content[oldIdx+1] = newNode - - return writeYAMLNode(configFile, root) -} - -// addDefaultAntigravityConfig adds the default antigravity configuration -func addDefaultAntigravityConfig(configFile string, root *yaml.Node, rootMap *yaml.Node) (bool, error) { - defaults := map[string][]OAuthModelAlias{ - "antigravity": defaultAntigravityAliases(), - } - newNode := buildOAuthModelAliasNode(defaults) - - // Add new key-value pair - keyNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "oauth-model-alias"} - rootMap.Content = append(rootMap.Content, keyNode, newNode) - - return writeYAMLNode(configFile, root) -} - -// parseOldAliasNode parses the old oauth-model-mappings node structure -func parseOldAliasNode(node *yaml.Node) map[string][]OAuthModelAlias { - if node == nil || node.Kind != yaml.MappingNode { - return nil - } - result := make(map[string][]OAuthModelAlias) - for i := 0; i+1 < len(node.Content); i += 2 { - channelNode := node.Content[i] - entriesNode := node.Content[i+1] - if channelNode == nil || entriesNode == nil { - continue - } - channel := strings.ToLower(strings.TrimSpace(channelNode.Value)) - if channel == "" || entriesNode.Kind != yaml.SequenceNode { - continue - } - entries := make([]OAuthModelAlias, 0, len(entriesNode.Content)) - for _, entryNode := range entriesNode.Content { - if entryNode == nil || entryNode.Kind != yaml.MappingNode { - continue - } - entry := parseAliasEntry(entryNode) - if entry.Name != "" && entry.Alias != "" { - entries = append(entries, entry) - } - } - if len(entries) > 0 { - result[channel] = entries - } - } - return result -} - -// parseAliasEntry parses a single alias entry node -func parseAliasEntry(node *yaml.Node) OAuthModelAlias { - var entry OAuthModelAlias - for i := 0; i+1 < len(node.Content); i += 2 { - keyNode := node.Content[i] - valNode := node.Content[i+1] - if keyNode == nil || valNode == nil { - continue - } - switch strings.ToLower(strings.TrimSpace(keyNode.Value)) { - case "name": - entry.Name = strings.TrimSpace(valNode.Value) - case "alias": - entry.Alias = strings.TrimSpace(valNode.Value) - case "fork": - entry.Fork = strings.ToLower(strings.TrimSpace(valNode.Value)) == "true" - } - } - return entry -} - -// buildOAuthModelAliasNode creates a YAML node for oauth-model-alias -func buildOAuthModelAliasNode(aliases map[string][]OAuthModelAlias) *yaml.Node { - node := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} - for channel, entries := range aliases { - channelNode := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: channel} - entriesNode := &yaml.Node{Kind: yaml.SequenceNode, Tag: "!!seq"} - for _, entry := range entries { - entryNode := &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"} - entryNode.Content = append(entryNode.Content, - &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "name"}, - &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Name}, - &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "alias"}, - &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: entry.Alias}, - ) - if entry.Fork { - entryNode.Content = append(entryNode.Content, - &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "fork"}, - &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!bool", Value: "true"}, - ) - } - entriesNode.Content = append(entriesNode.Content, entryNode) - } - node.Content = append(node.Content, channelNode, entriesNode) - } - return node -} - -// removeMapKeyByIndex removes a key-value pair from a mapping node by index -func removeMapKeyByIndex(mapNode *yaml.Node, keyIdx int) { - if mapNode == nil || mapNode.Kind != yaml.MappingNode { - return - } - if keyIdx < 0 || keyIdx+1 >= len(mapNode.Content) { - return - } - mapNode.Content = append(mapNode.Content[:keyIdx], mapNode.Content[keyIdx+2:]...) -} - -// writeYAMLNode writes the YAML node tree back to file -func writeYAMLNode(configFile string, root *yaml.Node) (bool, error) { - f, err := os.Create(configFile) - if err != nil { - return false, err - } - defer f.Close() - - enc := yaml.NewEncoder(f) - enc.SetIndent(2) - if err := enc.Encode(root); err != nil { - return false, err - } - if err := enc.Close(); err != nil { - return false, err - } - return true, nil -} diff --git a/internal/config/oauth_model_alias_migration_test.go b/internal/config/oauth_model_alias_migration_test.go deleted file mode 100644 index cd73b9d5..00000000 --- a/internal/config/oauth_model_alias_migration_test.go +++ /dev/null @@ -1,245 +0,0 @@ -package config - -import ( - "os" - "path/filepath" - "strings" - "testing" - - "gopkg.in/yaml.v3" -) - -func TestMigrateOAuthModelAlias_SkipsIfNewFieldExists(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - configFile := filepath.Join(dir, "config.yaml") - - content := `oauth-model-alias: - gemini-cli: - - name: "gemini-2.5-pro" - alias: "g2.5p" -` - if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { - t.Fatal(err) - } - - migrated, err := MigrateOAuthModelAlias(configFile) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if migrated { - t.Fatal("expected no migration when oauth-model-alias already exists") - } - - // Verify file unchanged - data, _ := os.ReadFile(configFile) - if !strings.Contains(string(data), "oauth-model-alias:") { - t.Fatal("file should still contain oauth-model-alias") - } -} - -func TestMigrateOAuthModelAlias_MigratesOldField(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - configFile := filepath.Join(dir, "config.yaml") - - content := `oauth-model-mappings: - gemini-cli: - - name: "gemini-2.5-pro" - alias: "g2.5p" - fork: true -` - if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { - t.Fatal(err) - } - - migrated, err := MigrateOAuthModelAlias(configFile) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !migrated { - t.Fatal("expected migration to occur") - } - - // Verify new field exists and old field removed - data, _ := os.ReadFile(configFile) - if strings.Contains(string(data), "oauth-model-mappings:") { - t.Fatal("old field should be removed") - } - if !strings.Contains(string(data), "oauth-model-alias:") { - t.Fatal("new field should exist") - } - - // Parse and verify structure - var root yaml.Node - if err := yaml.Unmarshal(data, &root); err != nil { - t.Fatal(err) - } -} - -func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - configFile := filepath.Join(dir, "config.yaml") - - // Use old model names that should be converted - content := `oauth-model-mappings: - antigravity: - - name: "gemini-2.5-computer-use-preview-10-2025" - alias: "computer-use" - - name: "gemini-3-pro-preview" - alias: "g3p" -` - if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { - t.Fatal(err) - } - - migrated, err := MigrateOAuthModelAlias(configFile) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !migrated { - t.Fatal("expected migration to occur") - } - - // Verify model names were converted - data, _ := os.ReadFile(configFile) - content = string(data) - if !strings.Contains(content, "rev19-uic3-1p") { - t.Fatal("expected gemini-2.5-computer-use-preview-10-2025 to be converted to rev19-uic3-1p") - } - if !strings.Contains(content, "gemini-3-pro-high") { - t.Fatal("expected gemini-3-pro-preview to be converted to gemini-3-pro-high") - } - - // Verify missing default aliases were supplemented - if !strings.Contains(content, "gemini-3-pro-image") { - t.Fatal("expected missing default alias gemini-3-pro-image to be added") - } - if !strings.Contains(content, "gemini-3-flash") { - t.Fatal("expected missing default alias gemini-3-flash to be added") - } - if !strings.Contains(content, "claude-sonnet-4-5") { - t.Fatal("expected missing default alias claude-sonnet-4-5 to be added") - } - if !strings.Contains(content, "claude-sonnet-4-5-thinking") { - t.Fatal("expected missing default alias claude-sonnet-4-5-thinking to be added") - } - if !strings.Contains(content, "claude-opus-4-5-thinking") { - t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added") - } - if !strings.Contains(content, "claude-opus-4-6-thinking") { - t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added") - } -} - -func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - configFile := filepath.Join(dir, "config.yaml") - - content := `debug: true -port: 8080 -` - if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { - t.Fatal(err) - } - - migrated, err := MigrateOAuthModelAlias(configFile) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !migrated { - t.Fatal("expected migration to add default config") - } - - // Verify default antigravity config was added - data, _ := os.ReadFile(configFile) - content = string(data) - if !strings.Contains(content, "oauth-model-alias:") { - t.Fatal("expected oauth-model-alias to be added") - } - if !strings.Contains(content, "antigravity:") { - t.Fatal("expected antigravity channel to be added") - } - if !strings.Contains(content, "rev19-uic3-1p") { - t.Fatal("expected default antigravity aliases to include rev19-uic3-1p") - } -} - -func TestMigrateOAuthModelAlias_PreservesOtherConfig(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - configFile := filepath.Join(dir, "config.yaml") - - content := `debug: true -port: 8080 -oauth-model-mappings: - gemini-cli: - - name: "test" - alias: "t" -api-keys: - - "key1" - - "key2" -` - if err := os.WriteFile(configFile, []byte(content), 0644); err != nil { - t.Fatal(err) - } - - migrated, err := MigrateOAuthModelAlias(configFile) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !migrated { - t.Fatal("expected migration to occur") - } - - // Verify other config preserved - data, _ := os.ReadFile(configFile) - content = string(data) - if !strings.Contains(content, "debug: true") { - t.Fatal("expected debug field to be preserved") - } - if !strings.Contains(content, "port: 8080") { - t.Fatal("expected port field to be preserved") - } - if !strings.Contains(content, "api-keys:") { - t.Fatal("expected api-keys field to be preserved") - } -} - -func TestMigrateOAuthModelAlias_NonexistentFile(t *testing.T) { - t.Parallel() - - migrated, err := MigrateOAuthModelAlias("/nonexistent/path/config.yaml") - if err != nil { - t.Fatalf("unexpected error for nonexistent file: %v", err) - } - if migrated { - t.Fatal("expected no migration for nonexistent file") - } -} - -func TestMigrateOAuthModelAlias_EmptyFile(t *testing.T) { - t.Parallel() - - dir := t.TempDir() - configFile := filepath.Join(dir, "config.yaml") - - if err := os.WriteFile(configFile, []byte(""), 0644); err != nil { - t.Fatal(err) - } - - migrated, err := MigrateOAuthModelAlias(configFile) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if migrated { - t.Fatal("expected no migration for empty file") - } -} diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index 786c5318..5f6c7c88 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -34,6 +34,9 @@ type VertexCompatKey struct { // Models defines the model configurations including aliases for routing. Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"` + + // ExcludedModels lists model IDs that should be excluded for this provider. + ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } func (k VertexCompatKey) GetAPIKey() string { return k.APIKey } @@ -74,6 +77,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() { } entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = NormalizeHeaders(entry.Headers) + entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) // Sanitize models: remove entries without valid alias sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models)) diff --git a/internal/misc/header_utils.go b/internal/misc/header_utils.go index c6279a4c..5752a269 100644 --- a/internal/misc/header_utils.go +++ b/internal/misc/header_utils.go @@ -4,10 +4,98 @@ package misc import ( + "fmt" "net/http" + "runtime" "strings" ) +const ( + // GeminiCLIVersion is the version string reported in the User-Agent for upstream requests. + GeminiCLIVersion = "0.31.0" + + // GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream. + GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0" +) + +// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI. +func geminiCLIOS() string { + switch runtime.GOOS { + case "windows": + return "win32" + default: + return runtime.GOOS + } +} + +// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI. +func geminiCLIArch() string { + switch runtime.GOARCH { + case "amd64": + return "x64" + case "386": + return "x86" + default: + return runtime.GOARCH + } +} + +// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format. +// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable. +func GeminiCLIUserAgent(model string) string { + if model == "" { + model = "unknown" + } + return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch()) +} + +// ScrubProxyAndFingerprintHeaders removes all headers that could reveal +// proxy infrastructure, client identity, or browser fingerprints from an +// outgoing request. This ensures requests to upstream services look like they +// originate directly from a native client rather than a third-party client +// behind a reverse proxy. +func ScrubProxyAndFingerprintHeaders(req *http.Request) { + if req == nil { + return + } + + // --- Proxy tracing headers --- + req.Header.Del("X-Forwarded-For") + req.Header.Del("X-Forwarded-Host") + req.Header.Del("X-Forwarded-Proto") + req.Header.Del("X-Forwarded-Port") + req.Header.Del("X-Real-IP") + req.Header.Del("Forwarded") + req.Header.Del("Via") + + // --- Client identity headers --- + req.Header.Del("X-Title") + req.Header.Del("X-Stainless-Lang") + req.Header.Del("X-Stainless-Package-Version") + req.Header.Del("X-Stainless-Os") + req.Header.Del("X-Stainless-Arch") + req.Header.Del("X-Stainless-Runtime") + req.Header.Del("X-Stainless-Runtime-Version") + req.Header.Del("Http-Referer") + req.Header.Del("Referer") + + // --- Browser / Chromium fingerprint headers --- + // These are sent by Electron-based clients (e.g. CherryStudio) using the + // Fetch API, but NOT by Node.js https module (which Antigravity uses). + req.Header.Del("Sec-Ch-Ua") + req.Header.Del("Sec-Ch-Ua-Mobile") + req.Header.Del("Sec-Ch-Ua-Platform") + req.Header.Del("Sec-Fetch-Mode") + req.Header.Del("Sec-Fetch-Site") + req.Header.Del("Sec-Fetch-Dest") + req.Header.Del("Priority") + + // --- Encoding negotiation --- + // Antigravity (Node.js) sends "gzip, deflate, br" by default; + // Electron-based clients may add "zstd" which is a fingerprint mismatch. + req.Header.Del("Accept-Encoding") +} + // EnsureHeader ensures that a header exists in the target header map by checking // multiple sources in order of priority: source headers, existing target headers, // and finally the default value. It only sets the header if it's not already present diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go index 7cfe15db..750aa4b4 100644 --- a/internal/registry/model_definitions_static_data.go +++ b/internal/registry/model_definitions_static_data.go @@ -37,7 +37,7 @@ func GetClaudeModels() []*ModelInfo { DisplayName: "Claude 4.6 Sonnet", ContextLength: 200000, MaxCompletionTokens: 64000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, + Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}}, }, { ID: "claude-opus-4-6", @@ -49,7 +49,7 @@ func GetClaudeModels() []*ModelInfo { Description: "Premium model combining maximum intelligence with practical performance", ContextLength: 1000000, MaxCompletionTokens: 128000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, + Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}}, }, { ID: "claude-opus-4-5-20251101", @@ -208,12 +208,27 @@ func GetGeminiModels() []*ModelInfo { Name: "models/gemini-3-flash-preview", Version: "3.0", DisplayName: "Gemini 3 Flash Preview", - Description: "Gemini 3 Flash Preview", + Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", InputTokenLimit: 1048576, OutputTokenLimit: 65536, SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, }, + { + ID: "gemini-3.1-flash-lite-preview", + Object: "model", + Created: 1776288000, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3.1-flash-lite-preview", + Version: "3.1", + DisplayName: "Gemini 3.1 Flash Lite Preview", + Description: "Our smallest and most cost effective model, built for at scale usage.", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, + }, { ID: "gemini-3-pro-image-preview", Object: "model", @@ -324,6 +339,21 @@ func GetGeminiVertexModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, }, + { + ID: "gemini-3.1-flash-lite-preview", + Object: "model", + Created: 1776288000, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3.1-flash-lite-preview", + Version: "3.1", + DisplayName: "Gemini 3.1 Flash Lite Preview", + Description: "Our smallest and most cost effective model, built for at scale usage.", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, + }, { ID: "gemini-3-pro-image-preview", Object: "model", @@ -496,6 +526,21 @@ func GetGeminiCLIModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, }, + { + ID: "gemini-3.1-flash-lite-preview", + Object: "model", + Created: 1776288000, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3.1-flash-lite-preview", + Version: "3.1", + DisplayName: "Gemini 3.1 Flash Lite Preview", + Description: "Our smallest and most cost effective model, built for at scale usage.", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, + }, } } @@ -592,6 +637,21 @@ func GetAIStudioModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, }, + { + ID: "gemini-3.1-flash-lite-preview", + Object: "model", + Created: 1776288000, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3.1-flash-lite-preview", + Version: "3.1", + DisplayName: "Gemini 3.1 Flash Lite Preview", + Description: "Our smallest and most cost effective model, built for at scale usage.", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, + }, { ID: "gemini-pro-latest", Object: "model", @@ -827,6 +887,20 @@ func GetOpenAIModels() []*ModelInfo { SupportedParameters: []string{"tools"}, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, }, + { + ID: "gpt-5.4", + Object: "model", + Created: 1772668800, + OwnedBy: "openai", + Type: "openai", + Version: "gpt-5.4", + DisplayName: "GPT 5.4", + Description: "Stable version of GPT 5.4", + ContextLength: 1_050_000, + MaxCompletionTokens: 128000, + SupportedParameters: []string{"tools"}, + Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, + }, } } @@ -947,18 +1021,18 @@ type AntigravityModelConfig struct { // Keys use upstream model names returned by the Antigravity models endpoint. func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { return map[string]*AntigravityModelConfig{ - // "rev19-uic3-1p": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}}, "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, "gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, - "gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, + "gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, "gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, + "gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, "gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}}, + "gemini-3.1-flash-lite-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}}, "gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}}, - "claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}}, - "claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}}, + "claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, + "claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, "gpt-oss-120b-medium": {}, - "tab_flash_lite_preview": {}, } } diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index 7b8b262e..e036a04f 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -47,6 +47,10 @@ type ModelInfo struct { MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` // SupportedParameters lists supported parameters SupportedParameters []string `json:"supported_parameters,omitempty"` + // SupportedInputModalities lists supported input modalities (e.g., TEXT, IMAGE, VIDEO, AUDIO) + SupportedInputModalities []string `json:"supportedInputModalities,omitempty"` + // SupportedOutputModalities lists supported output modalities (e.g., TEXT, IMAGE) + SupportedOutputModalities []string `json:"supportedOutputModalities,omitempty"` // Thinking holds provider-specific reasoning/thinking budget capabilities. // This is optional and currently used for Gemini thinking budget normalization. @@ -499,6 +503,12 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo { if len(model.SupportedParameters) > 0 { copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...) } + if len(model.SupportedInputModalities) > 0 { + copyModel.SupportedInputModalities = append([]string(nil), model.SupportedInputModalities...) + } + if len(model.SupportedOutputModalities) > 0 { + copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...) + } return ©Model } @@ -1067,6 +1077,12 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) if len(model.SupportedGenerationMethods) > 0 { result["supportedGenerationMethods"] = model.SupportedGenerationMethods } + if len(model.SupportedInputModalities) > 0 { + result["supportedInputModalities"] = model.SupportedInputModalities + } + if len(model.SupportedOutputModalities) > 0 { + result["supportedOutputModalities"] = model.SupportedOutputModalities + } return result default: diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 00959a22..f3a052bf 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "crypto/sha256" + "crypto/tls" "encoding/binary" "encoding/json" "errors" @@ -45,10 +46,10 @@ const ( antigravityModelsPath = "/v1internal:fetchAvailableModels" antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" antigravityClientSecret = "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" - defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64" + defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64" antigravityAuthType = "antigravity" refreshSkew = 3000 * time.Second - systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**" + // systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**" ) var ( @@ -142,6 +143,62 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor { return &AntigravityExecutor{cfg: cfg} } +// antigravityTransport is a singleton HTTP/1.1 transport shared by all Antigravity requests. +// It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool +// (and the goroutines managing it) on every request. +var ( + antigravityTransport *http.Transport + antigravityTransportOnce sync.Once +) + +func cloneTransportWithHTTP11(base *http.Transport) *http.Transport { + if base == nil { + return nil + } + + clone := base.Clone() + clone.ForceAttemptHTTP2 = false + // Wipe TLSNextProto to prevent implicit HTTP/2 upgrade. + clone.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper) + if clone.TLSClientConfig == nil { + clone.TLSClientConfig = &tls.Config{} + } else { + clone.TLSClientConfig = clone.TLSClientConfig.Clone() + } + // Actively advertise only HTTP/1.1 in the ALPN handshake. + clone.TLSClientConfig.NextProtos = []string{"http/1.1"} + return clone +} + +// initAntigravityTransport creates the shared HTTP/1.1 transport exactly once. +func initAntigravityTransport() { + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + base = &http.Transport{} + } + antigravityTransport = cloneTransportWithHTTP11(base) +} + +// newAntigravityHTTPClient creates an HTTP client specifically for Antigravity, +// enforcing HTTP/1.1 by disabling HTTP/2 to perfectly mimic Node.js https defaults. +// The underlying Transport is a singleton to avoid leaking connection pools. +func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { + antigravityTransportOnce.Do(initAntigravityTransport) + + client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout) + // If no transport is set, use the shared HTTP/1.1 transport. + if client.Transport == nil { + client.Transport = antigravityTransport + return client + } + + // Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1. + if transport, ok := client.Transport.(*http.Transport); ok { + client.Transport = cloneTransportWithHTTP11(transport) + } + return client +} + // Identifier returns the executor identifier. func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType } @@ -162,6 +219,8 @@ func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyau } // HttpRequest injects Antigravity credentials into the request and executes it. +// It uses a whitelist approach: all incoming headers are stripped and only +// the minimum set required by the Antigravity protocol is explicitly set. func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { if req == nil { return nil, fmt.Errorf("antigravity executor: request is nil") @@ -170,10 +229,29 @@ func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyaut ctx = req.Context() } httpReq := req.WithContext(ctx) + + // --- Whitelist: save only the headers we need from the original request --- + contentType := httpReq.Header.Get("Content-Type") + + // Wipe ALL incoming headers + for k := range httpReq.Header { + delete(httpReq.Header, k) + } + + // --- Set only the headers Antigravity actually sends --- + if contentType != "" { + httpReq.Header.Set("Content-Type", contentType) + } + // Content-Length is managed automatically by Go's http.Client from the Body + httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) + httpReq.Close = true // sends Connection: close + + // Inject Authorization: Bearer if err := e.PrepareRequest(httpReq, auth); err != nil { return nil, err } - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) return httpClient.Do(httpReq) } @@ -185,7 +263,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au baseModel := thinking.ParseSuffix(req.Model).ModelName isClaude := strings.Contains(strings.ToLower(baseModel), "claude") - if isClaude || strings.Contains(baseModel, "gemini-3-pro") { + if isClaude || strings.Contains(baseModel, "gemini-3-pro") || strings.Contains(baseModel, "gemini-3.1-flash-image") { return e.executeClaudeNonStream(ctx, auth, req, opts) } @@ -220,7 +298,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) baseURLs := antigravityBaseURLFallbackOrder(auth) - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) attempts := antigravityRetryAttempts(auth, e.cfg) @@ -362,7 +440,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) baseURLs := antigravityBaseURLFallbackOrder(auth) - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) attempts := antigravityRetryAttempts(auth, e.cfg) @@ -754,7 +832,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel) baseURLs := antigravityBaseURLFallbackOrder(auth) - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) attempts := antigravityRetryAttempts(auth, e.cfg) @@ -956,7 +1034,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut payload = deleteJSONField(payload, "request.safetySettings") baseURLs := antigravityBaseURLFallbackOrder(auth) - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) var authID, authLabel, authType, authValue string if auth != nil { @@ -987,10 +1065,10 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut if errReq != nil { return cliproxyexecutor.Response{}, errReq } + httpReq.Close = true httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+token) httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) - httpReq.Header.Set("Accept", "application/json") if host := resolveHost(base); host != "" { httpReq.Host = host } @@ -1084,14 +1162,26 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } baseURLs := antigravityBaseURLFallbackOrder(auth) - httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, cfg, auth, 0) for idx, baseURL := range baseURLs { modelsURL := baseURL + antigravityModelsPath - httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`))) + + var payload []byte + if auth != nil && auth.Metadata != nil { + if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" { + payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid))) + } + } + if len(payload) == 0 { + payload = []byte(`{}`) + } + + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload)) if errReq != nil { return fallbackAntigravityPrimaryModels() } + httpReq.Close = true httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+token) httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) @@ -1152,7 +1242,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c continue } switch modelID { - case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro": + case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro": continue } modelCfg := modelConfig[modelID] @@ -1174,6 +1264,29 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c OwnedBy: antigravityAuthType, Type: antigravityAuthType, } + + // Build input modalities from upstream capability flags. + inputModalities := []string{"TEXT"} + if modelData.Get("supportsImages").Bool() { + inputModalities = append(inputModalities, "IMAGE") + } + if modelData.Get("supportsVideo").Bool() { + inputModalities = append(inputModalities, "VIDEO") + } + modelInfo.SupportedInputModalities = inputModalities + modelInfo.SupportedOutputModalities = []string{"TEXT"} + + // Token limits from upstream. + if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 { + modelInfo.InputTokenLimit = int(maxTok) + } + if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 { + modelInfo.OutputTokenLimit = int(maxOut) + } + + // Supported generation methods (Gemini v1beta convention). + modelInfo.SupportedGenerationMethods = []string{"generateContent", "countTokens"} + // Look up Thinking support from static config using upstream model name. if modelCfg != nil { if modelCfg.Thinking != nil { @@ -1241,10 +1354,11 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau return auth, errReq } httpReq.Header.Set("Host", "oauth2.googleapis.com") - httpReq.Header.Set("User-Agent", defaultAntigravityAgent) httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + // Real Antigravity uses Go's default User-Agent for OAuth token refresh + httpReq.Header.Set("User-Agent", "Go-http-client/2.0") - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { return auth, errDo @@ -1315,7 +1429,7 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au return nil } - httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0) projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient) if errFetch != nil { return errFetch @@ -1369,7 +1483,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payload = geminiToAntigravity(modelName, payload, projectID) payload, _ = sjson.SetBytes(payload, "model", modelName) - useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") + useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro") payloadStr := string(payload) paths := make([]string, 0) util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths) @@ -1383,18 +1497,18 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payloadStr = util.CleanJSONSchemaForGemini(payloadStr) } - if useAntigravitySchema { - systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts") - payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user") - payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction) - payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction)) + // if useAntigravitySchema { + // systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts") + // payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user") + // payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction) + // payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction)) - if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() { - for _, partResult := range systemInstructionPartsResult.Array() { - payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw) - } - } - } + // if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() { + // for _, partResult := range systemInstructionPartsResult.Array() { + // payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw) + // } + // } + // } if strings.Contains(modelName, "claude") { payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") @@ -1406,14 +1520,10 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau if errReq != nil { return nil, errReq } + httpReq.Close = true httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+token) httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) - if stream { - httpReq.Header.Set("Accept", "text/event-stream") - } else { - httpReq.Header.Set("Accept", "application/json") - } if host := resolveHost(base); host != "" { httpReq.Host = host } @@ -1625,7 +1735,16 @@ func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string { func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte { template, _ := sjson.Set(string(payload), "model", modelName) template, _ = sjson.Set(template, "userAgent", "antigravity") - template, _ = sjson.Set(template, "requestType", "agent") + + isImageModel := strings.Contains(modelName, "image") + + var reqType string + if isImageModel { + reqType = "image_gen" + } else { + reqType = "agent" + } + template, _ = sjson.Set(template, "requestType", reqType) // Use real project ID from auth if available, otherwise generate random (legacy fallback) if projectID != "" { @@ -1633,8 +1752,13 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b } else { template, _ = sjson.Set(template, "project", generateProjectID()) } - template, _ = sjson.Set(template, "requestId", generateRequestID()) - template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload)) + + if isImageModel { + template, _ = sjson.Set(template, "requestId", generateImageGenRequestID()) + } else { + template, _ = sjson.Set(template, "requestId", generateRequestID()) + template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload)) + } template, _ = sjson.Delete(template, "request.safetySettings") if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() { @@ -1648,6 +1772,10 @@ func generateRequestID() string { return "agent-" + uuid.NewString() } +func generateImageGenRequestID() string { + return fmt.Sprintf("image_gen/%d/%s/12", time.Now().UnixMilli(), uuid.NewString()) +} + func generateSessionID() string { randSourceMutex.Lock() n := randSource.Int63n(9_000_000_000_000_000_000) diff --git a/internal/runtime/executor/antigravity_executor_buildrequest_test.go b/internal/runtime/executor/antigravity_executor_buildrequest_test.go index c5cba4ee..27dbeca4 100644 --- a/internal/runtime/executor/antigravity_executor_buildrequest_test.go +++ b/internal/runtime/executor/antigravity_executor_buildrequest_test.go @@ -59,6 +59,7 @@ func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any "properties": { "mode": { "type": "string", + "deprecated": true, "enum": ["a", "b"], "enumTitles": ["A", "B"] } @@ -156,4 +157,7 @@ func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]a if _, ok := mode["enumTitles"]; ok { t.Fatalf("enumTitles should be removed from nested schema") } + if _, ok := mode["deprecated"]; ok { + t.Fatalf("deprecated should be removed from nested schema") + } } diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index fcb3a9c9..7d0ddcf2 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -9,9 +9,11 @@ import ( "crypto/rand" "crypto/sha256" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" + "net/textproto" "runtime" "strings" "time" @@ -135,6 +137,15 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r body = ensureCacheControl(body) } + // Enforce Anthropic's cache_control block limit (max 4 breakpoints per request). + // Cloaking and ensureCacheControl may push the total over 4 when the client + // (e.g. Amp CLI) already sends multiple cache_control blocks. + body = enforceCacheControlLimit(body, 4) + + // Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05. + // A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages). + body = normalizeCacheControlTTL(body) + // Extract betas from body and convert to header var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) @@ -176,11 +187,27 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) + // Decompress error responses — pass the Content-Encoding value (may be empty) + // and let decodeResponseBody handle both header-declared and magic-byte-detected + // compression. This keeps error-path behaviour consistent with the success path. + errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding")) + if decErr != nil { + recordAPIResponseError(ctx, e.cfg, decErr) + msg := fmt.Sprintf("failed to decode error response body: %v", decErr) + logWithRequestID(ctx).Warn(msg) + return resp, statusErr{code: httpResp.StatusCode, msg: msg} + } + b, readErr := io.ReadAll(errBody) + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + msg := fmt.Sprintf("failed to read error response body: %v", readErr) + logWithRequestID(ctx).Warn(msg) + b = []byte(msg) + } appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) err = statusErr{code: httpResp.StatusCode, msg: string(b)} - if errClose := httpResp.Body.Close(); errClose != nil { + if errClose := errBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } return resp, err @@ -276,6 +303,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A body = ensureCacheControl(body) } + // Enforce Anthropic's cache_control block limit (max 4 breakpoints per request). + body = enforceCacheControlLimit(body, 4) + + // Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05. + body = normalizeCacheControlTTL(body) + // Extract betas from body and convert to header var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) @@ -317,10 +350,26 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - b, _ := io.ReadAll(httpResp.Body) + // Decompress error responses — pass the Content-Encoding value (may be empty) + // and let decodeResponseBody handle both header-declared and magic-byte-detected + // compression. This keeps error-path behaviour consistent with the success path. + errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding")) + if decErr != nil { + recordAPIResponseError(ctx, e.cfg, decErr) + msg := fmt.Sprintf("failed to decode error response body: %v", decErr) + logWithRequestID(ctx).Warn(msg) + return nil, statusErr{code: httpResp.StatusCode, msg: msg} + } + b, readErr := io.ReadAll(errBody) + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + msg := fmt.Sprintf("failed to read error response body: %v", readErr) + logWithRequestID(ctx).Warn(msg) + b = []byte(msg) + } appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - if errClose := httpResp.Body.Close(); errClose != nil { + if errClose := errBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } err = statusErr{code: httpResp.StatusCode, msg: string(b)} @@ -425,6 +474,10 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut body = checkSystemInstructions(body) } + // Keep count_tokens requests compatible with Anthropic cache-control constraints too. + body = enforceCacheControlLimit(body, 4) + body = normalizeCacheControlTTL(body) + // Extract betas from body and convert to header (for count_tokens too) var extraBetas []string extraBetas, body = extractAndRemoveBetas(body) @@ -464,9 +517,25 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut } recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone()) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - b, _ := io.ReadAll(resp.Body) + // Decompress error responses — pass the Content-Encoding value (may be empty) + // and let decodeResponseBody handle both header-declared and magic-byte-detected + // compression. This keeps error-path behaviour consistent with the success path. + errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding")) + if decErr != nil { + recordAPIResponseError(ctx, e.cfg, decErr) + msg := fmt.Sprintf("failed to decode error response body: %v", decErr) + logWithRequestID(ctx).Warn(msg) + return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg} + } + b, readErr := io.ReadAll(errBody) + if readErr != nil { + recordAPIResponseError(ctx, e.cfg, readErr) + msg := fmt.Sprintf("failed to read error response body: %v", readErr) + logWithRequestID(ctx).Warn(msg) + b = []byte(msg) + } appendAPIResponseChunk(ctx, e.cfg, b) - if errClose := resp.Body.Close(); errClose != nil { + if errClose := errBody.Close(); errClose != nil { log.Errorf("response body close error: %v", errClose) } return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)} @@ -559,6 +628,12 @@ func disableThinkingIfToolChoiceForced(body []byte) []byte { if toolChoiceType == "any" || toolChoiceType == "tool" { // Remove thinking configuration entirely to avoid API error body, _ = sjson.DeleteBytes(body, "thinking") + // Adaptive thinking may also set output_config.effort; remove it to avoid + // leaking thinking controls when tool_choice forces tool use. + body, _ = sjson.DeleteBytes(body, "output_config.effort") + if oc := gjson.GetBytes(body, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + body, _ = sjson.DeleteBytes(body, "output_config") + } } return body } @@ -581,12 +656,61 @@ func (c *compositeReadCloser) Close() error { return firstErr } +// peekableBody wraps a bufio.Reader around the original ReadCloser so that +// magic bytes can be inspected without consuming them from the stream. +type peekableBody struct { + *bufio.Reader + closer io.Closer +} + +func (p *peekableBody) Close() error { + return p.closer.Close() +} + func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) { if body == nil { return nil, fmt.Errorf("response body is nil") } if contentEncoding == "" { - return body, nil + // No Content-Encoding header. Attempt best-effort magic-byte detection to + // handle misbehaving upstreams that compress without setting the header. + // Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences; + // br and deflate have none and are left as-is. + // The bufio wrapper preserves unread bytes so callers always see the full + // stream regardless of whether decompression was applied. + pb := &peekableBody{Reader: bufio.NewReader(body), closer: body} + magic, peekErr := pb.Peek(4) + if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) { + switch { + case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b: + gzipReader, gzErr := gzip.NewReader(pb) + if gzErr != nil { + _ = pb.Close() + return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr) + } + return &compositeReadCloser{ + Reader: gzipReader, + closers: []func() error{ + gzipReader.Close, + pb.Close, + }, + }, nil + case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd: + decoder, zdErr := zstd.NewReader(pb) + if zdErr != nil { + _ = pb.Close() + return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr) + } + return &compositeReadCloser{ + Reader: decoder, + closers: []func() error{ + func() error { decoder.Close(); return nil }, + pb.Close, + }, + }, nil + } + } + return pb, nil } encodings := strings.Split(contentEncoding, ",") for _, raw := range encodings { @@ -709,11 +833,21 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, } } - // Merge extra betas from request body - if len(extraBetas) > 0 { + hasClaude1MHeader := false + if ginHeaders != nil { + if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok { + hasClaude1MHeader = true + } + } + + // Merge extra betas from request body and request flags. + if len(extraBetas) > 0 || hasClaude1MHeader { existingSet := make(map[string]bool) for _, b := range strings.Split(baseBetas, ",") { - existingSet[strings.TrimSpace(b)] = true + betaName := strings.TrimSpace(b) + if betaName != "" { + existingSet[betaName] = true + } } for _, beta := range extraBetas { beta = strings.TrimSpace(beta) @@ -722,6 +856,9 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, existingSet[beta] = true } } + if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] { + baseBetas += ",context-1m-2025-08-07" + } } r.Header.Set("Anthropic-Beta", baseBetas) @@ -750,11 +887,15 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)")) } r.Header.Set("Connection", "keep-alive") - r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") if stream { r.Header.Set("Accept", "text/event-stream") + // SSE streams must not be compressed: the downstream scanner reads + // line-delimited text and cannot parse compressed bytes. Using + // "identity" tells the upstream to send an uncompressed stream. + r.Header.Set("Accept-Encoding", "identity") } else { r.Header.Set("Accept", "application/json") + r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd") } // Keep OS/Arch mapping dynamic (not configurable). // They intentionally continue to derive from runtime.GOOS/runtime.GOARCH. @@ -763,6 +904,12 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, attrs = auth.Attributes } util.ApplyCustomHeadersFromAttrs(r, attrs) + // Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which + // may override it with a user-configured value. Compressed SSE breaks the line + // scanner regardless of user preference, so this is non-negotiable for streams. + if stream { + r.Header.Set("Accept-Encoding", "identity") + } } func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { @@ -1073,17 +1220,22 @@ func generateBillingHeader(payload []byte) string { return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch) } -// checkSystemInstructionsWithMode injects Claude Code system prompt to match -// the real Claude Code request format: -// system[0]: billing header (no cache_control) -// system[1]: "You are a Claude agent, built on Anthropic's Claude Agent SDK." (with cache_control) -// system[2..]: user's system messages (with cache_control on last) +// checkSystemInstructionsWithMode injects Claude Code-style system blocks: +// +// system[0]: billing header (no cache_control) +// system[1]: agent identifier (no cache_control) +// system[2..]: user system messages (cache_control added when missing) func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { system := gjson.GetBytes(payload, "system") billingText := generateBillingHeader(payload) billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText) - agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}` + // No cache_control on the agent block. It is a cloaking artifact with zero cache + // value (the last system block is what actually triggers caching of all system content). + // Including any cache_control here creates an intra-system TTL ordering violation + // when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta + // forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m). + agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}` if strictMode { // Strict mode: billing header + agent identifier only @@ -1103,11 +1255,12 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { if system.IsArray() { system.ForEach(func(_, part gjson.Result) bool { if part.Get("type").String() == "text" { - // Add cache_control with ttl to user system messages if not present + // Add cache_control to user system messages if not present. + // Do NOT add ttl — let it inherit the default (5m) to avoid + // TTL ordering violations with the prompt-caching-scope-2026-01-05 beta. partJSON := part.Raw if !part.Get("cache_control").Exists() { partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral") - partJSON, _ = sjson.Set(partJSON, "cache_control.ttl", "1h") } result += "," + partJSON } @@ -1254,6 +1407,313 @@ func countCacheControls(payload []byte) int { return count } +func parsePayloadObject(payload []byte) (map[string]any, bool) { + if len(payload) == 0 { + return nil, false + } + var root map[string]any + if err := json.Unmarshal(payload, &root); err != nil { + return nil, false + } + return root, true +} + +func marshalPayloadObject(original []byte, root map[string]any) []byte { + if root == nil { + return original + } + out, err := json.Marshal(root) + if err != nil { + return original + } + return out +} + +func asObject(v any) (map[string]any, bool) { + obj, ok := v.(map[string]any) + return obj, ok +} + +func asArray(v any) ([]any, bool) { + arr, ok := v.([]any) + return arr, ok +} + +func countCacheControlsMap(root map[string]any) int { + count := 0 + + if system, ok := asArray(root["system"]); ok { + for _, item := range system { + if obj, ok := asObject(item); ok { + if _, exists := obj["cache_control"]; exists { + count++ + } + } + } + } + + if tools, ok := asArray(root["tools"]); ok { + for _, item := range tools { + if obj, ok := asObject(item); ok { + if _, exists := obj["cache_control"]; exists { + count++ + } + } + } + } + + if messages, ok := asArray(root["messages"]); ok { + for _, msg := range messages { + msgObj, ok := asObject(msg) + if !ok { + continue + } + content, ok := asArray(msgObj["content"]) + if !ok { + continue + } + for _, item := range content { + if obj, ok := asObject(item); ok { + if _, exists := obj["cache_control"]; exists { + count++ + } + } + } + } + } + + return count +} + +func normalizeTTLForBlock(obj map[string]any, seen5m *bool) { + ccRaw, exists := obj["cache_control"] + if !exists { + return + } + cc, ok := asObject(ccRaw) + if !ok { + *seen5m = true + return + } + ttlRaw, ttlExists := cc["ttl"] + ttl, ttlIsString := ttlRaw.(string) + if !ttlExists || !ttlIsString || ttl != "1h" { + *seen5m = true + return + } + if *seen5m { + delete(cc, "ttl") + } +} + +func findLastCacheControlIndex(arr []any) int { + last := -1 + for idx, item := range arr { + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists { + last = idx + } + } + return last +} + +func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) { + for idx, item := range arr { + if *excess <= 0 { + return + } + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists && idx != preserveIdx { + delete(obj, "cache_control") + *excess-- + } + } +} + +func stripAllCacheControl(arr []any, excess *int) { + for _, item := range arr { + if *excess <= 0 { + return + } + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists { + delete(obj, "cache_control") + *excess-- + } + } +} + +func stripMessageCacheControl(messages []any, excess *int) { + for _, msg := range messages { + if *excess <= 0 { + return + } + msgObj, ok := asObject(msg) + if !ok { + continue + } + content, ok := asArray(msgObj["content"]) + if !ok { + continue + } + for _, item := range content { + if *excess <= 0 { + return + } + obj, ok := asObject(item) + if !ok { + continue + } + if _, exists := obj["cache_control"]; exists { + delete(obj, "cache_control") + *excess-- + } + } + } +} + +// normalizeCacheControlTTL ensures cache_control TTL values don't violate the +// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not +// appear after a 5m-TTL block anywhere in the evaluation order. +// +// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages. +// Within each section, blocks are evaluated in array order. A 5m (default) block +// followed by a 1h block at ANY later position is an error — including within +// the same section (e.g. system[1]=5m then system[3]=1h). +// +// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block +// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m). +func normalizeCacheControlTTL(payload []byte) []byte { + root, ok := parsePayloadObject(payload) + if !ok { + return payload + } + + seen5m := false + + if tools, ok := asArray(root["tools"]); ok { + for _, tool := range tools { + if obj, ok := asObject(tool); ok { + normalizeTTLForBlock(obj, &seen5m) + } + } + } + + if system, ok := asArray(root["system"]); ok { + for _, item := range system { + if obj, ok := asObject(item); ok { + normalizeTTLForBlock(obj, &seen5m) + } + } + } + + if messages, ok := asArray(root["messages"]); ok { + for _, msg := range messages { + msgObj, ok := asObject(msg) + if !ok { + continue + } + content, ok := asArray(msgObj["content"]) + if !ok { + continue + } + for _, item := range content { + if obj, ok := asObject(item); ok { + normalizeTTLForBlock(obj, &seen5m) + } + } + } + } + + return marshalPayloadObject(payload, root) +} + +// enforceCacheControlLimit removes excess cache_control blocks from a payload +// so the total does not exceed the Anthropic API limit (currently 4). +// +// Anthropic evaluates cache breakpoints in order: tools → system → messages. +// The most valuable breakpoints are: +// 1. Last tool — caches ALL tool definitions +// 2. Last system block — caches ALL system content +// 3. Recent messages — cache conversation context +// +// Removal priority (strip lowest-value first): +// +// Phase 1: system blocks earliest-first, preserving the last one. +// Phase 2: tool blocks earliest-first, preserving the last one. +// Phase 3: message content blocks earliest-first. +// Phase 4: remaining system blocks (last system). +// Phase 5: remaining tool blocks (last tool). +func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte { + root, ok := parsePayloadObject(payload) + if !ok { + return payload + } + + total := countCacheControlsMap(root) + if total <= maxBlocks { + return payload + } + + excess := total - maxBlocks + + var system []any + if arr, ok := asArray(root["system"]); ok { + system = arr + } + var tools []any + if arr, ok := asArray(root["tools"]); ok { + tools = arr + } + var messages []any + if arr, ok := asArray(root["messages"]); ok { + messages = arr + } + + if len(system) > 0 { + stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess) + } + if excess <= 0 { + return marshalPayloadObject(payload, root) + } + + if len(tools) > 0 { + stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess) + } + if excess <= 0 { + return marshalPayloadObject(payload, root) + } + + if len(messages) > 0 { + stripMessageCacheControl(messages, &excess) + } + if excess <= 0 { + return marshalPayloadObject(payload, root) + } + + if len(system) > 0 { + stripAllCacheControl(system, &excess) + } + if excess <= 0 { + return marshalPayloadObject(payload, root) + } + + if len(tools) > 0 { + stripAllCacheControl(tools, &excess) + } + + return marshalPayloadObject(payload, root) +} + // injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching. // Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache." // This enables caching of conversation history, which is especially beneficial for long multi-turn conversations. diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index dd29ed8a..c4a4d644 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -2,12 +2,15 @@ package executor import ( "bytes" + "compress/gzip" "context" "io" "net/http" "net/http/httptest" + "strings" "testing" + "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -348,3 +351,619 @@ func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) { t.Fatalf("built-in tool_reference should not be prefixed, got %q", got) } } + +func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) { + payload := []byte(`{ + "tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}], + "system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}], + "messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}] + }`) + + out := normalizeCacheControlTTL(payload) + + if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" { + t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h") + } + if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() { + t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block") + } +} + +func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { + payload := []byte(`{ + "tools": [ + {"name":"t1","cache_control":{"type":"ephemeral"}}, + {"name":"t2","cache_control":{"type":"ephemeral"}} + ], + "system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}], + "messages": [ + {"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]}, + {"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]} + ] + }`) + + out := enforceCacheControlLimit(payload, 4) + + if got := countCacheControls(out); got != 4 { + t.Fatalf("cache_control count = %d, want 4", got) + } + if gjson.GetBytes(out, "tools.0.cache_control").Exists() { + t.Fatalf("tools.0.cache_control should be removed first (non-last tool)") + } + if !gjson.GetBytes(out, "tools.1.cache_control").Exists() { + t.Fatalf("tools.1.cache_control (last tool) should be preserved") + } + if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() { + t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough") + } +} + +func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) { + payload := []byte(`{ + "tools": [ + {"name":"t1","cache_control":{"type":"ephemeral"}}, + {"name":"t2","cache_control":{"type":"ephemeral"}}, + {"name":"t3","cache_control":{"type":"ephemeral"}}, + {"name":"t4","cache_control":{"type":"ephemeral"}}, + {"name":"t5","cache_control":{"type":"ephemeral"}} + ] + }`) + + out := enforceCacheControlLimit(payload, 4) + + if got := countCacheControls(out); got != 4 { + t.Fatalf("cache_control count = %d, want 4", got) + } + if gjson.GetBytes(out, "tools.0.cache_control").Exists() { + t.Fatalf("tools.0.cache_control should be removed to satisfy max=4") + } + if !gjson.GetBytes(out, "tools.4.cache_control").Exists() { + t.Fatalf("last tool cache_control should be preserved when possible") + } +} + +func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) { + var seenBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + seenBody = bytes.Clone(body) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"input_tokens":42}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + + payload := []byte(`{ + "tools": [ + {"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}, + {"name":"t2","cache_control":{"type":"ephemeral"}} + ], + "system": [ + {"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}}, + {"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}} + ], + "messages": [ + {"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}, + {"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]} + ] + }`) + + _, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-haiku-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + + if len(seenBody) == 0 { + t.Fatal("expected count_tokens request body to be captured") + } + if got := countCacheControls(seenBody); got > 4 { + t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got) + } + if hasTTLOrderingViolation(seenBody) { + t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody)) + } +} + +func hasTTLOrderingViolation(payload []byte) bool { + seen5m := false + violates := false + + checkCC := func(cc gjson.Result) { + if !cc.Exists() || violates { + return + } + ttl := cc.Get("ttl").String() + if ttl != "1h" { + seen5m = true + return + } + if seen5m { + violates = true + } + } + + tools := gjson.GetBytes(payload, "tools") + if tools.IsArray() { + tools.ForEach(func(_, tool gjson.Result) bool { + checkCC(tool.Get("cache_control")) + return !violates + }) + } + + system := gjson.GetBytes(payload, "system") + if system.IsArray() { + system.ForEach(func(_, item gjson.Result) bool { + checkCC(item.Get("cache_control")) + return !violates + }) + } + + messages := gjson.GetBytes(payload, "messages") + if messages.IsArray() { + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + content.ForEach(func(_, item gjson.Result) bool { + checkCC(item.Get("cache_control")) + return !violates + }) + } + return !violates + }) + } + + return violates +} + +func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) { + testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error { + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + return err + }) +} + +func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) { + testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error { + _, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + return err + }) +} + +func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) { + testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error { + _, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")}) + return err + }) +} + +func testClaudeExecutorInvalidCompressedErrorBody( + t *testing.T, + invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error, +) { + t.Helper() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Encoding", "gzip") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("not-a-valid-gzip-stream")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + err := invoke(executor, auth, payload) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "failed to decode error response body") { + t.Fatalf("expected decode failure message, got: %v", err) + } + if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest { + t.Fatalf("expected status code 400, got: %v", err) + } +} + +// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming +// requests use Accept-Encoding: identity so the upstream cannot respond with a +// compressed SSE body that would silently break the line scanner. +func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) { + var gotEncoding, gotAccept string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotEncoding = r.Header.Get("Accept-Encoding") + gotAccept = r.Header.Get("Accept") + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected chunk error: %v", chunk.Err) + } + } + + if gotEncoding != "identity" { + t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity") + } + if gotAccept != "text/event-stream" { + t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream") + } +} + +// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming +// requests keep the full accept-encoding to allow response compression (which +// decodeResponseBody handles correctly). +func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) { + var gotEncoding, gotAccept string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotEncoding = r.Header.Get("Accept-Encoding") + gotAccept = r.Header.Get("Accept") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`)) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("Execute error: %v", err) + } + + if gotEncoding != "gzip, deflate, br, zstd" { + t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd") + } + if gotAccept != "application/json" { + t.Errorf("Accept = %q, want %q", gotAccept, "application/json") + } +} + +// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming +// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before +// the line scanner runs, so SSE chunks are not silently dropped. +func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n")) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Content-Encoding", "gzip") + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + var combined strings.Builder + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("chunk error: %v", chunk.Err) + } + combined.Write(chunk.Payload) + } + + if combined.Len() == 0 { + t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)") + } + if !strings.Contains(combined.String(), "message_stop") { + t.Errorf("expected SSE content in chunks, got: %q", combined.String()) + } +} + +// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody +// detects gzip-compressed content via magic bytes even when Content-Encoding is absent. +func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) { + const plaintext = "data: {\"type\":\"message_stop\"}\n" + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(plaintext)) + _ = gz.Close() + + rc := io.NopCloser(&buf) + decoded, err := decodeResponseBody(rc, "") + if err != nil { + t.Fatalf("decodeResponseBody error: %v", err) + } + defer decoded.Close() + + got, err := io.ReadAll(decoded) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(got) != plaintext { + t.Errorf("decoded = %q, want %q", got, plaintext) + } +} + +// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns +// plain text untouched when Content-Encoding is absent and no magic bytes match. +func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) { + const plaintext = "data: {\"type\":\"message_stop\"}\n" + rc := io.NopCloser(strings.NewReader(plaintext)) + decoded, err := decodeResponseBody(rc, "") + if err != nil { + t.Fatalf("decodeResponseBody error: %v", err) + } + defer decoded.Close() + + got, err := io.ReadAll(decoded) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(got) != plaintext { + t.Errorf("decoded = %q, want %q", got, plaintext) + } +} + +// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full +// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting +// Content-Encoding (a misbehaving upstream), the magic-byte sniff in +// decodeResponseBody still decompresses it, so chunks reach the caller. +func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) { + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n")) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + // Intentionally omit Content-Encoding to simulate misbehaving upstream. + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + + var combined strings.Builder + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("chunk error: %v", chunk.Err) + } + combined.Write(chunk.Payload) + } + + if combined.Len() == 0 { + t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)") + } + if !strings.Contains(combined.String(), "message_stop") { + t.Errorf("unexpected chunk content: %q", combined.String()) + } +} + +// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies +// that injecting Accept-Encoding via auth.Attributes cannot override the stream +// path's enforced identity encoding. +func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) { + var gotEncoding string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotEncoding = r.Header.Get("Accept-Encoding") + w.Header().Set("Content-Type", "text/event-stream") + _, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n")) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + // Inject Accept-Encoding via the custom header attribute mechanism. + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + "header:Accept-Encoding": "gzip, deflate, br, zstd", + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + for chunk := range result.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected chunk error: %v", chunk.Err) + } + } + + if gotEncoding != "identity" { + t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding) + } +} + +// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody +// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when +// Content-Encoding is absent. +func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) { + const plaintext = "data: {\"type\":\"message_stop\"}\n" + + var buf bytes.Buffer + enc, err := zstd.NewWriter(&buf) + if err != nil { + t.Fatalf("zstd.NewWriter: %v", err) + } + _, _ = enc.Write([]byte(plaintext)) + _ = enc.Close() + + rc := io.NopCloser(&buf) + decoded, err := decodeResponseBody(rc, "") + if err != nil { + t.Fatalf("decodeResponseBody error: %v", err) + } + defer decoded.Close() + + got, err := io.ReadAll(decoded) + if err != nil { + t.Fatalf("ReadAll error: %v", err) + } + if string(got) != plaintext { + t.Errorf("decoded = %q, want %q", got, plaintext) + } +} + +// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the +// error path (4xx) correctly decompresses a gzip body even when the upstream omits +// the Content-Encoding header. This closes the gap left by PR #1771, which only +// fixed header-declared compression on the error path. +func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) { + const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}` + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(errJSON)) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Intentionally omit Content-Encoding to simulate misbehaving upstream. + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err == nil { + t.Fatal("expected an error for 400 response, got nil") + } + if !strings.Contains(err.Error(), "test error") { + t.Errorf("error message should contain decompressed JSON, got: %q", err.Error()) + } +} + +// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies +// the same for the streaming executor: 4xx gzip body without Content-Encoding is +// decoded and the error message is readable. +func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) { + const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}` + + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + _, _ = gz.Write([]byte(errJSON)) + _ = gz.Close() + compressedBody := buf.Bytes() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Intentionally omit Content-Encoding to simulate misbehaving upstream. + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write(compressedBody) + })) + defer server.Close() + + executor := NewClaudeExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{ + "api_key": "key-123", + "base_url": server.URL, + }} + payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) + + _, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{ + Model: "claude-3-5-sonnet-20241022", + Payload: payload, + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("claude"), + }) + if err == nil { + t.Fatal("expected an error for 400 response, got nil") + } + if !strings.Contains(err.Error(), "stream test error") { + t.Errorf("error message should contain decompressed JSON, got: %q", err.Error()) + } +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index a0cbc0d5..30092ec7 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -616,6 +616,10 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form if promptCacheKey.Exists() { cache.ID = promptCacheKey.String() } + } else if from == "openai" { + if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" { + cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String() + } } if cache.ID != "" { diff --git a/internal/runtime/executor/codex_executor_cache_test.go b/internal/runtime/executor/codex_executor_cache_test.go new file mode 100644 index 00000000..d6dca031 --- /dev/null +++ b/internal/runtime/executor/codex_executor_cache_test.go @@ -0,0 +1,64 @@ +package executor + +import ( + "context" + "io" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + "github.com/tidwall/gjson" +) + +func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) { + recorder := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(recorder) + ginCtx.Set("apiKey", "test-api-key") + + ctx := context.WithValue(context.Background(), "gin", ginCtx) + executor := &CodexExecutor{} + rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`) + req := cliproxyexecutor.Request{ + Model: "gpt-5.3-codex", + Payload: []byte(`{"model":"gpt-5.3-codex"}`), + } + url := "https://example.com/responses" + + httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON) + if err != nil { + t.Fatalf("cacheHelper error: %v", err) + } + + body, errRead := io.ReadAll(httpReq.Body) + if errRead != nil { + t.Fatalf("read request body: %v", errRead) + } + + expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String() + gotKey := gjson.GetBytes(body, "prompt_cache_key").String() + if gotKey != expectedKey { + t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey) + } + if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey { + t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey) + } + if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey { + t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey) + } + + httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON) + if err != nil { + t.Fatalf("cacheHelper error (second call): %v", err) + } + body2, errRead2 := io.ReadAll(httpReq2.Body) + if errRead2 != nil { + t.Fatalf("read request body (second call): %v", errRead2) + } + gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String() + if gotKey2 != expectedKey { + t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey) + } +} diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 7c887221..1f340050 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -31,7 +31,7 @@ import ( ) const ( - codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04" + codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06" codexResponsesWebsocketIdleTimeout = 5 * time.Minute codexResponsesWebsocketHandshakeTO = 30 * time.Second ) @@ -57,11 +57,6 @@ type codexWebsocketSession struct { wsURL string authID string - // connCreateSent tracks whether a `response.create` message has been successfully sent - // on the current websocket connection. The upstream expects the first message on each - // connection to be `response.create`. - connCreateSent bool - writeMu sync.Mutex activeMu sync.Mutex @@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut defer sess.reqMu.Unlock() } - allowAppend := true - if sess != nil { - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - } - wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBody := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut // execution session. connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) if errDialRetry == nil && connRetry != nil { - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBodyRetry := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut return resp, errSend } } - markCodexWebsocketCreateSent(sess, conn, wsReqBody) for { if ctx != nil && ctx.Err() != nil { @@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) var authID, authLabel, authType, authValue string - if auth != nil { - authID = auth.ID - authLabel = auth.Label - authType, authValue = auth.AccountInfo() - } + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() executionSessionID := executionSessionIDFromOptions(opts) var sess *codexWebsocketSession if executionSessionID != "" { sess = e.getOrCreateSession(executionSessionID) - sess.reqMu.Lock() + if sess != nil { + sess.reqMu.Lock() + } } - allowAppend := true - if sess != nil { - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - } - wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBody := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr sess.reqMu.Unlock() return nil, errDialRetry } - sess.connMu.Lock() - allowAppend = sess.connCreateSent - sess.connMu.Unlock() - wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend) + wsReqBodyRetry := buildCodexWebsocketRequestBody(body) recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: wsURL, Method: "WEBSOCKET", @@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr return nil, errSend } } - markCodexWebsocketCreateSent(sess, conn, wsReqBody) out := make(chan cliproxyexecutor.StreamChunk) go func() { @@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con return conn.WriteMessage(websocket.TextMessage, payload) } -func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte { +func buildCodexWebsocketRequestBody(body []byte) []byte { if len(body) == 0 { return nil } - // Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns. - // The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation). - // Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive. - // - // NOTE: The upstream expects the first websocket event on each connection to be `response.create`, - // so we only use `response.append` after we have initialized the current connection. - if allowAppend { - if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" { - inputNode := gjson.GetBytes(body, "input") - wsReqBody := []byte(`{}`) - wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append") - if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" { - wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw)) - return wsReqBody - } - wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]")) - return wsReqBody - } - } - + // Match codex-rs websocket v2 semantics: every request is `response.create`. + // Incremental follow-up turns continue on the same websocket using + // `previous_response_id` + incremental `input`, not `response.append`. wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create") if errSet == nil && len(wsReqBody) > 0 { return wsReqBody @@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, } } -func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) { - if sess == nil || conn == nil || len(payload) == 0 { - return - } - if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" { - return - } - - sess.connMu.Lock() - if sess.conn == conn { - sess.connCreateSent = true - } - sess.connMu.Unlock() -} - func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer { dialer := &websocket.Dialer{ Proxy: http.ProxyFromEnvironment, @@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) { } } -func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} { - done := make(chan struct{}) - if ctx == nil || conn == nil { - return done - } - go func() { - select { - case <-done: - case <-ctx.Done(): - _ = conn.Close() - } - }() - return done -} - -func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} { - done := make(chan struct{}) - if ctx == nil || conn == nil { - return done - } - go func() { - select { - case <-done: - case <-ctx.Done(): - _ = conn.SetReadDeadline(time.Now()) - } - }() - return done -} - func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string { if len(opts.Metadata) == 0 { return "" @@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * sess.conn = conn sess.wsURL = wsURL sess.authID = authID - sess.connCreateSent = false sess.readerConn = conn sess.connMu.Unlock() @@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes return } sess.conn = nil - sess.connCreateSent = false if sess.readerConn == conn { sess.readerConn = nil } @@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess authID := sess.authID wsURL := sess.wsURL sess.conn = nil - sess.connCreateSent = false if sess.readerConn == conn { sess.readerConn = nil } diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go new file mode 100644 index 00000000..1fd68513 --- /dev/null +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -0,0 +1,36 @@ +package executor + +import ( + "context" + "net/http" + "testing" + + "github.com/tidwall/gjson" +) + +func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) { + body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`) + + wsReqBody := buildCodexWebsocketRequestBody(body) + + if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" { + t.Fatalf("type = %s, want response.create", got) + } + if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" { + t.Fatalf("previous_response_id = %s, want resp-1", got) + } + if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" { + t.Fatalf("input item id mismatch") + } + if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" { + t.Fatalf("unexpected websocket request type: %s", got) + } +} + +func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { + headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "") + + if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { + t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) + } +} diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index cb3ffb59..1be245b7 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -16,7 +16,6 @@ import ( "strings" "time" - "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" @@ -81,7 +80,7 @@ func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth return statusErr{code: http.StatusUnauthorized, msg: "missing access token"} } req.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(req) + applyGeminiCLIHeaders(req, "unknown") return nil } @@ -189,7 +188,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth } reqHTTP.Header.Set("Content-Type", "application/json") reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) + applyGeminiCLIHeaders(reqHTTP, attemptModel) reqHTTP.Header.Set("Accept", "application/json") recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: url, @@ -334,7 +333,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } reqHTTP.Header.Set("Content-Type", "application/json") reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) + applyGeminiCLIHeaders(reqHTTP, attemptModel) reqHTTP.Header.Set("Accept", "text/event-stream") recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: url, @@ -515,7 +514,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. } reqHTTP.Header.Set("Content-Type", "application/json") reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) - applyGeminiCLIHeaders(reqHTTP) + applyGeminiCLIHeaders(reqHTTP, baseModel) reqHTTP.Header.Set("Accept", "application/json") recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: url, @@ -738,21 +737,11 @@ func stringValue(m map[string]any, key string) string { } // applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream. -func applyGeminiCLIHeaders(r *http.Request) { - var ginHeaders http.Header - if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { - ginHeaders = ginCtx.Request.Header - } - - misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "google-api-nodejs-client/9.15.1") - misc.EnsureHeader(r.Header, ginHeaders, "X-Goog-Api-Client", "gl-node/22.17.0") - misc.EnsureHeader(r.Header, ginHeaders, "Client-Metadata", geminiCLIClientMetadata()) -} - -// geminiCLIClientMetadata returns a compact metadata string required by upstream. -func geminiCLIClientMetadata() string { - // Keep parity with CLI client defaults - return "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI" +// User-Agent is always forced to the GeminiCLI format regardless of the client's value, +// so that upstream identifies the request as a native GeminiCLI client. +func applyGeminiCLIHeaders(r *http.Request, model string) { + r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model)) + r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader) } // cliPreviewFallbackOrder returns preview model candidates for a base model. diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index 8a5a1d7d..b8a0fcae 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -293,7 +293,7 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri if config.Mode != ModeLevel { return config } - if !isBudgetBasedProvider(toFormat) || !isLevelBasedProvider(fromFormat) { + if !isBudgetCapableProvider(toFormat) { return config } budget, ok := ConvertLevelToBudget(string(config.Level)) @@ -353,6 +353,26 @@ func extractClaudeConfig(body []byte) ThinkingConfig { if thinkingType == "disabled" { return ThinkingConfig{Mode: ModeNone, Budget: 0} } + if thinkingType == "adaptive" || thinkingType == "auto" { + // Claude adaptive thinking uses output_config.effort (low/medium/high/max). + // We only treat it as a thinking config when effort is explicitly present; + // otherwise we passthrough and let upstream defaults apply. + if effort := gjson.GetBytes(body, "output_config.effort"); effort.Exists() && effort.Type == gjson.String { + value := strings.ToLower(strings.TrimSpace(effort.String())) + if value == "" { + return ThinkingConfig{} + } + switch value { + case "none": + return ThinkingConfig{Mode: ModeNone, Budget: 0} + case "auto": + return ThinkingConfig{Mode: ModeAuto, Budget: -1} + default: + return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)} + } + } + return ThinkingConfig{} + } // Check budget_tokens if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() { diff --git a/internal/thinking/convert.go b/internal/thinking/convert.go index 776ccef6..89db7745 100644 --- a/internal/thinking/convert.go +++ b/internal/thinking/convert.go @@ -16,6 +16,9 @@ var levelToBudgetMap = map[string]int{ "medium": 8192, "high": 24576, "xhigh": 32768, + // "max" is used by Claude adaptive thinking effort. We map it to a large budget + // and rely on per-model clamping when converting to budget-only providers. + "max": 128000, } // ConvertLevelToBudget converts a thinking level to a budget value. @@ -31,6 +34,7 @@ var levelToBudgetMap = map[string]int{ // - medium → 8192 // - high → 24576 // - xhigh → 32768 +// - max → 128000 // // Returns: // - budget: The converted budget value @@ -92,6 +96,43 @@ func ConvertBudgetToLevel(budget int) (string, bool) { } } +// HasLevel reports whether the given target level exists in the levels slice. +// Matching is case-insensitive with leading/trailing whitespace trimmed. +func HasLevel(levels []string, target string) bool { + for _, level := range levels { + if strings.EqualFold(strings.TrimSpace(level), target) { + return true + } + } + return false +} + +// MapToClaudeEffort maps a generic thinking level string to a Claude adaptive +// thinking effort value (low/medium/high/max). +// +// supportsMax indicates whether the target model supports "max" effort. +// Returns the mapped effort and true if the level is valid, or ("", false) otherwise. +func MapToClaudeEffort(level string, supportsMax bool) (string, bool) { + level = strings.ToLower(strings.TrimSpace(level)) + switch level { + case "": + return "", false + case "minimal": + return "low", true + case "low", "medium", "high": + return level, true + case "xhigh", "max": + if supportsMax { + return "max", true + } + return "high", true + case "auto": + return "high", true + default: + return "", false + } +} + // ModelCapability describes the thinking format support of a model. type ModelCapability int diff --git a/internal/thinking/provider/claude/apply.go b/internal/thinking/provider/claude/apply.go index 3c74d514..275be469 100644 --- a/internal/thinking/provider/claude/apply.go +++ b/internal/thinking/provider/claude/apply.go @@ -1,8 +1,10 @@ // Package claude implements thinking configuration scaffolding for Claude models. // -// Claude models use the thinking.budget_tokens format with values in the range -// 1024-128000. Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5), -// while older models do not. +// Claude models support two thinking control styles: +// - Manual thinking: thinking.type="enabled" with thinking.budget_tokens (token budget) +// - Adaptive thinking (Claude 4.6): thinking.type="adaptive" with output_config.effort (low/medium/high/max) +// +// Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5), while older models do not. // See: _bmad-output/planning-artifacts/architecture.md#Epic-6 package claude @@ -34,7 +36,11 @@ func init() { // - Budget clamping to model range // - ZeroAllowed constraint enforcement // -// Apply only processes ModeBudget and ModeNone; other modes are passed through unchanged. +// Apply processes: +// - ModeBudget: manual thinking budget_tokens +// - ModeLevel: adaptive thinking effort (Claude 4.6) +// - ModeAuto: provider default adaptive/manual behavior +// - ModeNone: disabled // // Expected output format when enabled: // @@ -45,6 +51,17 @@ func init() { // } // } // +// Expected output format for adaptive: +// +// { +// "thinking": { +// "type": "adaptive" +// }, +// "output_config": { +// "effort": "high" +// } +// } +// // Expected output format when disabled: // // { @@ -60,30 +77,91 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * return body, nil } - // Only process ModeBudget and ModeNone; other modes pass through - // (caller should use ValidateConfig first to normalize modes) - if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone { - return body, nil - } - if len(body) == 0 || !gjson.ValidBytes(body) { body = []byte(`{}`) } - // Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced) - // Decide enabled/disabled based on budget value - if config.Budget == 0 { + supportsAdaptive := modelInfo != nil && modelInfo.Thinking != nil && len(modelInfo.Thinking.Levels) > 0 + + switch config.Mode { + case thinking.ModeNone: result, _ := sjson.SetBytes(body, "thinking.type", "disabled") result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } return result, nil + + case thinking.ModeLevel: + // Adaptive thinking effort is only valid when the model advertises discrete levels. + // (Claude 4.6 uses output_config.effort.) + if supportsAdaptive && config.Level != "" { + result, _ := sjson.SetBytes(body, "thinking.type", "adaptive") + result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level)) + return result, nil + } + + // Fallback for non-adaptive Claude models: convert level to budget_tokens. + if budget, ok := thinking.ConvertLevelToBudget(string(config.Level)); ok { + config.Mode = thinking.ModeBudget + config.Budget = budget + config.Level = "" + } else { + return body, nil + } + fallthrough + + case thinking.ModeBudget: + // Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced). + // Decide enabled/disabled based on budget value. + if config.Budget == 0 { + result, _ := sjson.SetBytes(body, "thinking.type", "disabled") + result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } + return result, nil + } + + result, _ := sjson.SetBytes(body, "thinking.type", "enabled") + result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget) + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } + + // Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint). + result = a.normalizeClaudeBudget(result, config.Budget, modelInfo) + return result, nil + + case thinking.ModeAuto: + // For Claude 4.6 models, auto maps to adaptive thinking with upstream defaults. + if supportsAdaptive { + result, _ := sjson.SetBytes(body, "thinking.type", "adaptive") + result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + // Explicit effort is optional for adaptive thinking; omit it to allow upstream default. + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } + return result, nil + } + + // Legacy fallback: enable thinking without specifying budget_tokens. + result, _ := sjson.SetBytes(body, "thinking.type", "enabled") + result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } + return result, nil + + default: + return body, nil } - - result, _ := sjson.SetBytes(body, "thinking.type", "enabled") - result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget) - - // Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint) - result = a.normalizeClaudeBudget(result, config.Budget, modelInfo) - return result, nil } // normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens. @@ -141,7 +219,7 @@ func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) } func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) { - if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto { + if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto && config.Mode != thinking.ModeLevel { return body, nil } @@ -153,14 +231,36 @@ func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, case thinking.ModeNone: result, _ := sjson.SetBytes(body, "thinking.type", "disabled") result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } return result, nil case thinking.ModeAuto: result, _ := sjson.SetBytes(body, "thinking.type", "enabled") result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } + return result, nil + case thinking.ModeLevel: + // For user-defined models, interpret ModeLevel as Claude adaptive thinking effort. + // Upstream is responsible for validating whether the target model supports it. + if config.Level == "" { + return body, nil + } + result, _ := sjson.SetBytes(body, "thinking.type", "adaptive") + result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens") + result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level)) return result, nil default: result, _ := sjson.SetBytes(body, "thinking.type", "enabled") result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget) + result, _ = sjson.DeleteBytes(result, "output_config.effort") + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } return result, nil } } diff --git a/internal/thinking/provider/codex/apply.go b/internal/thinking/provider/codex/apply.go index 3bed318b..0f336359 100644 --- a/internal/thinking/provider/codex/apply.go +++ b/internal/thinking/provider/codex/apply.go @@ -7,8 +7,6 @@ package codex import ( - "strings" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/tidwall/gjson" @@ -68,7 +66,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * effort := "" support := modelInfo.Thinking if config.Budget == 0 { - if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) { + if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) { effort = string(thinking.LevelNone) } } @@ -120,12 +118,3 @@ func applyCompatibleCodex(body []byte, config thinking.ThinkingConfig) ([]byte, result, _ := sjson.SetBytes(body, "reasoning.effort", effort) return result, nil } - -func hasLevel(levels []string, target string) bool { - for _, level := range levels { - if strings.EqualFold(strings.TrimSpace(level), target) { - return true - } - } - return false -} diff --git a/internal/thinking/provider/openai/apply.go b/internal/thinking/provider/openai/apply.go index eaad30ee..c77c1ab8 100644 --- a/internal/thinking/provider/openai/apply.go +++ b/internal/thinking/provider/openai/apply.go @@ -6,8 +6,6 @@ package openai import ( - "strings" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/tidwall/gjson" @@ -65,7 +63,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * effort := "" support := modelInfo.Thinking if config.Budget == 0 { - if support.ZeroAllowed || hasLevel(support.Levels, string(thinking.LevelNone)) { + if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) { effort = string(thinking.LevelNone) } } @@ -117,12 +115,3 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte, result, _ := sjson.SetBytes(body, "reasoning_effort", effort) return result, nil } - -func hasLevel(levels []string, target string) bool { - for _, level := range levels { - if strings.EqualFold(strings.TrimSpace(level), target) { - return true - } - } - return false -} diff --git a/internal/thinking/strip.go b/internal/thinking/strip.go index 514ab3f8..85498c01 100644 --- a/internal/thinking/strip.go +++ b/internal/thinking/strip.go @@ -30,7 +30,7 @@ func StripThinkingConfig(body []byte, provider string) []byte { var paths []string switch provider { case "claude": - paths = []string{"thinking"} + paths = []string{"thinking", "output_config.effort"} case "gemini": paths = []string{"generationConfig.thinkingConfig"} case "gemini-cli", "antigravity": @@ -59,5 +59,12 @@ func StripThinkingConfig(body []byte, provider string) []byte { for _, path := range paths { result, _ = sjson.DeleteBytes(result, path) } + + // Avoid leaving an empty output_config object for Claude when effort was the only field. + if provider == "claude" { + if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 { + result, _ = sjson.DeleteBytes(result, "output_config") + } + } return result } diff --git a/internal/thinking/suffix.go b/internal/thinking/suffix.go index 275c0856..7f2959da 100644 --- a/internal/thinking/suffix.go +++ b/internal/thinking/suffix.go @@ -109,7 +109,7 @@ func ParseSpecialSuffix(rawSuffix string) (mode ThinkingMode, ok bool) { // ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level. // // This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level. -// Only discrete effort levels are valid: minimal, low, medium, high, xhigh. +// Only discrete effort levels are valid: minimal, low, medium, high, xhigh, max. // Level matching is case-insensitive. // // Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix @@ -140,6 +140,8 @@ func ParseLevelSuffix(rawSuffix string) (level ThinkingLevel, ok bool) { return LevelHigh, true case "xhigh": return LevelXHigh, true + case "max": + return LevelMax, true default: return "", false } diff --git a/internal/thinking/types.go b/internal/thinking/types.go index 6ae1e088..5e45fc6b 100644 --- a/internal/thinking/types.go +++ b/internal/thinking/types.go @@ -54,6 +54,9 @@ const ( LevelHigh ThinkingLevel = "high" // LevelXHigh sets extra-high thinking effort LevelXHigh ThinkingLevel = "xhigh" + // LevelMax sets maximum thinking effort. + // This is currently used by Claude 4.6 adaptive thinking (opus supports "max"). + LevelMax ThinkingLevel = "max" ) // ThinkingConfig represents a unified thinking configuration. diff --git a/internal/thinking/validate.go b/internal/thinking/validate.go index f082ad56..4a3ca97c 100644 --- a/internal/thinking/validate.go +++ b/internal/thinking/validate.go @@ -53,7 +53,17 @@ func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFo return &config, nil } - allowClampUnsupported := isBudgetBasedProvider(fromFormat) && isLevelBasedProvider(toFormat) + // allowClampUnsupported determines whether to clamp unsupported levels instead of returning an error. + // This applies when crossing provider families (e.g., openai→gemini, claude→gemini) and the target + // model supports discrete levels. Same-family conversions require strict validation. + toCapability := detectModelCapability(modelInfo) + toHasLevelSupport := toCapability == CapabilityLevelOnly || toCapability == CapabilityHybrid + allowClampUnsupported := toHasLevelSupport && !isSameProviderFamily(fromFormat, toFormat) + + // strictBudget determines whether to enforce strict budget range validation. + // This applies when: (1) config comes from request body (not suffix), (2) source format is known, + // and (3) source and target are in the same provider family. Cross-family or suffix-based configs + // are clamped instead of rejected to improve interoperability. strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat) budgetDerivedFromLevel := false @@ -201,7 +211,7 @@ func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupp } // standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest. -var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh} +var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh, LevelMax} // clampLevel clamps the given level to the nearest supported level. // On tie, prefers the lower level. @@ -325,7 +335,9 @@ func normalizeLevels(levels []string) []string { return out } -func isBudgetBasedProvider(provider string) bool { +// isBudgetCapableProvider returns true if the provider supports budget-based thinking. +// These providers may also support level-based thinking (hybrid models). +func isBudgetCapableProvider(provider string) bool { switch provider { case "gemini", "gemini-cli", "antigravity", "claude": return true @@ -334,15 +346,6 @@ func isBudgetBasedProvider(provider string) bool { } } -func isLevelBasedProvider(provider string) bool { - switch provider { - case "openai", "openai-response", "codex": - return true - default: - return false - } -} - func isGeminiFamily(provider string) bool { switch provider { case "gemini", "gemini-cli", "antigravity": @@ -352,11 +355,21 @@ func isGeminiFamily(provider string) bool { } } +func isOpenAIFamily(provider string) bool { + switch provider { + case "openai", "openai-response", "codex": + return true + default: + return false + } +} + func isSameProviderFamily(from, to string) bool { if from == to { return true } - return isGeminiFamily(from) && isGeminiFamily(to) + return (isGeminiFamily(from) && isGeminiFamily(to)) || + (isOpenAIFamily(from) && isOpenAIFamily(to)) } func abs(x int) int { diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index a3f9fa48..8c1a38c5 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -431,6 +431,33 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.SetRaw(out, "request.tools", toolsJSON) } + // tool_choice + toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice") + if toolChoiceResult.Exists() { + toolChoiceType := "" + toolChoiceName := "" + if toolChoiceResult.IsObject() { + toolChoiceType = toolChoiceResult.Get("type").String() + toolChoiceName = toolChoiceResult.Get("name").String() + } else if toolChoiceResult.Type == gjson.String { + toolChoiceType = toolChoiceResult.String() + } + + switch toolChoiceType { + case "auto": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO") + case "none": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE") + case "any": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") + case "tool": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") + if toolChoiceName != "" { + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) + } + } + } + // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() { switch t.Get("type").String() { @@ -440,16 +467,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) } - case "auto": - // Amp sends thinking.type="auto" — use max budget from model config - // Antigravity API for Claude models requires a concrete positive budget, - // not -1. Use a high default that ApplyThinking will cap to model max. - out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", 64000) - out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) - case "adaptive": - // Keep adaptive as a high level sentinel; ApplyThinking resolves it - // to model-specific max capability. - out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") + case "adaptive", "auto": + // For adaptive thinking: + // - If output_config.effort is explicitly present, pass through as thinkingLevel. + // - Otherwise, treat it as "enabled with target-model maximum" and emit high. + // ApplyThinking handles clamping to target model's supported levels. + effort := "" + if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String { + effort = strings.ToLower(strings.TrimSpace(v.String())) + } + if effort != "" { + if effort == "max" { + effort = "high" + } + out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) + } else { + out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") + } out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) } } diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index 865db668..39dc493d 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -193,6 +193,42 @@ func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) { } } +func TestConvertClaudeRequestToAntigravity_ToolChoice_SpecificTool(t *testing.T) { + inputJSON := []byte(`{ + "model": "gemini-3-flash-preview", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hi"} + ] + } + ], + "tools": [ + { + "name": "json", + "description": "A JSON tool", + "input_schema": { + "type": "object", + "properties": {} + } + } + ], + "tool_choice": {"type": "tool", "name": "json"} + }`) + + output := ConvertClaudeRequestToAntigravity("gemini-3-flash-preview", inputJSON, false) + outputStr := string(output) + + if got := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" { + t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got) + } + allowed := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array() + if len(allowed) != 1 || allowed[0].String() != "json" { + t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw) + } +} + func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) { inputJSON := []byte(`{ "model": "claude-3-5-sonnet-20240620", @@ -1199,3 +1235,64 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw) } } + +func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) { + tests := []struct { + name string + effort string + expected string + }{ + {"low", "low", "low"}, + {"medium", "medium", "medium"}, + {"high", "high", "high"}, + {"max", "max", "high"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-opus-4-6-thinking", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "thinking": {"type": "adaptive"}, + "output_config": {"effort": "` + tt.effort + `"} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false) + outputStr := string(output) + + thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig") + if !thinkingConfig.Exists() { + t.Fatal("thinkingConfig should exist for adaptive thinking") + } + if thinkingConfig.Get("thinkingLevel").String() != tt.expected { + t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String()) + } + if !thinkingConfig.Get("includeThoughts").Bool() { + t.Error("includeThoughts should be true") + } + }) + } +} + +func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-opus-4-6-thinking", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "thinking": {"type": "adaptive"} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false) + outputStr := string(output) + + thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig") + if !thinkingConfig.Exists() { + t.Fatal("thinkingConfig should exist for adaptive thinking without effort") + } + if thinkingConfig.Get("thinkingLevel").String() != "high" { + t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String()) + } + if !thinkingConfig.Get("includeThoughts").Bool() { + t.Error("includeThoughts should be true") + } +} diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index 85b28b8b..7fb25b2a 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -34,6 +34,11 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ // Model out, _ = sjson.SetBytes(out, "model", modelName) + // Let user-provided generationConfig pass through + if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() { + out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw)) + } + // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig. // Inline translation-only mapping; capability checks happen later in ApplyThinking. re := gjson.GetBytes(rawJSON, "reasoning_effort") @@ -207,6 +212,33 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ } else { log.Warnf("Unknown file name extension '%s' in user message, skip", ext) } + case "input_audio": + audioData := item.Get("input_audio.data").String() + audioFormat := item.Get("input_audio.format").String() + if audioData != "" { + audioMimeMap := map[string]string{ + "mp3": "audio/mpeg", + "wav": "audio/wav", + "ogg": "audio/ogg", + "flac": "audio/flac", + "aac": "audio/aac", + "webm": "audio/webm", + "pcm16": "audio/pcm", + "g711_ulaw": "audio/basic", + "g711_alaw": "audio/basic", + } + mimeType := "audio/wav" + if audioFormat != "" { + if mapped, ok := audioMimeMap[audioFormat]; ok { + mimeType = mapped + } else { + mimeType = "audio/" + audioFormat + } + } + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData) + p++ + } } } } diff --git a/internal/translator/claude/gemini/claude_gemini_request.go b/internal/translator/claude/gemini/claude_gemini_request.go index ea53da05..a8d97b9d 100644 --- a/internal/translator/claude/gemini/claude_gemini_request.go +++ b/internal/translator/claude/gemini/claude_gemini_request.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" @@ -115,24 +116,47 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream // Include thoughts configuration for reasoning process visibility // Translator only does format conversion, ApplyThinking handles model capability validation. if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() { + mi := registry.LookupModelInfo(modelName, "claude") + supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0 + supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax)) + + // MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid + // validation errors since validate treats same-provider unsupported levels as errors. thinkingLevel := thinkingConfig.Get("thinkingLevel") if !thinkingLevel.Exists() { thinkingLevel = thinkingConfig.Get("thinking_level") } if thinkingLevel.Exists() { level := strings.ToLower(strings.TrimSpace(thinkingLevel.String())) - switch level { - case "": - case "none": - out, _ = sjson.Set(out, "thinking.type", "disabled") - out, _ = sjson.Delete(out, "thinking.budget_tokens") - case "auto": - out, _ = sjson.Set(out, "thinking.type", "enabled") - out, _ = sjson.Delete(out, "thinking.budget_tokens") - default: - if budget, ok := thinking.ConvertLevelToBudget(level); ok { + if supportsAdaptive { + switch level { + case "": + case "none": + out, _ = sjson.Set(out, "thinking.type", "disabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Delete(out, "output_config.effort") + default: + if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok { + level = mapped + } + out, _ = sjson.Set(out, "thinking.type", "adaptive") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Set(out, "output_config.effort", level) + } + } else { + switch level { + case "": + case "none": + out, _ = sjson.Set(out, "thinking.type", "disabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + case "auto": out, _ = sjson.Set(out, "thinking.type", "enabled") - out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + out, _ = sjson.Delete(out, "thinking.budget_tokens") + default: + if budget, ok := thinking.ConvertLevelToBudget(level); ok { + out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + } } } } else { @@ -142,16 +166,35 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream } if thinkingBudget.Exists() { budget := int(thinkingBudget.Int()) - switch budget { - case 0: - out, _ = sjson.Set(out, "thinking.type", "disabled") - out, _ = sjson.Delete(out, "thinking.budget_tokens") - case -1: - out, _ = sjson.Set(out, "thinking.type", "enabled") - out, _ = sjson.Delete(out, "thinking.budget_tokens") - default: - out, _ = sjson.Set(out, "thinking.type", "enabled") - out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + if supportsAdaptive { + switch budget { + case 0: + out, _ = sjson.Set(out, "thinking.type", "disabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Delete(out, "output_config.effort") + default: + level, ok := thinking.ConvertBudgetToLevel(budget) + if ok { + if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM { + level = mapped + } + out, _ = sjson.Set(out, "thinking.type", "adaptive") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Set(out, "output_config.effort", level) + } + } + } else { + switch budget { + case 0: + out, _ = sjson.Set(out, "thinking.type", "disabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + case -1: + out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + default: + out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + } } } else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True { out, _ = sjson.Set(out, "thinking.type", "enabled") diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_request.go b/internal/translator/claude/openai/chat-completions/claude_openai_request.go index f94825b2..ef01bb94 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_request.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_request.go @@ -14,6 +14,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -68,17 +69,45 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream if v := root.Get("reasoning_effort"); v.Exists() { effort := strings.ToLower(strings.TrimSpace(v.String())) if effort != "" { - budget, ok := thinking.ConvertLevelToBudget(effort) - if ok { - switch budget { - case 0: + mi := registry.LookupModelInfo(modelName, "claude") + supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0 + supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax)) + + // Claude 4.6 supports adaptive thinking with output_config.effort. + // MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid + // validation errors since validate treats same-provider unsupported levels as errors. + if supportsAdaptive { + switch effort { + case "none": out, _ = sjson.Set(out, "thinking.type", "disabled") - case -1: - out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Delete(out, "output_config.effort") + case "auto": + out, _ = sjson.Set(out, "thinking.type", "adaptive") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Delete(out, "output_config.effort") default: - if budget > 0 { + if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok { + effort = mapped + } + out, _ = sjson.Set(out, "thinking.type", "adaptive") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Set(out, "output_config.effort", effort) + } + } else { + // Legacy/manual thinking (budget_tokens). + budget, ok := thinking.ConvertLevelToBudget(effort) + if ok { + switch budget { + case 0: + out, _ = sjson.Set(out, "thinking.type", "disabled") + case -1: out, _ = sjson.Set(out, "thinking.type", "enabled") - out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + default: + if budget > 0 { + out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + } } } } @@ -174,46 +203,9 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream msg, _ = sjson.SetRaw(msg, "content.-1", part) } else if contentResult.Exists() && contentResult.IsArray() { contentResult.ForEach(func(_, part gjson.Result) bool { - partType := part.Get("type").String() - - switch partType { - case "text": - textPart := `{"type":"text","text":""}` - textPart, _ = sjson.Set(textPart, "text", part.Get("text").String()) - msg, _ = sjson.SetRaw(msg, "content.-1", textPart) - - case "image_url": - // Convert OpenAI image format to Claude Code format - imageURL := part.Get("image_url.url").String() - if strings.HasPrefix(imageURL, "data:") { - // Extract base64 data and media type from data URL - parts := strings.Split(imageURL, ",") - if len(parts) == 2 { - mediaTypePart := strings.Split(parts[0], ";")[0] - mediaType := strings.TrimPrefix(mediaTypePart, "data:") - data := parts[1] - - imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}` - imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType) - imagePart, _ = sjson.Set(imagePart, "source.data", data) - msg, _ = sjson.SetRaw(msg, "content.-1", imagePart) - } - } - - case "file": - fileData := part.Get("file.file_data").String() - if strings.HasPrefix(fileData, "data:") { - semicolonIdx := strings.Index(fileData, ";") - commaIdx := strings.Index(fileData, ",") - if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx { - mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:") - data := fileData[commaIdx+1:] - docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` - docPart, _ = sjson.Set(docPart, "source.media_type", mediaType) - docPart, _ = sjson.Set(docPart, "source.data", data) - msg, _ = sjson.SetRaw(msg, "content.-1", docPart) - } - } + claudePart := convertOpenAIContentPartToClaudePart(part) + if claudePart != "" { + msg, _ = sjson.SetRaw(msg, "content.-1", claudePart) } return true }) @@ -262,11 +254,16 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream case "tool": // Handle tool result messages conversion toolCallID := message.Get("tool_call_id").String() - content := message.Get("content").String() + toolContentResult := message.Get("content") msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}` msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID) - msg, _ = sjson.Set(msg, "content.0.content", content) + toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult) + if toolResultContentRaw { + msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent) + } else { + msg, _ = sjson.Set(msg, "content.0.content", toolResultContent) + } out, _ = sjson.SetRaw(out, "messages.-1", msg) messageIndex++ } @@ -329,3 +326,110 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream return []byte(out) } + +func convertOpenAIContentPartToClaudePart(part gjson.Result) string { + switch part.Get("type").String() { + case "text": + textPart := `{"type":"text","text":""}` + textPart, _ = sjson.Set(textPart, "text", part.Get("text").String()) + return textPart + + case "image_url": + return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String()) + + case "file": + fileData := part.Get("file.file_data").String() + if strings.HasPrefix(fileData, "data:") { + semicolonIdx := strings.Index(fileData, ";") + commaIdx := strings.Index(fileData, ",") + if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx { + mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:") + data := fileData[commaIdx+1:] + docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` + docPart, _ = sjson.Set(docPart, "source.media_type", mediaType) + docPart, _ = sjson.Set(docPart, "source.data", data) + return docPart + } + } + } + + return "" +} + +func convertOpenAIImageURLToClaudePart(imageURL string) string { + if imageURL == "" { + return "" + } + + if strings.HasPrefix(imageURL, "data:") { + parts := strings.SplitN(imageURL, ",", 2) + if len(parts) != 2 { + return "" + } + + mediaTypePart := strings.SplitN(parts[0], ";", 2)[0] + mediaType := strings.TrimPrefix(mediaTypePart, "data:") + if mediaType == "" { + mediaType = "application/octet-stream" + } + + imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}` + imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType) + imagePart, _ = sjson.Set(imagePart, "source.data", parts[1]) + return imagePart + } + + imagePart := `{"type":"image","source":{"type":"url","url":""}}` + imagePart, _ = sjson.Set(imagePart, "source.url", imageURL) + return imagePart +} + +func convertOpenAIToolResultContent(content gjson.Result) (string, bool) { + if !content.Exists() { + return "", false + } + + if content.Type == gjson.String { + return content.String(), false + } + + if content.IsArray() { + claudeContent := "[]" + partCount := 0 + + content.ForEach(func(_, part gjson.Result) bool { + if part.Type == gjson.String { + textPart := `{"type":"text","text":""}` + textPart, _ = sjson.Set(textPart, "text", part.String()) + claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart) + partCount++ + return true + } + + claudePart := convertOpenAIContentPartToClaudePart(part) + if claudePart != "" { + claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart) + partCount++ + } + return true + }) + + if partCount > 0 || len(content.Array()) == 0 { + return claudeContent, true + } + + return content.Raw, false + } + + if content.IsObject() { + claudePart := convertOpenAIContentPartToClaudePart(content) + if claudePart != "" { + claudeContent := "[]" + claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart) + return claudeContent, true + } + return content.Raw, false + } + + return content.Raw, false +} diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_request_test.go b/internal/translator/claude/openai/chat-completions/claude_openai_request_test.go new file mode 100644 index 00000000..ed84661d --- /dev/null +++ b/internal/translator/claude/openai/chat-completions/claude_openai_request_test.go @@ -0,0 +1,137 @@ +package chat_completions + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) { + inputJSON := `{ + "model": "gpt-4.1", + "messages": [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "do_work", + "arguments": "{\"a\":1}" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": [ + {"type": "text", "text": "tool ok"}, + { + "type": "image_url", + "image_url": { + "url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" + } + } + ] + } + ] + }` + + result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false) + resultJSON := gjson.ParseBytes(result) + messages := resultJSON.Get("messages").Array() + + if len(messages) != 2 { + t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + } + + toolResult := messages[1].Get("content.0") + if got := toolResult.Get("type").String(); got != "tool_result" { + t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got) + } + if got := toolResult.Get("tool_use_id").String(); got != "call_1" { + t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got) + } + + toolContent := toolResult.Get("content") + if !toolContent.IsArray() { + t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw) + } + if got := toolContent.Get("0.type").String(); got != "text" { + t.Fatalf("Expected first tool_result part type %q, got %q", "text", got) + } + if got := toolContent.Get("0.text").String(); got != "tool ok" { + t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got) + } + if got := toolContent.Get("1.type").String(); got != "image" { + t.Fatalf("Expected second tool_result part type %q, got %q", "image", got) + } + if got := toolContent.Get("1.source.type").String(); got != "base64" { + t.Fatalf("Expected image source type %q, got %q", "base64", got) + } + if got := toolContent.Get("1.source.media_type").String(); got != "image/png" { + t.Fatalf("Expected image media type %q, got %q", "image/png", got) + } + if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" { + t.Fatalf("Unexpected base64 image data: %q", got) + } +} + +func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) { + inputJSON := `{ + "model": "gpt-4.1", + "messages": [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "do_work", + "arguments": "{\"a\":1}" + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://example.com/tool.png" + } + } + ] + } + ] + }` + + result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false) + resultJSON := gjson.ParseBytes(result) + messages := resultJSON.Get("messages").Array() + + if len(messages) != 2 { + t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + } + + toolContent := messages[1].Get("content.0.content") + if !toolContent.IsArray() { + t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw) + } + if got := toolContent.Get("0.type").String(); got != "image" { + t.Fatalf("Expected tool_result part type %q, got %q", "image", got) + } + if got := toolContent.Get("0.source.type").String(); got != "url" { + t.Fatalf("Expected image source type %q, got %q", "url", got) + } + if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" { + t.Fatalf("Unexpected image URL: %q", got) + } +} diff --git a/internal/translator/claude/openai/responses/claude_openai-responses_request.go b/internal/translator/claude/openai/responses/claude_openai-responses_request.go index 33a81124..cb550b09 100644 --- a/internal/translator/claude/openai/responses/claude_openai-responses_request.go +++ b/internal/translator/claude/openai/responses/claude_openai-responses_request.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -56,17 +57,45 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte if v := root.Get("reasoning.effort"); v.Exists() { effort := strings.ToLower(strings.TrimSpace(v.String())) if effort != "" { - budget, ok := thinking.ConvertLevelToBudget(effort) - if ok { - switch budget { - case 0: + mi := registry.LookupModelInfo(modelName, "claude") + supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0 + supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax)) + + // Claude 4.6 supports adaptive thinking with output_config.effort. + // MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid + // validation errors since validate treats same-provider unsupported levels as errors. + if supportsAdaptive { + switch effort { + case "none": out, _ = sjson.Set(out, "thinking.type", "disabled") - case -1: - out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Delete(out, "output_config.effort") + case "auto": + out, _ = sjson.Set(out, "thinking.type", "adaptive") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Delete(out, "output_config.effort") default: - if budget > 0 { + if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok { + effort = mapped + } + out, _ = sjson.Set(out, "thinking.type", "adaptive") + out, _ = sjson.Delete(out, "thinking.budget_tokens") + out, _ = sjson.Set(out, "output_config.effort", effort) + } + } else { + // Legacy/manual thinking (budget_tokens). + budget, ok := thinking.ConvertLevelToBudget(effort) + if ok { + switch budget { + case 0: + out, _ = sjson.Set(out, "thinking.type", "disabled") + case -1: out, _ = sjson.Set(out, "thinking.type", "enabled") - out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + default: + if budget > 0 { + out, _ = sjson.Set(out, "thinking.type", "enabled") + out, _ = sjson.Set(out, "thinking.budget_tokens", budget) + } } } } diff --git a/internal/translator/codex/claude/codex_claude_request.go b/internal/translator/codex/claude/codex_claude_request.go index 64e41fb5..6373e693 100644 --- a/internal/translator/codex/claude/codex_claude_request.go +++ b/internal/translator/codex/claude/codex_claude_request.go @@ -160,7 +160,51 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) flushMessage() functionCallOutputMessage := `{"type":"function_call_output"}` functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String()) - functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String()) + + contentResult := messageContentResult.Get("content") + if contentResult.IsArray() { + toolResultContentIndex := 0 + toolResultContent := `[]` + contentResults := contentResult.Array() + for k := 0; k < len(contentResults); k++ { + toolResultContentType := contentResults[k].Get("type").String() + if toolResultContentType == "image" { + sourceResult := contentResults[k].Get("source") + if sourceResult.Exists() { + data := sourceResult.Get("data").String() + if data == "" { + data = sourceResult.Get("base64").String() + } + if data != "" { + mediaType := sourceResult.Get("media_type").String() + if mediaType == "" { + mediaType = sourceResult.Get("mime_type").String() + } + if mediaType == "" { + mediaType = "application/octet-stream" + } + dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data) + + toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image") + toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL) + toolResultContentIndex++ + } + } + } else if toolResultContentType == "text" { + toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text") + toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String()) + toolResultContentIndex++ + } + } + if toolResultContent != `[]` { + functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent) + } else { + functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String()) + } + } else { + functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String()) + } + template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage) } } @@ -211,6 +255,8 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw)) tool, _ = sjson.Delete(tool, "input_schema") tool, _ = sjson.Delete(tool, "parameters.$schema") + tool, _ = sjson.Delete(tool, "cache_control") + tool, _ = sjson.Delete(tool, "defer_loading") tool, _ = sjson.Set(tool, "strict", false) template, _ = sjson.SetRaw(template, "tools.-1", tool) } @@ -230,10 +276,18 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) reasoningEffort = effort } } - case "adaptive": - // Claude adaptive means "enable with max capacity"; keep it as highest level - // and let ApplyThinking normalize per target model capability. - reasoningEffort = string(thinking.LevelXHigh) + case "adaptive", "auto": + // Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6). + // Pass through directly; ApplyThinking handles clamping to target model's levels. + effort := "" + if v := rootResult.Get("output_config.effort"); v.Exists() && v.Type == gjson.String { + effort = strings.ToLower(strings.TrimSpace(v.String())) + } + if effort != "" { + reasoningEffort = effort + } else { + reasoningEffort = string(thinking.LevelXHigh) + } case "disabled": if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" { reasoningEffort = effort diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index cdcf2e4f..7f597062 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -22,8 +22,8 @@ var ( // ConvertCodexResponseToClaudeParams holds parameters for response conversion. type ConvertCodexResponseToClaudeParams struct { - HasToolCall bool - BlockIndex int + HasToolCall bool + BlockIndex int HasReceivedArgumentsDelta bool } diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_response.go b/internal/translator/codex/openai/chat-completions/codex_openai_response.go index f0e264c8..0054d995 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_response.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_response.go @@ -74,8 +74,13 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR } // Extract and set the model version. + cachedModel := (*param).(*ConvertCliToOpenAIParams).Model if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() { template, _ = sjson.Set(template, "model", modelResult.String()) + } else if cachedModel != "" { + template, _ = sjson.Set(template, "model", cachedModel) + } else if modelName != "" { + template, _ = sjson.Set(template, "model", modelName) } template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt) diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go b/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go new file mode 100644 index 00000000..70aaea06 --- /dev/null +++ b/internal/translator/codex/openai/chat-completions/codex_openai_response_test.go @@ -0,0 +1,47 @@ +package chat_completions + +import ( + "context" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *testing.T) { + ctx := context.Background() + var param any + + modelName := "gpt-5.3-codex" + + out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.created","response":{"id":"resp_123","created_at":1700000000,"model":"gpt-5.3-codex"}}`), ¶m) + if len(out) != 0 { + t.Fatalf("expected no output for response.created, got %d chunks", len(out)) + } + + out = ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + gotModel := gjson.Get(out[0], "model").String() + if gotModel != modelName { + t.Fatalf("expected model %q, got %q", modelName, gotModel) + } +} + +func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.T) { + ctx := context.Background() + var param any + + modelName := "gpt-5.3-codex" + + out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m) + if len(out) != 1 { + t.Fatalf("expected 1 chunk, got %d", len(out)) + } + + gotModel := gjson.Get(out[0], "model").String() + if gotModel != modelName { + t.Fatalf("expected model %q, got %q", modelName, gotModel) + } +} diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request.go b/internal/translator/codex/openai/responses/codex_openai-responses_request.go index 1161c515..360c037f 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request.go @@ -25,7 +25,12 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens") rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature") rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p") - rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") + if v := gjson.GetBytes(rawJSON, "service_tier"); v.Exists() { + if v.String() != "priority" { + rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") + } + } + rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation") rawJSON = applyResponsesCompactionCompatibility(rawJSON) diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go index 65732c3f..a2ede1b8 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go @@ -264,18 +264,18 @@ func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) { } } -func TestUserFieldDeletion(t *testing.T) { +func TestUserFieldDeletion(t *testing.T) { inputJSON := []byte(`{ "model": "gpt-5.2", "user": "test-user", "input": [{"role": "user", "content": "Hello"}] - }`) - - output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false) - outputStr := string(output) - - // Verify user field is deleted - userField := gjson.Get(outputStr, "user") + }`) + + output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false) + outputStr := string(output) + + // Verify user field is deleted + userField := gjson.Get(outputStr, "user") if userField.Exists() { t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw) } diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_response.go b/internal/translator/codex/openai/responses/codex_openai-responses_response.go index 4287206a..e84b817b 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_response.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_response.go @@ -6,24 +6,14 @@ import ( "fmt" "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) // ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks // to OpenAI Responses SSE events (response.*). -func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { +func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string { if bytes.HasPrefix(rawJSON, []byte("data:")) { rawJSON = bytes.TrimSpace(rawJSON[5:]) - if typeResult := gjson.GetBytes(rawJSON, "type"); typeResult.Exists() { - typeStr := typeResult.String() - if typeStr == "response.created" || typeStr == "response.in_progress" || typeStr == "response.completed" { - if gjson.GetBytes(rawJSON, "response.instructions").Exists() { - instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String() - rawJSON, _ = sjson.SetBytes(rawJSON, "response.instructions", instructions) - } - } - } out := fmt.Sprintf("data: %s", string(rawJSON)) return []string{out} } @@ -32,17 +22,12 @@ func ConvertCodexResponseToOpenAIResponses(ctx context.Context, modelName string // ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON // from a non-streaming OpenAI Chat Completions response. -func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { +func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string { rootResult := gjson.ParseBytes(rawJSON) // Verify this is a response.completed event if rootResult.Get("type").String() != "response.completed" { return "" } responseResult := rootResult.Get("response") - template := responseResult.Raw - if responseResult.Get("instructions").Exists() { - instructions := gjson.GetBytes(originalRequestRawJSON, "instructions").String() - template, _ = sjson.Set(template, "instructions", instructions) - } - return template + return responseResult.Raw } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go index ee661381..e3753b03 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go @@ -156,6 +156,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) [] tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.Delete(tool, "cache_control") + tool, _ = sjson.Delete(tool, "defer_loading") if gjson.Valid(tool) && gjson.Parse(tool).IsObject() { if !hasTools { out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`) @@ -171,7 +172,35 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) [] } } - // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled + // tool_choice + toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice") + if toolChoiceResult.Exists() { + toolChoiceType := "" + toolChoiceName := "" + if toolChoiceResult.IsObject() { + toolChoiceType = toolChoiceResult.Get("type").String() + toolChoiceName = toolChoiceResult.Get("name").String() + } else if toolChoiceResult.Type == gjson.String { + toolChoiceType = toolChoiceResult.String() + } + + switch toolChoiceType { + case "auto": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO") + case "none": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE") + case "any": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") + case "tool": + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY") + if toolChoiceName != "" { + out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) + } + } + } + + // Map Anthropic thinking -> Gemini CLI thinkingConfig when enabled + // Translator only does format conversion, ApplyThinking handles model capability validation. if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() { switch t.Get("type").String() { case "enabled": @@ -180,10 +209,20 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) [] out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) } - case "adaptive": - // Keep adaptive as a high level sentinel; ApplyThinking resolves it - // to model-specific max capability. - out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") + case "adaptive", "auto": + // For adaptive thinking: + // - If output_config.effort is explicitly present, pass through as thinkingLevel. + // - Otherwise, treat it as "enabled with target-model maximum" and emit high. + // ApplyThinking handles clamping to target model's supported levels. + effort := "" + if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String { + effort = strings.ToLower(strings.TrimSpace(v.String())) + } + if effort != "" { + out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) + } else { + out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") + } out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true) } } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go new file mode 100644 index 00000000..10364e75 --- /dev/null +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go @@ -0,0 +1,42 @@ +package claude + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) { + inputJSON := []byte(`{ + "model": "gemini-3-flash-preview", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hi"} + ] + } + ], + "tools": [ + { + "name": "json", + "description": "A JSON tool", + "input_schema": { + "type": "object", + "properties": {} + } + } + ], + "tool_choice": {"type": "tool", "name": "json"} + }`) + + output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false) + + if got := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" { + t.Fatalf("Expected request.toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got) + } + allowed := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array() + if len(allowed) != 1 || allowed[0].String() != "json" { + t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw) + } +} diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go index 53da71f4..b0a6bddd 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go @@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo // Model out, _ = sjson.SetBytes(out, "model", modelName) + // Let user-provided generationConfig pass through + if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() { + out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw)) + } + // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig. // Inline translation-only mapping; capability checks happen later in ApplyThinking. re := gjson.GetBytes(rawJSON, "reasoning_effort") diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index e882f769..b13955bb 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -9,6 +9,7 @@ import ( "bytes" "strings" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -84,6 +85,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) case "tool_use": functionName := contentResult.Get("name").String() + if toolUseID := contentResult.Get("id").String(); toolUseID != "" { + if derived := toolNameFromClaudeToolUseID(toolUseID); derived != "" { + functionName = derived + } + } functionArgs := contentResult.Get("input").String() argsResult := gjson.Parse(functionArgs) if argsResult.IsObject() && gjson.Valid(functionArgs) { @@ -99,10 +105,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if toolCallID == "" { return true } - funcName := toolCallID - toolCallIDs := strings.Split(toolCallID, "-") - if len(toolCallIDs) > 1 { - funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") + funcName := toolNameFromClaudeToolUseID(toolCallID) + if funcName == "" { + funcName = toolCallID } responseData := contentResult.Get("content").Raw part := `{"functionResponse":{"name":"","response":{"result":""}}}` @@ -136,6 +141,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "type") tool, _ = sjson.Delete(tool, "cache_control") + tool, _ = sjson.Delete(tool, "defer_loading") if gjson.Valid(tool) && gjson.Parse(tool).IsObject() { if !hasTools { out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`) @@ -151,7 +157,34 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) } } - // Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled + // tool_choice + toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice") + if toolChoiceResult.Exists() { + toolChoiceType := "" + toolChoiceName := "" + if toolChoiceResult.IsObject() { + toolChoiceType = toolChoiceResult.Get("type").String() + toolChoiceName = toolChoiceResult.Get("name").String() + } else if toolChoiceResult.Type == gjson.String { + toolChoiceType = toolChoiceResult.String() + } + + switch toolChoiceType { + case "auto": + out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "AUTO") + case "none": + out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "NONE") + case "any": + out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY") + case "tool": + out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY") + if toolChoiceName != "" { + out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) + } + } + } + + // Map Anthropic thinking -> Gemini thinking config when enabled // Translator only does format conversion, ApplyThinking handles model capability validation. if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() { switch t.Get("type").String() { @@ -161,10 +194,28 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget) out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true) } - case "adaptive": - // Keep adaptive as a high level sentinel; ApplyThinking resolves it - // to model-specific max capability. - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high") + case "adaptive", "auto": + // For adaptive thinking: + // - If output_config.effort is explicitly present, pass through as thinkingLevel. + // - Otherwise, treat it as "enabled with target-model maximum" and emit thinkingBudget=max. + // ApplyThinking handles clamping to target model's supported levels. + effort := "" + if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String { + effort = strings.ToLower(strings.TrimSpace(v.String())) + } + if effort != "" { + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", effort) + } else { + maxBudget := 0 + if mi := registry.LookupModelInfo(modelName, "gemini"); mi != nil && mi.Thinking != nil { + maxBudget = mi.Thinking.Max + } + if maxBudget > 0 { + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget) + } else { + out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high") + } + } out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true) } } @@ -183,3 +234,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) return result } + +func toolNameFromClaudeToolUseID(toolUseID string) string { + parts := strings.Split(toolUseID, "-") + if len(parts) <= 1 { + return "" + } + return strings.Join(parts[0:len(parts)-1], "-") +} diff --git a/internal/translator/gemini/claude/gemini_claude_request_test.go b/internal/translator/gemini/claude/gemini_claude_request_test.go new file mode 100644 index 00000000..e242c42c --- /dev/null +++ b/internal/translator/gemini/claude/gemini_claude_request_test.go @@ -0,0 +1,42 @@ +package claude + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeRequestToGemini_ToolChoice_SpecificTool(t *testing.T) { + inputJSON := []byte(`{ + "model": "gemini-3-flash-preview", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "hi"} + ] + } + ], + "tools": [ + { + "name": "json", + "description": "A JSON tool", + "input_schema": { + "type": "object", + "properties": {} + } + } + ], + "tool_choice": {"type": "tool", "name": "json"} + }`) + + output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false) + + if got := gjson.GetBytes(output, "toolConfig.functionCallingConfig.mode").String(); got != "ANY" { + t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got) + } + allowed := gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Array() + if len(allowed) != 1 || allowed[0].String() != "json" { + t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Raw) + } +} diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index cfc06921..e5adcb5e 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -12,8 +12,8 @@ import ( "fmt" "strings" "sync/atomic" - "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -25,6 +25,8 @@ type Params struct { ResponseType int ResponseIndex int HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output + ToolNameMap map[string]string + SawToolCall bool } // toolUseIDCounter provides a process-wide unique counter for tool use identifiers. @@ -53,6 +55,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR HasFirstResponse: false, ResponseType: 0, ResponseIndex: 0, + ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON), + SawToolCall: false, } } @@ -66,8 +70,6 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR return []string{} } - // Track whether tools are being used in this response chunk - usedTool := false output := "" // Initialize the streaming session with a message_start event @@ -175,12 +177,13 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR } else if functionCallResult.Exists() { // Handle function/tool calls from the AI model // This processes tool usage requests and formats them for Claude API compatibility - usedTool = true - fcName := functionCallResult.Get("name").String() + (*param).(*Params).SawToolCall = true + upstreamToolName := functionCallResult.Get("name").String() + clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName) // FIX: Handle streaming split/delta where name might be empty in subsequent chunks. // If we are already in tool use mode and name is empty, treat as continuation (delta). - if (*param).(*Params).ResponseType == 3 && fcName == "" { + if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" { if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw) @@ -221,8 +224,8 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) - data, _ = sjson.Set(data, "content_block.name", fcName) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))) + data, _ = sjson.Set(data, "content_block.name", clientToolName) output = output + fmt.Sprintf("data: %s\n\n\n", data) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { @@ -249,7 +252,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR output = output + `data: ` template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - if usedTool { + if (*param).(*Params).SawToolCall { template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` } else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" { template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` @@ -278,10 +281,10 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Returns: // - string: A Claude-compatible JSON response. func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON _ = requestRawJSON root := gjson.ParseBytes(rawJSON) + toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON) out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out, _ = sjson.Set(out, "id", root.Get("responseId").String()) @@ -336,11 +339,12 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina flushText() hasToolCall = true - name := functionCall.Get("name").String() + upstreamToolName := functionCall.Get("name").String() + clientToolName := util.MapToolName(toolNameMap, upstreamToolName) toolIDCounter++ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` - toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) - toolBlock, _ = sjson.Set(toolBlock, "name", name) + toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)) + toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName) inputRaw := "{}" if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() { inputRaw = args.Raw diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index 5de35681..f18f45be 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -34,6 +34,11 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) // Model out, _ = sjson.SetBytes(out, "model", modelName) + // Let user-provided generationConfig pass through + if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() { + out, _ = sjson.SetRawBytes(out, "generationConfig", []byte(genConfig.Raw)) + } + // Apply thinking configuration: convert OpenAI reasoning_effort to Gemini thinkingConfig. // Inline translation-only mapping; capability checks happen later in ApplyThinking. re := gjson.GetBytes(rawJSON, "reasoning_effort") diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index aca01717..143359d6 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -237,6 +237,33 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte partJSON, _ = sjson.Set(partJSON, "inline_data.data", data) } } + case "input_audio": + audioData := contentItem.Get("data").String() + audioFormat := contentItem.Get("format").String() + if audioData != "" { + audioMimeMap := map[string]string{ + "mp3": "audio/mpeg", + "wav": "audio/wav", + "ogg": "audio/ogg", + "flac": "audio/flac", + "aac": "audio/aac", + "webm": "audio/webm", + "pcm16": "audio/pcm", + "g711_ulaw": "audio/basic", + "g711_alaw": "audio/basic", + } + mimeType := "audio/wav" + if audioFormat != "" { + if mapped, ok := audioMimeMap[audioFormat]; ok { + mimeType = mapped + } else { + mimeType = "audio/" + audioFormat + } + } + partJSON = `{"inline_data":{"mime_type":"","data":""}}` + partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType) + partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData) + } } if partJSON != "" { @@ -354,22 +381,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte funcDecl, _ = sjson.Set(funcDecl, "description", desc.String()) } if params := tool.Get("parameters"); params.Exists() { - // Convert parameter types from OpenAI format to Gemini format - cleaned := params.Raw - // Convert type values to uppercase for Gemini - paramsResult := gjson.Parse(cleaned) - if properties := paramsResult.Get("properties"); properties.Exists() { - properties.ForEach(func(key, value gjson.Result) bool { - if propType := value.Get("type"); propType.Exists() { - upperType := strings.ToUpper(propType.String()) - cleaned, _ = sjson.Set(cleaned, "properties."+key.String()+".type", upperType) - } - return true - }) - } - // Set the overall type to OBJECT - cleaned, _ = sjson.Set(cleaned, "type", "OBJECT") - funcDecl, _ = sjson.SetRaw(funcDecl, "parametersJsonSchema", cleaned) + funcDecl, _ = sjson.SetRaw(funcDecl, "parametersJsonSchema", params.Raw) } geminiTools, _ = sjson.SetRaw(geminiTools, "0.functionDeclarations.-1", funcDecl) diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index acb79a13..b5280af8 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -75,10 +75,18 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream out, _ = sjson.Set(out, "reasoning_effort", effort) } } - case "adaptive": - // Claude adaptive means "enable with max capacity"; keep it as highest level - // and let ApplyThinking normalize per target model capability. - out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh)) + case "adaptive", "auto": + // Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6). + // Pass through directly; ApplyThinking handles clamping to target model's levels. + effort := "" + if v := root.Get("output_config.effort"); v.Exists() && v.Type == gjson.String { + effort = strings.ToLower(strings.TrimSpace(v.String())) + } + if effort != "" { + out, _ = sjson.Set(out, "reasoning_effort", effort) + } else { + out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh)) + } case "disabled": if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" { out, _ = sjson.Set(out, "reasoning_effort", effort) @@ -175,7 +183,12 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream // Collect tool_result to emit after the main message (ensures tool results follow tool_calls) toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}` toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String()) - toolResultJSON, _ = sjson.Set(toolResultJSON, "content", convertClaudeToolResultContentToString(part.Get("content"))) + toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content")) + if toolResultContentRaw { + toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent) + } else { + toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent) + } toolResults = append(toolResults, toolResultJSON) } return true @@ -366,21 +379,41 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) { } } -func convertClaudeToolResultContentToString(content gjson.Result) string { +func convertClaudeToolResultContent(content gjson.Result) (string, bool) { if !content.Exists() { - return "" + return "", false } if content.Type == gjson.String { - return content.String() + return content.String(), false } if content.IsArray() { var parts []string + contentJSON := "[]" + hasImagePart := false content.ForEach(func(_, item gjson.Result) bool { switch { case item.Type == gjson.String: - parts = append(parts, item.String()) + text := item.String() + parts = append(parts, text) + textContent := `{"type":"text","text":""}` + textContent, _ = sjson.Set(textContent, "text", text) + contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent) + case item.IsObject() && item.Get("type").String() == "text": + text := item.Get("text").String() + parts = append(parts, text) + textContent := `{"type":"text","text":""}` + textContent, _ = sjson.Set(textContent, "text", text) + contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent) + case item.IsObject() && item.Get("type").String() == "image": + contentItem, ok := convertClaudeContentPart(item) + if ok { + contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem) + hasImagePart = true + } else { + parts = append(parts, item.Raw) + } case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String: parts = append(parts, item.Get("text").String()) default: @@ -389,19 +422,31 @@ func convertClaudeToolResultContentToString(content gjson.Result) string { return true }) + if hasImagePart { + return contentJSON, true + } + joined := strings.Join(parts, "\n\n") if strings.TrimSpace(joined) != "" { - return joined + return joined, false } - return content.Raw + return content.Raw, false } if content.IsObject() { - if text := content.Get("text"); text.Exists() && text.Type == gjson.String { - return text.String() + if content.Get("type").String() == "image" { + contentItem, ok := convertClaudeContentPart(content) + if ok { + contentJSON := "[]" + contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem) + return contentJSON, true + } } - return content.Raw + if text := content.Get("text"); text.Exists() && text.Type == gjson.String { + return text.String(), false + } + return content.Raw, false } - return content.Raw + return content.Raw, false } diff --git a/internal/translator/openai/claude/openai_claude_request_test.go b/internal/translator/openai/claude/openai_claude_request_test.go index d08de1b2..3fd4707f 100644 --- a/internal/translator/openai/claude/openai_claude_request_test.go +++ b/internal/translator/openai/claude/openai_claude_request_test.go @@ -488,6 +488,114 @@ func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) { } } +func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) { + inputJSON := `{ + "model": "claude-3-opus", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}} + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "call_1", + "content": [ + {"type": "text", "text": "tool ok"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUg==" + } + } + ] + } + ] + } + ] + }` + + result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false) + resultJSON := gjson.ParseBytes(result) + messages := resultJSON.Get("messages").Array() + + if len(messages) != 2 { + t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + } + + toolContent := messages[1].Get("content") + if !toolContent.IsArray() { + t.Fatalf("Expected tool content array, got %s", toolContent.Raw) + } + if got := toolContent.Get("0.type").String(); got != "text" { + t.Fatalf("Expected first tool content type %q, got %q", "text", got) + } + if got := toolContent.Get("0.text").String(); got != "tool ok" { + t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got) + } + if got := toolContent.Get("1.type").String(); got != "image_url" { + t.Fatalf("Expected second tool content type %q, got %q", "image_url", got) + } + if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" { + t.Fatalf("Unexpected image_url: %q", got) + } +} + +func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) { + inputJSON := `{ + "model": "claude-3-opus", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}} + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "call_1", + "content": { + "type": "image", + "source": { + "type": "url", + "url": "https://example.com/tool.png" + } + } + } + ] + } + ] + }` + + result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false) + resultJSON := gjson.ParseBytes(result) + messages := resultJSON.Get("messages").Array() + + if len(messages) != 2 { + t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw) + } + + toolContent := messages[1].Get("content") + if !toolContent.IsArray() { + t.Fatalf("Expected tool content array, got %s", toolContent.Raw) + } + if got := toolContent.Get("0.type").String(); got != "image_url" { + t.Fatalf("Expected tool content type %q, got %q", "image_url", got) + } + if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" { + t.Fatalf("Unexpected image_url: %q", got) + } +} + func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) { inputJSON := `{ "model": "claude-3-opus", diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go index ca20c848..7bb496a2 100644 --- a/internal/translator/openai/claude/openai_claude_response.go +++ b/internal/translator/openai/claude/openai_claude_response.go @@ -22,9 +22,11 @@ var ( // ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion type ConvertOpenAIResponseToAnthropicParams struct { - MessageID string - Model string - CreatedAt int64 + MessageID string + Model string + CreatedAt int64 + ToolNameMap map[string]string + SawToolCall bool // Content accumulator for streaming ContentAccumulator strings.Builder // Tool calls accumulator for streaming @@ -78,6 +80,8 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR MessageID: "", Model: "", CreatedAt: 0, + ToolNameMap: nil, + SawToolCall: false, ContentAccumulator: strings.Builder{}, ToolCallsAccumulator: nil, TextContentBlockStarted: false, @@ -97,6 +101,10 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR } rawJSON = bytes.TrimSpace(rawJSON[5:]) + if (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap == nil { + (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap = util.ToolNameMapFromClaudeRequest(originalRequestRawJSON) + } + // Check if this is the [DONE] marker rawStr := strings.TrimSpace(string(rawJSON)) if rawStr == "[DONE]" { @@ -111,6 +119,16 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR } } +func effectiveOpenAIFinishReason(param *ConvertOpenAIResponseToAnthropicParams) string { + if param == nil { + return "" + } + if param.SawToolCall { + return "tool_calls" + } + return param.FinishReason +} + // convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string { root := gjson.ParseBytes(rawJSON) @@ -197,6 +215,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI } toolCalls.ForEach(func(_, toolCall gjson.Result) bool { + param.SawToolCall = true index := int(toolCall.Get("index").Int()) blockIndex := param.toolContentBlockIndex(index) @@ -215,7 +234,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI // Handle function name if function := toolCall.Get("function"); function.Exists() { if name := function.Get("name"); name.Exists() { - accumulator.Name = name.String() + accumulator.Name = util.MapToolName(param.ToolNameMap, name.String()) stopThinkingContentBlock(param, &results) @@ -246,7 +265,11 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI // Handle finish_reason (but don't send message_delta/message_stop yet) if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" { reason := finishReason.String() - param.FinishReason = reason + if param.SawToolCall { + param.FinishReason = "tool_calls" + } else { + param.FinishReason = reason + } // Send content_block_stop for thinking content if needed if param.ThinkingContentBlockStarted { @@ -294,7 +317,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage) // Send message_delta with usage messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason)) + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param))) messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens) messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens) if cachedTokens > 0 { @@ -348,7 +371,7 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams) // If we haven't sent message_delta yet (no usage info was received), send it now if param.FinishReason != "" && !param.MessageDeltaSent { messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` - messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(param.FinishReason)) + messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param))) results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n") param.MessageDeltaSent = true } @@ -531,10 +554,10 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results // Returns: // - string: An Anthropic-compatible JSON response. func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string { - _ = originalRequestRawJSON _ = requestRawJSON root := gjson.ParseBytes(rawJSON) + toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON) out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}` out, _ = sjson.Set(out, "id", root.Get("id").String()) out, _ = sjson.Set(out, "model", root.Get("model").String()) @@ -590,7 +613,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina hasToolCall = true toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String()) - toolUse, _ = sjson.Set(toolUse, "name", tc.Get("function.name").String()) + toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String())) argsStr := util.FixJSON(tc.Get("function.arguments").String()) if argsStr != "" && gjson.Valid(argsStr) { @@ -647,7 +670,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina hasToolCall = true toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String()) - toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String()) + toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String())) argsStr := util.FixJSON(toolCall.Get("function.arguments").String()) if argsStr != "" && gjson.Valid(argsStr) { diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index b8d07bf4..8617b846 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -430,7 +430,7 @@ func removeUnsupportedKeywords(jsonStr string) string { keywords := append(unsupportedConstraints, "$schema", "$defs", "definitions", "const", "$ref", "$id", "additionalProperties", "propertyNames", "patternProperties", // Gemini doesn't support these schema keywords - "enumTitles", "prefill", // Claude/OpenCode schema metadata fields unsupported by Gemini + "enumTitles", "prefill", "deprecated", // Schema metadata fields unsupported by Gemini ) deletePaths := make([]string, 0) diff --git a/internal/util/translator.go b/internal/util/translator.go index 51ecb748..669ba745 100644 --- a/internal/util/translator.go +++ b/internal/util/translator.go @@ -6,6 +6,7 @@ package util import ( "bytes" "fmt" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -219,3 +220,54 @@ func FixJSON(input string) string { return out.String() } + +func CanonicalToolName(name string) string { + canonical := strings.TrimSpace(name) + canonical = strings.TrimLeft(canonical, "_") + return strings.ToLower(canonical) +} + +// ToolNameMapFromClaudeRequest returns a canonical-name -> original-name map extracted from a Claude request. +// It is used to restore exact tool name casing for clients that require strict tool name matching (e.g. Claude Code). +func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string { + if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) { + return nil + } + + tools := gjson.GetBytes(rawJSON, "tools") + if !tools.Exists() || !tools.IsArray() { + return nil + } + + toolResults := tools.Array() + out := make(map[string]string, len(toolResults)) + tools.ForEach(func(_, tool gjson.Result) bool { + name := strings.TrimSpace(tool.Get("name").String()) + if name == "" { + return true + } + key := CanonicalToolName(name) + if key == "" { + return true + } + if _, exists := out[key]; !exists { + out[key] = name + } + return true + }) + + if len(out) == 0 { + return nil + } + return out +} + +func MapToolName(toolNameMap map[string]string, name string) string { + if name == "" || toolNameMap == nil { + return name + } + if mapped, ok := toolNameMap[CanonicalToolName(name)]; ok && mapped != "" { + return mapped + } + return name +} diff --git a/internal/watcher/clients.go b/internal/watcher/clients.go index 0d0b6fe7..f63223ae 100644 --- a/internal/watcher/clients.go +++ b/internal/watcher/clients.go @@ -369,3 +369,79 @@ func (w *Watcher) persistAuthAsync(message string, paths ...string) { } }() } + +func (w *Watcher) stopServerUpdateTimer() { + w.serverUpdateMu.Lock() + defer w.serverUpdateMu.Unlock() + if w.serverUpdateTimer != nil { + w.serverUpdateTimer.Stop() + w.serverUpdateTimer = nil + } + w.serverUpdatePend = false +} + +func (w *Watcher) triggerServerUpdate(cfg *config.Config) { + if w == nil || w.reloadCallback == nil || cfg == nil { + return + } + if w.stopped.Load() { + return + } + + now := time.Now() + + w.serverUpdateMu.Lock() + if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce { + w.serverUpdateLast = now + if w.serverUpdateTimer != nil { + w.serverUpdateTimer.Stop() + w.serverUpdateTimer = nil + } + w.serverUpdatePend = false + w.serverUpdateMu.Unlock() + w.reloadCallback(cfg) + return + } + + if w.serverUpdatePend { + w.serverUpdateMu.Unlock() + return + } + + delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast) + if delay < 10*time.Millisecond { + delay = 10 * time.Millisecond + } + w.serverUpdatePend = true + if w.serverUpdateTimer != nil { + w.serverUpdateTimer.Stop() + w.serverUpdateTimer = nil + } + var timer *time.Timer + timer = time.AfterFunc(delay, func() { + if w.stopped.Load() { + return + } + w.clientsMutex.RLock() + latestCfg := w.config + w.clientsMutex.RUnlock() + + w.serverUpdateMu.Lock() + if w.serverUpdateTimer != timer || !w.serverUpdatePend { + w.serverUpdateMu.Unlock() + return + } + w.serverUpdateTimer = nil + w.serverUpdatePend = false + if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() { + w.serverUpdateMu.Unlock() + return + } + + w.serverUpdateLast = time.Now() + w.serverUpdateMu.Unlock() + w.reloadCallback(latestCfg) + }) + w.serverUpdateTimer = timer + w.serverUpdateMu.Unlock() +} diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index b7d537da..7997f04e 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -304,6 +304,11 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldModels.hash != newModels.hash { changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count)) } + oldExcluded := SummarizeExcludedModels(o.ExcludedModels) + newExcluded := SummarizeExcludedModels(n.ExcludedModels) + if oldExcluded.hash != newExcluded.hash { + changes = append(changes, fmt.Sprintf("vertex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count)) + } if !equalStringMap(o.Headers, n.Headers) { changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i)) } diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go index 69194efc..52ae9a48 100644 --- a/internal/watcher/synthesizer/config.go +++ b/internal/watcher/synthesizer/config.go @@ -315,7 +315,7 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor CreatedAt: now, UpdatedAt: now, } - ApplyAuthExcludedModelsMeta(a, cfg, nil, "apikey") + ApplyAuthExcludedModelsMeta(a, cfg, compat.ExcludedModels, "apikey") out = append(out, a) } return out diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 50f3a2ab..02a0cefa 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "strconv" "strings" "time" @@ -92,6 +93,9 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] id = rel } } + if runtime.GOOS == "windows" { + id = strings.ToLower(id) + } proxyURL := "" if p, ok := metadata["proxy_url"].(string); ok { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 8180e474..cf890a4c 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -6,6 +6,7 @@ import ( "context" "strings" "sync" + "sync/atomic" "time" "github.com/fsnotify/fsnotify" @@ -35,6 +36,11 @@ type Watcher struct { clientsMutex sync.RWMutex configReloadMu sync.Mutex configReloadTimer *time.Timer + serverUpdateMu sync.Mutex + serverUpdateTimer *time.Timer + serverUpdateLast time.Time + serverUpdatePend bool + stopped atomic.Bool reloadCallback func(*config.Config) watcher *fsnotify.Watcher lastAuthHashes map[string]string @@ -77,6 +83,7 @@ const ( replaceCheckDelay = 50 * time.Millisecond configReloadDebounce = 150 * time.Millisecond authRemoveDebounceWindow = 1 * time.Second + serverUpdateDebounce = 1 * time.Second ) // NewWatcher creates a new file watcher instance @@ -116,8 +123,10 @@ func (w *Watcher) Start(ctx context.Context) error { // Stop stops the file watcher func (w *Watcher) Stop() error { + w.stopped.Store(true) w.stopDispatch() w.stopConfigReloadTimer() + w.stopServerUpdateTimer() return w.watcher.Close() } diff --git a/internal/watcher/watcher_test.go b/internal/watcher/watcher_test.go index 27d28419..00a7a143 100644 --- a/internal/watcher/watcher_test.go +++ b/internal/watcher/watcher_test.go @@ -543,6 +543,46 @@ func TestAuthSliceToMap(t *testing.T) { } } +func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) { + tmpDir := t.TempDir() + cfg := &config.Config{AuthDir: tmpDir} + + var reloads int32 + w := &Watcher{ + reloadCallback: func(*config.Config) { + atomic.AddInt32(&reloads, 1) + }, + } + w.SetConfig(cfg) + + w.serverUpdateMu.Lock() + w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond)) + w.serverUpdateMu.Unlock() + w.triggerServerUpdate(cfg) + + if got := atomic.LoadInt32(&reloads); got != 0 { + t.Fatalf("expected no immediate reload, got %d", got) + } + + w.serverUpdateMu.Lock() + if !w.serverUpdatePend || w.serverUpdateTimer == nil { + w.serverUpdateMu.Unlock() + t.Fatal("expected a pending server update timer") + } + w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond)) + w.serverUpdateMu.Unlock() + + w.triggerServerUpdate(cfg) + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected immediate reload once, got %d", got) + } + + time.Sleep(250 * time.Millisecond) + if got := atomic.LoadInt32(&reloads); got != 1 { + t.Fatalf("expected pending timer to be cancelled, got %d reloads", got) + } +} + func TestShouldDebounceRemove(t *testing.T) { w := &Watcher{} path := filepath.Clean("test.json") diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index f2d44f05..5e2beb94 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -26,7 +26,6 @@ const ( wsRequestTypeAppend = "response.append" wsEventTypeError = "error" wsEventTypeCompleted = "response.completed" - wsEventTypeDone = "response.done" wsDoneMarker = "[DONE]" wsTurnStateHeader = "x-codex-turn-state" wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" @@ -469,9 +468,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( for i := range payloads { eventType := gjson.GetBytes(payloads[i], "type").String() if eventType == wsEventTypeCompleted { - // log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone) - payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone) - completed = true completedOutput = responseCompletedOutputFromPayload(payloads[i]) } diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 9b6cec78..a04bb18c 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -2,12 +2,15 @@ package openai import ( "bytes" + "errors" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/tidwall/gjson" ) @@ -247,3 +250,79 @@ func TestSetWebsocketRequestBody(t *testing.T) { t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body") } } + +func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + serverErrCh := make(chan error, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil) + if err != nil { + serverErrCh <- err + return + } + defer func() { + errClose := conn.Close() + if errClose != nil { + serverErrCh <- errClose + } + }() + + ctx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ctx.Request = r + + data := make(chan []byte, 1) + errCh := make(chan *interfaces.ErrorMessage) + data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n") + close(data) + close(errCh) + + var bodyLog strings.Builder + completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + ctx, + conn, + func(...interface{}) {}, + data, + errCh, + &bodyLog, + "session-1", + ) + if err != nil { + serverErrCh <- err + return + } + if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" { + serverErrCh <- errors.New("completed output not captured") + return + } + serverErrCh <- nil + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + errClose := conn.Close() + if errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message: %v", errReadMessage) + } + if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted { + t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted) + } + if strings.Contains(string(payload), "response.done") { + t.Fatalf("payload unexpectedly rewrote completed event: %s", payload) + } + + if errServer := <-serverErrCh; errServer != nil { + t.Fatalf("server error: %v", errServer) + } +} diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index c424a89b..987d305e 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "runtime" "strings" "sync" "time" @@ -257,14 +258,17 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, } func (s *FileTokenStore) idFor(path, baseDir string) string { - if baseDir == "" { - return path + id := path + if baseDir != "" { + if rel, errRel := filepath.Rel(baseDir, path); errRel == nil && rel != "" { + id = rel + } } - rel, err := filepath.Rel(baseDir, path) - if err != nil { - return path + // On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths. + if runtime.GOOS == "windows" { + id = strings.ToLower(id) } - return rel + return id } func (s *FileTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) { diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 3434b7a7..ae5b745c 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -463,9 +463,14 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { return nil, nil } m.mu.Lock() - if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" { - auth.Index = existing.Index - auth.indexAssigned = existing.indexAssigned + if existing, ok := m.auths[auth.ID]; ok && existing != nil { + if !auth.indexAssigned && auth.Index == "" { + auth.Index = existing.Index + auth.indexAssigned = existing.indexAssigned + } + if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 { + auth.ModelStates = existing.ModelStates + } } auth.EnsureIndex() m.auths[auth.ID] = auth.Clone() diff --git a/sdk/cliproxy/auth/conductor_overrides_test.go b/sdk/cliproxy/auth/conductor_overrides_test.go index e5792c68..7aca49da 100644 --- a/sdk/cliproxy/auth/conductor_overrides_test.go +++ b/sdk/cliproxy/auth/conductor_overrides_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/google/uuid" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" ) @@ -115,8 +117,19 @@ func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) ( executor := &credentialRetryLimitExecutor{id: "claude"} m.RegisterExecutor(executor) - auth1 := &Auth{ID: "auth-1", Provider: "claude"} - auth2 := &Auth{ID: "auth-2", Provider: "claude"} + baseID := uuid.NewString() + auth1 := &Auth{ID: baseID + "-auth-1", Provider: "claude"} + auth2 := &Auth{ID: baseID + "-auth-2", Provider: "claude"} + + // Auth selection requires that the global model registry knows each credential supports the model. + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth1.ID, "claude", []*registry.ModelInfo{{ID: "test-model"}}) + reg.RegisterClient(auth2.ID, "claude", []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + reg.UnregisterClient(auth1.ID) + reg.UnregisterClient(auth2.ID) + }) + if _, errRegister := m.Register(context.Background(), auth1); errRegister != nil { t.Fatalf("register auth1: %v", errRegister) } diff --git a/sdk/cliproxy/auth/conductor_update_test.go b/sdk/cliproxy/auth/conductor_update_test.go new file mode 100644 index 00000000..f058f517 --- /dev/null +++ b/sdk/cliproxy/auth/conductor_update_test.go @@ -0,0 +1,49 @@ +package auth + +import ( + "context" + "testing" +) + +func TestManager_Update_PreservesModelStates(t *testing.T) { + m := NewManager(nil, nil, nil) + + model := "test-model" + backoffLevel := 7 + + if _, errRegister := m.Register(context.Background(), &Auth{ + ID: "auth-1", + Provider: "claude", + Metadata: map[string]any{"k": "v"}, + ModelStates: map[string]*ModelState{ + model: { + Quota: QuotaState{BackoffLevel: backoffLevel}, + }, + }, + }); errRegister != nil { + t.Fatalf("register auth: %v", errRegister) + } + + if _, errUpdate := m.Update(context.Background(), &Auth{ + ID: "auth-1", + Provider: "claude", + Metadata: map[string]any{"k": "v2"}, + }); errUpdate != nil { + t.Fatalf("update auth: %v", errUpdate) + } + + updated, ok := m.GetByID("auth-1") + if !ok || updated == nil { + t.Fatalf("expected auth to be present") + } + if len(updated.ModelStates) == 0 { + t.Fatalf("expected ModelStates to be preserved") + } + state := updated.ModelStates[model] + if state == nil { + t.Fatalf("expected model state to be present") + } + if state.Quota.BackoffLevel != backoffLevel { + t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel) + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 4be83816..6124f8b1 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -290,6 +290,9 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A auth.CreatedAt = existing.CreatedAt auth.LastRefreshedAt = existing.LastRefreshedAt auth.NextRefreshAfter = existing.NextRefreshAfter + if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 { + auth.ModelStates = existing.ModelStates + } op = "update" _, err = s.coreManager.Update(ctx, auth) } else { @@ -788,10 +791,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "vertex": // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() - if authKind == "apikey" { - if entry := s.resolveConfigVertexCompatKey(a); entry != nil && len(entry.Models) > 0 { + if entry := s.resolveConfigVertexCompatKey(a); entry != nil { + if len(entry.Models) > 0 { models = buildVertexCompatConfigModels(entry) } + if authKind == "apikey" { + excluded = entry.ExcludedModels + } } models = applyExcludedModels(models, excluded) case "gemini-cli": diff --git a/test/thinking_conversion_test.go b/test/thinking_conversion_test.go index 781a1667..7d9b7b86 100644 --- a/test/thinking_conversion_test.go +++ b/test/thinking_conversion_test.go @@ -34,6 +34,8 @@ type thinkingTestCase struct { inputJSON string expectField string expectValue string + expectField2 string + expectValue2 string includeThoughts string expectErr bool } @@ -384,15 +386,17 @@ func TestThinkingE2EMatrix_Suffix(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 30: Effort xhigh → not in low/high → error + // Case 30: Effort xhigh → clamped to high { - name: "30", - from: "openai", - to: "gemini", - model: "gemini-mixed-model(xhigh)", - inputJSON: `{"model":"gemini-mixed-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`, - expectField: "", - expectErr: true, + name: "30", + from: "openai", + to: "gemini", + model: "gemini-mixed-model(xhigh)", + inputJSON: `{"model":"gemini-mixed-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`, + expectField: "generationConfig.thinkingConfig.thinkingLevel", + expectValue: "high", + includeThoughts: "true", + expectErr: false, }, // Case 31: Effort none → clamped to low (min supported) → includeThoughts=false { @@ -1666,15 +1670,17 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // Case 30: reasoning_effort=xhigh → error (not in low/high) + // Case 30: reasoning_effort=xhigh → clamped to high { - name: "30", - from: "openai", - to: "gemini", - model: "gemini-mixed-model", - inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`, - expectField: "", - expectErr: true, + name: "30", + from: "openai", + to: "gemini", + model: "gemini-mixed-model", + inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`, + expectField: "generationConfig.thinkingConfig.thinkingLevel", + expectValue: "high", + includeThoughts: "true", + expectErr: false, }, // Case 31: reasoning_effort=none → clamped to low → includeThoughts=false { @@ -2590,9 +2596,8 @@ func TestThinkingE2EMatrix_Body(t *testing.T) { runThinkingTests(t, cases) } -// TestThinkingE2EClaudeAdaptive_Body tests Claude thinking.type=adaptive extended body-only cases. -// These cases validate that adaptive means "thinking enabled without explicit budget", and -// cross-protocol conversion should resolve to target-model maximum thinking capability. +// TestThinkingE2EClaudeAdaptive_Body covers Group 3 cases in docs/thinking-e2e-test-cases.md. +// It focuses on Claude 4.6 adaptive thinking and effort/level cross-protocol semantics (body-only). func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) { reg := registry.GetGlobalRegistry() uid := fmt.Sprintf("thinking-e2e-claude-adaptive-%d", time.Now().UnixNano()) @@ -2601,32 +2606,347 @@ func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) { defer reg.UnregisterClient(uid) cases := []thinkingTestCase{ - // A1: Claude adaptive to OpenAI level model -> highest supported level + // A subgroup: OpenAI -> Claude (reasoning_effort -> output_config.effort) { name: "A1", + from: "openai", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"minimal"}`, + expectField: "output_config.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "A2", + from: "openai", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"low"}`, + expectField: "output_config.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "A3", + from: "openai", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`, + expectField: "output_config.effort", + expectValue: "medium", + expectErr: false, + }, + { + name: "A4", + from: "openai", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"high"}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "A5", + from: "openai", + to: "claude", + model: "claude-opus-4-6-model", + inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`, + expectField: "output_config.effort", + expectValue: "max", + expectErr: false, + }, + { + name: "A6", + from: "openai", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "A7", + from: "openai", + to: "claude", + model: "claude-opus-4-6-model", + inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"max"}`, + expectField: "output_config.effort", + expectValue: "max", + expectErr: false, + }, + { + name: "A8", + from: "openai", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"max"}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + + // B subgroup: Gemini -> Claude (thinkingLevel/thinkingBudget -> output_config.effort) + { + name: "B1", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"minimal"}}}`, + expectField: "output_config.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "B2", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"low"}}}`, + expectField: "output_config.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "B3", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"medium"}}}`, + expectField: "output_config.effort", + expectValue: "medium", + expectErr: false, + }, + { + name: "B4", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"high"}}}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "B5", + from: "gemini", + to: "claude", + model: "claude-opus-4-6-model", + inputJSON: `{"model":"claude-opus-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"xhigh"}}}`, + expectField: "output_config.effort", + expectValue: "max", + expectErr: false, + }, + { + name: "B6", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"xhigh"}}}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "B7", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":512}}}`, + expectField: "output_config.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "B8", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":1024}}}`, + expectField: "output_config.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "B9", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`, + expectField: "output_config.effort", + expectValue: "medium", + expectErr: false, + }, + { + name: "B10", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":24576}}}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "B11", + from: "gemini", + to: "claude", + model: "claude-opus-4-6-model", + inputJSON: `{"model":"claude-opus-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":32768}}}`, + expectField: "output_config.effort", + expectValue: "max", + expectErr: false, + }, + { + name: "B12", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":32768}}}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "B13", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`, + expectField: "thinking.type", + expectValue: "disabled", + expectErr: false, + }, + { + name: "B14", + from: "gemini", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`, + expectField: "output_config.effort", + expectValue: "high", + expectErr: false, + }, + + // C subgroup: Claude adaptive + effort cross-protocol conversion + { + name: "C1", from: "claude", to: "openai", model: "level-model", - inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"minimal"}}`, + expectField: "reasoning_effort", + expectValue: "minimal", + expectErr: false, + }, + { + name: "C2", + from: "claude", + to: "openai", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`, + expectField: "reasoning_effort", + expectValue: "low", + expectErr: false, + }, + { + name: "C3", + from: "claude", + to: "openai", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"medium"}}`, + expectField: "reasoning_effort", + expectValue: "medium", + expectErr: false, + }, + { + name: "C4", + from: "claude", + to: "openai", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, expectField: "reasoning_effort", expectValue: "high", expectErr: false, }, - // A2: Claude adaptive to Gemini level subset model -> highest supported level { - name: "A2", + name: "C5", + from: "claude", + to: "openai", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`, + expectField: "reasoning_effort", + expectValue: "high", + expectErr: false, + }, + { + name: "C6", + from: "claude", + to: "openai", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`, + expectField: "reasoning_effort", + expectValue: "high", + expectErr: false, + }, + { + name: "C7", + from: "claude", + to: "openai", + model: "no-thinking-model", + inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, + expectField: "", + expectErr: false, + }, + + { + name: "C8", from: "claude", to: "gemini", model: "level-subset-model", - inputJSON: `{"model":"level-subset-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, + inputJSON: `{"model":"level-subset-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, expectField: "generationConfig.thinkingConfig.thinkingLevel", expectValue: "high", includeThoughts: "true", expectErr: false, }, - // A3: Claude adaptive to Gemini budget model -> max budget { - name: "A3", + name: "C9", + from: "claude", + to: "gemini", + model: "gemini-budget-model", + inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`, + expectField: "generationConfig.thinkingConfig.thinkingBudget", + expectValue: "1024", + includeThoughts: "true", + expectErr: false, + }, + { + name: "C10", + from: "claude", + to: "gemini", + model: "gemini-budget-model", + inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"medium"}}`, + expectField: "generationConfig.thinkingConfig.thinkingBudget", + expectValue: "8192", + includeThoughts: "true", + expectErr: false, + }, + { + name: "C11", + from: "claude", + to: "gemini", + model: "gemini-budget-model", + inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, + expectField: "generationConfig.thinkingConfig.thinkingBudget", + expectValue: "20000", + includeThoughts: "true", + expectErr: false, + }, + { + name: "C12", from: "claude", to: "gemini", model: "gemini-budget-model", @@ -2636,32 +2956,91 @@ func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // A4: Claude adaptive to Gemini mixed model -> highest supported level { - name: "A4", + name: "C13", from: "claude", to: "gemini", model: "gemini-mixed-model", - inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, + inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, expectField: "generationConfig.thinkingConfig.thinkingLevel", expectValue: "high", includeThoughts: "true", expectErr: false, }, - // A5: Claude adaptive passthrough for same protocol + { - name: "A5", + name: "C14", from: "claude", - to: "claude", - model: "claude-budget-model", - inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, - expectField: "thinking.type", - expectValue: "adaptive", + to: "codex", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"minimal"}}`, + expectField: "reasoning.effort", + expectValue: "minimal", expectErr: false, }, - // A6: Claude adaptive to Antigravity budget model -> max budget { - name: "A6", + name: "C15", + from: "claude", + to: "codex", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`, + expectField: "reasoning.effort", + expectValue: "low", + expectErr: false, + }, + { + name: "C16", + from: "claude", + to: "codex", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, + expectField: "reasoning.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "C17", + from: "claude", + to: "codex", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`, + expectField: "reasoning.effort", + expectValue: "high", + expectErr: false, + }, + { + name: "C18", + from: "claude", + to: "codex", + model: "level-model", + inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`, + expectField: "reasoning.effort", + expectValue: "high", + expectErr: false, + }, + + { + name: "C19", + from: "claude", + to: "iflow", + model: "glm-test", + inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"minimal"}}`, + expectField: "chat_template_kwargs.enable_thinking", + expectValue: "true", + expectErr: false, + }, + { + name: "C20", + from: "claude", + to: "iflow", + model: "minimax-test", + inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, + expectField: "reasoning_split", + expectValue: "true", + expectErr: false, + }, + { + name: "C21", from: "claude", to: "antigravity", model: "antigravity-budget-model", @@ -2671,48 +3050,66 @@ func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) { includeThoughts: "true", expectErr: false, }, - // A7: Claude adaptive to iFlow GLM -> enabled boolean + { - name: "A7", - from: "claude", - to: "iflow", - model: "glm-test", - inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, - expectField: "chat_template_kwargs.enable_thinking", - expectValue: "true", - expectErr: false, + name: "C22", + from: "claude", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"medium"}}`, + expectField: "thinking.type", + expectValue: "adaptive", + expectField2: "output_config.effort", + expectValue2: "medium", + expectErr: false, }, - // A8: Claude adaptive to iFlow MiniMax -> enabled boolean { - name: "A8", - from: "claude", - to: "iflow", - model: "minimax-test", - inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, - expectField: "reasoning_split", - expectValue: "true", - expectErr: false, + name: "C23", + from: "claude", + to: "claude", + model: "claude-opus-4-6-model", + inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`, + expectField: "thinking.type", + expectValue: "adaptive", + expectField2: "output_config.effort", + expectValue2: "max", + expectErr: false, }, - // A9: Claude adaptive to Codex level model -> highest supported level { - name: "A9", - from: "claude", - to: "codex", - model: "level-model", - inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, - expectField: "reasoning.effort", - expectValue: "high", - expectErr: false, + name: "C24", + from: "claude", + to: "claude", + model: "claude-opus-4-6-model", + inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`, + expectErr: true, }, - // A10: Claude adaptive on non-thinking model should still be stripped { - name: "A10", - from: "claude", - to: "openai", - model: "no-thinking-model", - inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`, - expectField: "", - expectErr: false, + name: "C25", + from: "claude", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`, + expectField: "thinking.type", + expectValue: "adaptive", + expectField2: "output_config.effort", + expectValue2: "high", + expectErr: false, + }, + { + name: "C26", + from: "claude", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`, + expectErr: true, + }, + { + name: "C27", + from: "claude", + to: "claude", + model: "claude-sonnet-4-6-model", + inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`, + expectErr: true, }, } @@ -2767,6 +3164,29 @@ func getTestModels() []*registry.ModelInfo { DisplayName: "Claude Budget Model", Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, }, + { + ID: "claude-sonnet-4-6-model", + Object: "model", + Created: 1771372800, // 2026-02-17 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4.6 Sonnet", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}}, + }, + { + ID: "claude-opus-4-6-model", + Object: "model", + Created: 1770318000, // 2026-02-05 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4.6 Opus", + Description: "Premium model combining maximum intelligence with practical performance", + ContextLength: 1000000, + MaxCompletionTokens: 128000, + Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}}, + }, { ID: "antigravity-budget-model", Object: "model", @@ -2879,17 +3299,23 @@ func runThinkingTests(t *testing.T, cases []thinkingTestCase) { return } - val := gjson.GetBytes(body, tc.expectField) - if !val.Exists() { - t.Fatalf("expected field %s not found, body=%s", tc.expectField, string(body)) + assertField := func(fieldPath, expected string) { + val := gjson.GetBytes(body, fieldPath) + if !val.Exists() { + t.Fatalf("expected field %s not found, body=%s", fieldPath, string(body)) + } + actualValue := val.String() + if val.Type == gjson.Number { + actualValue = fmt.Sprintf("%d", val.Int()) + } + if actualValue != expected { + t.Fatalf("field %s: expected %q, got %q, body=%s", fieldPath, expected, actualValue, string(body)) + } } - actualValue := val.String() - if val.Type == gjson.Number { - actualValue = fmt.Sprintf("%d", val.Int()) - } - if actualValue != tc.expectValue { - t.Fatalf("field %s: expected %q, got %q, body=%s", tc.expectField, tc.expectValue, actualValue, string(body)) + assertField(tc.expectField, tc.expectValue) + if tc.expectField2 != "" { + assertField(tc.expectField2, tc.expectValue2) } if tc.includeThoughts != "" && (tc.to == "gemini" || tc.to == "gemini-cli" || tc.to == "antigravity") {