diff --git a/cmd/server/main.go b/cmd/server/main.go index 684d9295..7353c7d9 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -58,6 +58,7 @@ func main() { // Command-line flags to control the application's behavior. var login bool var codexLogin bool + var codexDeviceLogin bool var claudeLogin bool var qwenLogin bool var iflowLogin bool @@ -76,6 +77,7 @@ func main() { // Define command-line flags for different operation modes. flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth") + flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow") flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth") flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth") flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") @@ -467,6 +469,9 @@ func main() { } else if codexLogin { // Handle Codex login cmd.DoCodexLogin(cfg, options) + } else if codexDeviceLogin { + // Handle Codex device-code login + cmd.DoCodexDeviceLogin(cfg, options) } else if claudeLogin { // Handle Claude login cmd.DoClaudeLogin(cfg, options) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index e133a436..10edfa29 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -408,6 +408,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H { if !auth.LastRefreshedAt.IsZero() { entry["last_refresh"] = auth.LastRefreshedAt } + if !auth.NextRetryAfter.IsZero() { + entry["next_retry_after"] = auth.NextRetryAfter + } if path != "" { entry["path"] = path entry["source"] = "file" @@ -947,11 +950,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s if store == nil { return "", fmt.Errorf("token store unavailable") } + if h.postAuthHook != nil { + if err := h.postAuthHook(ctx, record); err != nil { + return "", fmt.Errorf("post-auth hook failed: %w", err) + } + } return store.Save(ctx, record) } func (h *Handler) RequestAnthropicToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Claude authentication...") @@ -1096,6 +1105,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{}) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient) @@ -1354,6 +1364,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { func (h *Handler) RequestCodexToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Codex authentication...") @@ -1499,6 +1510,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { func (h *Handler) RequestAntigravityToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Antigravity authentication...") @@ -1663,6 +1675,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { func (h *Handler) RequestQwenToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Qwen authentication...") @@ -1718,6 +1731,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { func (h *Handler) RequestKimiToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing Kimi authentication...") @@ -1794,6 +1808,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) { func (h *Handler) RequestIFlowToken(c *gin.Context) { ctx := context.Background() + ctx = PopulateAuthContext(ctx, c) fmt.Println("Initializing iFlow authentication...") @@ -2412,3 +2427,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) { } c.JSON(http.StatusOK, gin.H{"status": "wait"}) } + +// PopulateAuthContext extracts request info and adds it to the context +func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context { + info := &coreauth.RequestInfo{ + Query: c.Request.URL.Query(), + Headers: c.Request.Header, + } + return coreauth.WithRequestInfo(ctx, info) +} diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go index 613c9841..45786b9d 100644 --- a/internal/api/handlers/management/handler.go +++ b/internal/api/handlers/management/handler.go @@ -47,6 +47,7 @@ type Handler struct { allowRemoteOverride bool envSecret string logDir string + postAuthHook coreauth.PostAuthHook } // NewHandler creates a new management handler instance. @@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) { h.logDir = dir } +// SetPostAuthHook registers a hook to be called after auth record creation but before persistence. +func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) { + h.postAuthHook = hook +} + // Middleware enforces access control for management endpoints. // All requests (local and remote) require a valid management key. // Additionally, remote access requires allow-remote-management=true. diff --git a/internal/api/server.go b/internal/api/server.go index 76e9a33a..a7aef0aa 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -51,6 +51,7 @@ type serverOptionConfig struct { keepAliveEnabled bool keepAliveTimeout time.Duration keepAliveOnTimeout func() + postAuthHook auth.PostAuthHook } // ServerOption customises HTTP server construction. @@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque } } +// WithPostAuthHook registers a hook to be called after auth record creation. +func WithPostAuthHook(hook auth.PostAuthHook) ServerOption { + return func(cfg *serverOptionConfig) { + cfg.postAuthHook = hook + } +} + // Server represents the main API server. // It encapsulates the Gin engine, HTTP server, handlers, and configuration. type Server struct { @@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk } logDir := logging.ResolveLogDirectory(cfg) s.mgmt.SetLogDirectory(logDir) + if optionState.postAuthHook != nil { + s.mgmt.SetPostAuthHook(optionState.postAuthHook) + } s.localPassword = optionState.localPassword // Setup routes diff --git a/internal/auth/claude/token.go b/internal/auth/claude/token.go index cda10d58..6ebb0f2f 100644 --- a/internal/auth/claude/token.go +++ b/internal/auth/claude/token.go @@ -36,11 +36,21 @@ type ClaudeTokenStorage struct { // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Claude token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + // Encode and write the token data as JSON - if err = json.NewEncoder(f).Encode(ts); err != nil { + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/codex/openai_auth.go b/internal/auth/codex/openai_auth.go index 89deeadb..64bc00a6 100644 --- a/internal/auth/codex/openai_auth.go +++ b/internal/auth/codex/openai_auth.go @@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, // It performs an HTTP POST request to the OpenAI token endpoint with the provided // authorization code and PKCE verifier. func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) { + return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes) +} + +// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using +// a caller-provided redirect URI. This supports alternate auth flows such as device +// login while preserving the existing token parsing and storage behavior. +func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) { if pkceCodes == nil { return nil, fmt.Errorf("PKCE codes are required for token exchange") } + if strings.TrimSpace(redirectURI) == "" { + return nil, fmt.Errorf("redirect URI is required for token exchange") + } // Prepare token exchange request data := url.Values{ "grant_type": {"authorization_code"}, "client_id": {ClientID}, "code": {code}, - "redirect_uri": {RedirectURI}, + "redirect_uri": {strings.TrimSpace(redirectURI)}, "code_verifier": {pkceCodes.CodeVerifier}, } @@ -266,6 +276,10 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str if err == nil { return tokenData, nil } + if isNonRetryableRefreshErr(err) { + log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err) + return nil, err + } lastErr = err log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err) @@ -274,6 +288,14 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr) } +func isNonRetryableRefreshErr(err error) bool { + if err == nil { + return false + } + raw := strings.ToLower(err.Error()) + return strings.Contains(raw, "refresh_token_reused") +} + // UpdateTokenStorage updates an existing CodexTokenStorage with new token data. // This is typically called after a successful token refresh to persist the new credentials. func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) { diff --git a/internal/auth/codex/openai_auth_test.go b/internal/auth/codex/openai_auth_test.go new file mode 100644 index 00000000..3327eb4a --- /dev/null +++ b/internal/auth/codex/openai_auth_test.go @@ -0,0 +1,44 @@ +package codex + +import ( + "context" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) { + var calls int32 + auth := &CodexAuth{ + httpClient: &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&calls, 1) + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)), + Header: make(http.Header), + Request: req, + }, nil + }), + }, + } + + _, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3) + if err == nil { + t.Fatalf("expected error for non-retryable refresh failure") + } + if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") { + t.Fatalf("expected refresh_token_reused in error, got: %v", err) + } + if got := atomic.LoadInt32(&calls); got != 1 { + t.Fatalf("expected 1 refresh attempt, got %d", got) + } +} diff --git a/internal/auth/codex/token.go b/internal/auth/codex/token.go index e93fc417..7f032071 100644 --- a/internal/auth/codex/token.go +++ b/internal/auth/codex/token.go @@ -32,11 +32,21 @@ type CodexTokenStorage struct { Type string `json:"type"` // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Codex token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/gemini/gemini_token.go b/internal/auth/gemini/gemini_token.go index 0ec7da17..6848b708 100644 --- a/internal/auth/gemini/gemini_token.go +++ b/internal/auth/gemini/gemini_token.go @@ -35,11 +35,21 @@ type GeminiTokenStorage struct { // Type indicates the authentication provider type, always "gemini" for this storage. Type string `json:"type"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Gemini token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -49,6 +59,11 @@ type GeminiTokenStorage struct { func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { misc.LogSavingCredentials(authFilePath) ts.Type = "gemini" + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) } @@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error { } }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err := enc.Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/iflow/iflow_token.go b/internal/auth/iflow/iflow_token.go index 6d2beb39..a515c926 100644 --- a/internal/auth/iflow/iflow_token.go +++ b/internal/auth/iflow/iflow_token.go @@ -21,6 +21,15 @@ type IFlowTokenStorage struct { Scope string `json:"scope"` Cookie string `json:"cookie"` Type string `json:"type"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serialises the token storage to disk. @@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error { } defer func() { _ = f.Close() }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("iflow token: encode token failed: %w", err) } return nil diff --git a/internal/auth/kimi/token.go b/internal/auth/kimi/token.go index d4d06b64..7320d760 100644 --- a/internal/auth/kimi/token.go +++ b/internal/auth/kimi/token.go @@ -29,6 +29,15 @@ type KimiTokenStorage struct { Expired string `json:"expired,omitempty"` // Type indicates the authentication provider type, always "kimi" for this storage. Type string `json:"type"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // KimiTokenData holds the raw OAuth token response from Kimi. @@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + encoder := json.NewEncoder(f) encoder.SetIndent("", " ") - if err = encoder.Encode(ts); err != nil { + if err = encoder.Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/auth/qwen/qwen_token.go b/internal/auth/qwen/qwen_token.go index 4a2b3a2d..276c8b40 100644 --- a/internal/auth/qwen/qwen_token.go +++ b/internal/auth/qwen/qwen_token.go @@ -30,11 +30,21 @@ type QwenTokenStorage struct { Type string `json:"type"` // Expire is the timestamp when the current access token expires. Expire string `json:"expired"` + + // Metadata holds arbitrary key-value pairs injected via hooks. + // It is not exported to JSON directly to allow flattening during serialization. + Metadata map[string]any `json:"-"` +} + +// SetMetadata allows external callers to inject metadata into the storage before saving. +func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) { + ts.Metadata = meta } // SaveTokenToFile serializes the Qwen token storage to a JSON file. // This method creates the necessary directory structure and writes the token // data in JSON format to the specified file path for persistent storage. +// It merges any injected metadata into the top-level JSON object. // // Parameters: // - authFilePath: The full path where the token file should be saved @@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error { _ = f.Close() }() - if err = json.NewEncoder(f).Encode(ts); err != nil { + // Merge metadata using helper + data, errMerge := misc.MergeMetadata(ts, ts.Metadata) + if errMerge != nil { + return fmt.Errorf("failed to merge metadata: %w", errMerge) + } + + if err = json.NewEncoder(f).Encode(data); err != nil { return fmt.Errorf("failed to write token to file: %w", err) } return nil diff --git a/internal/cmd/openai_device_login.go b/internal/cmd/openai_device_login.go new file mode 100644 index 00000000..1b7351e6 --- /dev/null +++ b/internal/cmd/openai_device_login.go @@ -0,0 +1,60 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + log "github.com/sirupsen/logrus" +) + +const ( + codexLoginModeMetadataKey = "codex_login_mode" + codexLoginModeDevice = "device" +) + +// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the +// existing codex-login OAuth callback flow intact. +func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) { + if options == nil { + options = &LoginOptions{} + } + + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + manager := newAuthManager() + + authOpts := &sdkAuth.LoginOptions{ + NoBrowser: options.NoBrowser, + CallbackPort: options.CallbackPort, + Metadata: map[string]string{ + codexLoginModeMetadataKey: codexLoginModeDevice, + }, + Prompt: promptFn, + } + + _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) + if err != nil { + if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok { + log.Error(codex.GetUserFriendlyMessage(authErr)) + if authErr.Type == codex.ErrPortInUse.Type { + os.Exit(codex.ErrPortInUse.Code) + } + return + } + fmt.Printf("Codex device authentication failed: %v\n", err) + return + } + + if savedPath != "" { + fmt.Printf("Authentication saved to %s\n", savedPath) + } + fmt.Println("Codex device authentication successful!") +} diff --git a/internal/misc/credentials.go b/internal/misc/credentials.go index b03cd788..6b4f9ced 100644 --- a/internal/misc/credentials.go +++ b/internal/misc/credentials.go @@ -1,6 +1,7 @@ package misc import ( + "encoding/json" "fmt" "path/filepath" "strings" @@ -24,3 +25,37 @@ func LogSavingCredentials(path string) { func LogCredentialSeparator() { log.Debug(credentialSeparator) } + +// MergeMetadata serializes the source struct into a map and merges the provided metadata into it. +func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) { + var data map[string]any + + // Fast path: if source is already a map, just copy it to avoid mutation of original + if srcMap, ok := source.(map[string]any); ok { + data = make(map[string]any, len(srcMap)+len(metadata)) + for k, v := range srcMap { + data[k] = v + } + } else { + // Slow path: marshal to JSON and back to map to respect JSON tags + temp, err := json.Marshal(source) + if err != nil { + return nil, fmt.Errorf("failed to marshal source: %w", err) + } + if err := json.Unmarshal(temp, &data); err != nil { + return nil, fmt.Errorf("failed to unmarshal to map: %w", err) + } + } + + // Merge extra metadata + if metadata != nil { + if data == nil { + data = make(map[string]any) + } + for k, v := range metadata { + data[k] = v + } + } + + return data, nil +} diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go index 30f3b628..e03d878b 100644 --- a/internal/registry/model_definitions_static_data.go +++ b/internal/registry/model_definitions_static_data.go @@ -904,19 +904,12 @@ func GetIFlowModels() []*ModelInfo { Created int64 Thinking *ThinkingSupport }{ - {ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600}, {ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800}, {ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000}, {ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000}, {ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport}, - {ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400}, {ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport}, - {ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport}, - {ID: "glm-5", DisplayName: "GLM-5", Description: "Zhipu GLM 5 general model", Created: 1770768000, Thinking: iFlowThinkingSupport}, {ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000}, - {ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200}, - {ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000}, - {ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000}, {ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport}, {ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport}, {ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200}, @@ -925,11 +918,7 @@ func GetIFlowModels() []*ModelInfo { {ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600}, {ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600}, {ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600}, - {ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport}, - {ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport}, - {ID: "minimax-m2.5", DisplayName: "MiniMax-M2.5", Description: "MiniMax M2.5", Created: 1770825600, Thinking: iFlowThinkingSupport}, {ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200}, - {ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport}, } models := make([]*ModelInfo, 0, len(entries)) for _, entry := range entries { @@ -963,6 +952,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, "gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, "gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, + "gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, "gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}}, "claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, "claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index e697b64e..aa2be677 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -55,8 +55,78 @@ const ( var ( randSource = rand.New(rand.NewSource(time.Now().UnixNano())) randSourceMutex sync.Mutex + // antigravityPrimaryModelsCache keeps the latest non-empty model list fetched + // from any antigravity auth. Empty fetches never overwrite this cache. + antigravityPrimaryModelsCache struct { + mu sync.RWMutex + models []*registry.ModelInfo + } ) +func cloneAntigravityModels(models []*registry.ModelInfo) []*registry.ModelInfo { + if len(models) == 0 { + return nil + } + out := make([]*registry.ModelInfo, 0, len(models)) + for _, model := range models { + if model == nil || strings.TrimSpace(model.ID) == "" { + continue + } + out = append(out, cloneAntigravityModelInfo(model)) + } + if len(out) == 0 { + return nil + } + return out +} + +func cloneAntigravityModelInfo(model *registry.ModelInfo) *registry.ModelInfo { + if model == nil { + return nil + } + clone := *model + if len(model.SupportedGenerationMethods) > 0 { + clone.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...) + } + if len(model.SupportedParameters) > 0 { + clone.SupportedParameters = append([]string(nil), model.SupportedParameters...) + } + if model.Thinking != nil { + thinkingClone := *model.Thinking + if len(model.Thinking.Levels) > 0 { + thinkingClone.Levels = append([]string(nil), model.Thinking.Levels...) + } + clone.Thinking = &thinkingClone + } + return &clone +} + +func storeAntigravityPrimaryModels(models []*registry.ModelInfo) bool { + cloned := cloneAntigravityModels(models) + if len(cloned) == 0 { + return false + } + antigravityPrimaryModelsCache.mu.Lock() + antigravityPrimaryModelsCache.models = cloned + antigravityPrimaryModelsCache.mu.Unlock() + return true +} + +func loadAntigravityPrimaryModels() []*registry.ModelInfo { + antigravityPrimaryModelsCache.mu.RLock() + cloned := cloneAntigravityModels(antigravityPrimaryModelsCache.models) + antigravityPrimaryModelsCache.mu.RUnlock() + return cloned +} + +func fallbackAntigravityPrimaryModels() []*registry.ModelInfo { + models := loadAntigravityPrimaryModels() + if len(models) > 0 { + log.Debugf("antigravity executor: using cached primary model list (%d models)", len(models)) + } + return models +} + // AntigravityExecutor proxies requests to the antigravity upstream. type AntigravityExecutor struct { cfg *config.Config @@ -1072,7 +1142,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c exec := &AntigravityExecutor{cfg: cfg} token, updatedAuth, errToken := exec.ensureAccessToken(ctx, auth) if errToken != nil || token == "" { - return nil + return fallbackAntigravityPrimaryModels() } if updatedAuth != nil { auth = updatedAuth @@ -1096,7 +1166,7 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader(payload)) if errReq != nil { - return nil + return fallbackAntigravityPrimaryModels() } httpReq.Close = true httpReq.Header.Set("Content-Type", "application/json") @@ -1109,14 +1179,13 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) { - return nil + return fallbackAntigravityPrimaryModels() } if idx+1 < len(baseURLs) { log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } - log.Errorf("antigravity executor: models request failed: %v", errDo) - return nil + return fallbackAntigravityPrimaryModels() } bodyBytes, errRead := io.ReadAll(httpResp.Body) @@ -1128,21 +1197,27 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } - log.Errorf("antigravity executor: models read body failed: %v", errRead) - return nil + return fallbackAntigravityPrimaryModels() } if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices { if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) { log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) continue } - log.Errorf("antigravity executor: models request error status %d: %s", httpResp.StatusCode, string(bodyBytes)) - return nil + if idx+1 < len(baseURLs) { + log.Debugf("antigravity executor: models request failed with status %d on base url %s, retrying with fallback base url: %s", httpResp.StatusCode, baseURL, baseURLs[idx+1]) + continue + } + return fallbackAntigravityPrimaryModels() } result := gjson.GetBytes(bodyBytes, "models") if !result.Exists() { - return nil + if idx+1 < len(baseURLs) { + log.Debugf("antigravity executor: models field missing on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) + continue + } + return fallbackAntigravityPrimaryModels() } now := time.Now().Unix() @@ -1210,9 +1285,18 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c } models = append(models, modelInfo) } + if len(models) == 0 { + if idx+1 < len(baseURLs) { + log.Debugf("antigravity executor: empty models list on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1]) + continue + } + log.Debug("antigravity executor: fetched empty model list; retaining cached primary model list") + return fallbackAntigravityPrimaryModels() + } + storeAntigravityPrimaryModels(models) return models } - return nil + return fallbackAntigravityPrimaryModels() } func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) { diff --git a/internal/runtime/executor/antigravity_executor_models_cache_test.go b/internal/runtime/executor/antigravity_executor_models_cache_test.go new file mode 100644 index 00000000..be49a7c1 --- /dev/null +++ b/internal/runtime/executor/antigravity_executor_models_cache_test.go @@ -0,0 +1,90 @@ +package executor + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func resetAntigravityPrimaryModelsCacheForTest() { + antigravityPrimaryModelsCache.mu.Lock() + antigravityPrimaryModelsCache.models = nil + antigravityPrimaryModelsCache.mu.Unlock() +} + +func TestStoreAntigravityPrimaryModels_EmptyDoesNotOverwrite(t *testing.T) { + resetAntigravityPrimaryModelsCacheForTest() + t.Cleanup(resetAntigravityPrimaryModelsCacheForTest) + + seed := []*registry.ModelInfo{ + {ID: "claude-sonnet-4-5"}, + {ID: "gemini-2.5-pro"}, + } + if updated := storeAntigravityPrimaryModels(seed); !updated { + t.Fatal("expected non-empty model list to update primary cache") + } + + if updated := storeAntigravityPrimaryModels(nil); updated { + t.Fatal("expected nil model list not to overwrite primary cache") + } + if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{}); updated { + t.Fatal("expected empty model list not to overwrite primary cache") + } + + got := loadAntigravityPrimaryModels() + if len(got) != 2 { + t.Fatalf("expected cached model count 2, got %d", len(got)) + } + if got[0].ID != "claude-sonnet-4-5" || got[1].ID != "gemini-2.5-pro" { + t.Fatalf("unexpected cached model ids: %q, %q", got[0].ID, got[1].ID) + } +} + +func TestLoadAntigravityPrimaryModels_ReturnsClone(t *testing.T) { + resetAntigravityPrimaryModelsCacheForTest() + t.Cleanup(resetAntigravityPrimaryModelsCacheForTest) + + if updated := storeAntigravityPrimaryModels([]*registry.ModelInfo{{ + ID: "gpt-5", + DisplayName: "GPT-5", + SupportedGenerationMethods: []string{"generateContent"}, + SupportedParameters: []string{"temperature"}, + Thinking: ®istry.ThinkingSupport{ + Levels: []string{"high"}, + }, + }}); !updated { + t.Fatal("expected model cache update") + } + + got := loadAntigravityPrimaryModels() + if len(got) != 1 { + t.Fatalf("expected one cached model, got %d", len(got)) + } + got[0].ID = "mutated-id" + if len(got[0].SupportedGenerationMethods) > 0 { + got[0].SupportedGenerationMethods[0] = "mutated-method" + } + if len(got[0].SupportedParameters) > 0 { + got[0].SupportedParameters[0] = "mutated-parameter" + } + if got[0].Thinking != nil && len(got[0].Thinking.Levels) > 0 { + got[0].Thinking.Levels[0] = "mutated-level" + } + + again := loadAntigravityPrimaryModels() + if len(again) != 1 { + t.Fatalf("expected one cached model after mutation, got %d", len(again)) + } + if again[0].ID != "gpt-5" { + t.Fatalf("expected cached model id to remain %q, got %q", "gpt-5", again[0].ID) + } + if len(again[0].SupportedGenerationMethods) == 0 || again[0].SupportedGenerationMethods[0] != "generateContent" { + t.Fatalf("expected cached generation methods to be unmutated, got %v", again[0].SupportedGenerationMethods) + } + if len(again[0].SupportedParameters) == 0 || again[0].SupportedParameters[0] != "temperature" { + t.Fatalf("expected cached supported parameters to be unmutated, got %v", again[0].SupportedParameters) + } + if again[0].Thinking == nil || len(again[0].Thinking.Levels) == 0 || again[0].Thinking.Levels[0] != "high" { + t.Fatalf("expected cached model thinking levels to be unmutated, got %v", again[0].Thinking) + } +} diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 01de8f97..a0cbc0d5 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -156,7 +156,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - err = statusErr{code: httpResp.StatusCode, msg: string(b)} + err = newCodexStatusErr(httpResp.StatusCode, b) return resp, err } data, err := io.ReadAll(httpResp.Body) @@ -260,7 +260,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - err = statusErr{code: httpResp.StatusCode, msg: string(b)} + err = newCodexStatusErr(httpResp.StatusCode, b) return resp, err } data, err := io.ReadAll(httpResp.Body) @@ -358,7 +358,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } appendAPIResponseChunk(ctx, e.cfg, data) logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - err = statusErr{code: httpResp.StatusCode, msg: string(data)} + err = newCodexStatusErr(httpResp.StatusCode, data) return nil, err } out := make(chan cliproxyexecutor.StreamChunk) @@ -673,6 +673,35 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s util.ApplyCustomHeadersFromAttrs(r, attrs) } +func newCodexStatusErr(statusCode int, body []byte) statusErr { + err := statusErr{code: statusCode, msg: string(body)} + if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil { + err.retryAfter = retryAfter + } + return err +} + +func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration { + if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 { + return nil + } + if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" { + return nil + } + if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 { + resetAtTime := time.Unix(resetsAt, 0) + if resetAtTime.After(now) { + retryAfter := resetAtTime.Sub(now) + return &retryAfter + } + } + if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 { + retryAfter := time.Duration(resetsInSeconds) * time.Second + return &retryAfter + } + return nil +} + func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { if a == nil { return "", "" diff --git a/internal/runtime/executor/codex_executor_retry_test.go b/internal/runtime/executor/codex_executor_retry_test.go new file mode 100644 index 00000000..3e54ae7c --- /dev/null +++ b/internal/runtime/executor/codex_executor_retry_test.go @@ -0,0 +1,65 @@ +package executor + +import ( + "net/http" + "strconv" + "testing" + "time" +) + +func TestParseCodexRetryAfter(t *testing.T) { + now := time.Unix(1_700_000_000, 0) + + t.Run("resets_in_seconds", func(t *testing.T) { + body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`) + retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now) + if retryAfter == nil { + t.Fatalf("expected retryAfter, got nil") + } + if *retryAfter != 123*time.Second { + t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second) + } + }) + + t.Run("prefers resets_at", func(t *testing.T) { + resetAt := now.Add(5 * time.Minute).Unix() + body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`) + retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now) + if retryAfter == nil { + t.Fatalf("expected retryAfter, got nil") + } + if *retryAfter != 5*time.Minute { + t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute) + } + }) + + t.Run("fallback when resets_at is past", func(t *testing.T) { + resetAt := now.Add(-1 * time.Minute).Unix() + body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`) + retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now) + if retryAfter == nil { + t.Fatalf("expected retryAfter, got nil") + } + if *retryAfter != 77*time.Second { + t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second) + } + }) + + t.Run("non-429 status code", func(t *testing.T) { + body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`) + if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil { + t.Fatalf("expected nil for non-429, got %v", *got) + } + }) + + t.Run("non usage_limit_reached error type", func(t *testing.T) { + body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`) + if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil { + t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got) + } + }) +} + +func itoa(v int64) string { + return strconv.FormatInt(v, 10) +} diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index bcc4a057..e7957d29 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "strings" + "sync" "time" qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" @@ -22,9 +23,151 @@ import ( ) const ( - qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)" + qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)" + qwenRateLimitPerMin = 60 // 60 requests per minute per credential + qwenRateLimitWindow = time.Minute // sliding window duration ) +// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls. +var qwenBeijingLoc = func() *time.Location { + loc, err := time.LoadLocation("Asia/Shanghai") + if err != nil || loc == nil { + log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err) + return time.FixedZone("CST", 8*3600) + } + return loc +}() + +// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion. +var qwenQuotaCodes = map[string]struct{}{ + "insufficient_quota": {}, + "quota_exceeded": {}, +} + +// qwenRateLimiter tracks request timestamps per credential for rate limiting. +// Qwen has a limit of 60 requests per minute per account. +var qwenRateLimiter = struct { + sync.Mutex + requests map[string][]time.Time // authID -> request timestamps +}{ + requests: make(map[string][]time.Time), +} + +// redactAuthID returns a redacted version of the auth ID for safe logging. +// Keeps a small prefix/suffix to allow correlation across events. +func redactAuthID(id string) string { + if id == "" { + return "" + } + if len(id) <= 8 { + return id + } + return id[:4] + "..." + id[len(id)-4:] +} + +// checkQwenRateLimit checks if the credential has exceeded the rate limit. +// Returns nil if allowed, or a statusErr with retryAfter if rate limited. +func checkQwenRateLimit(authID string) error { + if authID == "" { + // Empty authID should not bypass rate limiting in production + // Use debug level to avoid log spam for certain auth flows + log.Debug("qwen rate limit check: empty authID, skipping rate limit") + return nil + } + + now := time.Now() + windowStart := now.Add(-qwenRateLimitWindow) + + qwenRateLimiter.Lock() + defer qwenRateLimiter.Unlock() + + // Get and filter timestamps within the window + timestamps := qwenRateLimiter.requests[authID] + var validTimestamps []time.Time + for _, ts := range timestamps { + if ts.After(windowStart) { + validTimestamps = append(validTimestamps, ts) + } + } + + // Always prune expired entries to prevent memory leak + // Delete empty entries, otherwise update with pruned slice + if len(validTimestamps) == 0 { + delete(qwenRateLimiter.requests, authID) + } + + // Check if rate limit exceeded + if len(validTimestamps) >= qwenRateLimitPerMin { + // Calculate when the oldest request will expire + oldestInWindow := validTimestamps[0] + retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now) + if retryAfter < time.Second { + retryAfter = time.Second + } + retryAfterSec := int(retryAfter.Seconds()) + return statusErr{ + code: http.StatusTooManyRequests, + msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec), + retryAfter: &retryAfter, + } + } + + // Record this request and update the map with pruned timestamps + validTimestamps = append(validTimestamps, now) + qwenRateLimiter.requests[authID] = validTimestamps + + return nil +} + +// isQwenQuotaError checks if the error response indicates a quota exceeded error. +// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted. +func isQwenQuotaError(body []byte) bool { + code := strings.ToLower(gjson.GetBytes(body, "error.code").String()) + errType := strings.ToLower(gjson.GetBytes(body, "error.type").String()) + + // Primary check: exact match on error.code or error.type (most reliable) + if _, ok := qwenQuotaCodes[code]; ok { + return true + } + if _, ok := qwenQuotaCodes[errType]; ok { + return true + } + + // Fallback: check message only if code/type don't match (less reliable) + msg := strings.ToLower(gjson.GetBytes(body, "error.message").String()) + if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") || + strings.Contains(msg, "free allocated quota exceeded") { + return true + } + + return false +} + +// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429. +// Returns the appropriate status code and retryAfter duration for statusErr. +// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives. +func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) { + errCode = httpCode + // Only check quota errors for expected status codes to avoid false positives + // Qwen returns 403 for quota errors, 429 for rate limits + if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) { + errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic + cooldown := timeUntilNextDay() + retryAfter = &cooldown + logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown) + } + return errCode, retryAfter +} + +// timeUntilNextDay returns duration until midnight Beijing time (UTC+8). +// Qwen's daily quota resets at 00:00 Beijing time. +func timeUntilNextDay() time.Duration { + now := time.Now() + nowLocal := now.In(qwenBeijingLoc) + tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc) + return tomorrow.Sub(now) +} + // QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions. // If access token is unavailable, it falls back to legacy via ClientAdapter. type QwenExecutor struct { @@ -67,6 +210,17 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req if opts.Alt == "responses/compact" { return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } + + // Check rate limit before proceeding + var authID string + if auth != nil { + authID = auth.ID + } + if err := checkQwenRateLimit(authID); err != nil { + logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID)) + return resp, err + } + baseModel := thinking.ParseSuffix(req.Model).ModelName token, baseURL := qwenCreds(auth) @@ -102,9 +256,8 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req return resp, err } applyQwenHeaders(httpReq, token, false) - var authID, authLabel, authType, authValue string + var authLabel, authType, authValue string if auth != nil { - authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() } @@ -135,8 +288,10 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - err = statusErr{code: httpResp.StatusCode, msg: string(b)} + + errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b) + logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter} return resp, err } data, err := io.ReadAll(httpResp.Body) @@ -158,6 +313,17 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut if opts.Alt == "responses/compact" { return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} } + + // Check rate limit before proceeding + var authID string + if auth != nil { + authID = auth.ID + } + if err := checkQwenRateLimit(authID); err != nil { + logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID)) + return nil, err + } + baseModel := thinking.ParseSuffix(req.Model).ModelName token, baseURL := qwenCreds(auth) @@ -200,9 +366,8 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return nil, err } applyQwenHeaders(httpReq, token, true) - var authID, authLabel, authType, authValue string + var authLabel, authType, authValue string if auth != nil { - authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() } @@ -228,11 +393,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + + errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b) + logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("qwen executor: close response body error: %v", errClose) } - err = statusErr{code: httpResp.StatusCode, msg: string(b)} + err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter} return nil, err } out := make(chan cliproxyexecutor.StreamChunk) diff --git a/internal/thinking/provider/openai/apply.go b/internal/thinking/provider/openai/apply.go index e8a2562f..eaad30ee 100644 --- a/internal/thinking/provider/openai/apply.go +++ b/internal/thinking/provider/openai/apply.go @@ -10,53 +10,10 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" - log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) -// validReasoningEffortLevels contains the standard values accepted by the -// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal, -// auto) are NOT in this set and must be clamped before use. -var validReasoningEffortLevels = map[string]struct{}{ - "none": {}, - "low": {}, - "medium": {}, - "high": {}, -} - -// clampReasoningEffort maps any thinking level string to a value that is safe -// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are -// mapped to the nearest standard equivalent. -// -// Mapping rules: -// - none / low / medium / high → returned as-is (already valid) -// - xhigh → "high" (nearest lower standard level) -// - minimal → "low" (nearest higher standard level) -// - auto → "medium" (reasonable default) -// - anything else → "medium" (safe default) -func clampReasoningEffort(level string) string { - if _, ok := validReasoningEffortLevels[level]; ok { - return level - } - var clamped string - switch level { - case string(thinking.LevelXHigh): - clamped = string(thinking.LevelHigh) - case string(thinking.LevelMinimal): - clamped = string(thinking.LevelLow) - case string(thinking.LevelAuto): - clamped = string(thinking.LevelMedium) - default: - clamped = string(thinking.LevelMedium) - } - log.WithFields(log.Fields{ - "original": level, - "clamped": clamped, - }).Debug("openai: reasoning_effort clamped to nearest valid standard value") - return clamped -} - // Applier implements thinking.ProviderApplier for OpenAI models. // // OpenAI-specific behavior: @@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * } if config.Mode == thinking.ModeLevel { - result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level))) + result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level)) return result, nil } @@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo * return body, nil } - result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort)) + result, _ := sjson.SetBytes(body, "reasoning_effort", effort) return result, nil } @@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte, return body, nil } - result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort)) + result, _ := sjson.SetBytes(body, "reasoning_effort", effort) return result, nil } diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index 448aa976..b634436d 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -223,14 +223,65 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData) } else if functionResponseResult.IsArray() { frResults := functionResponseResult.Array() - if len(frResults) == 1 { - functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw) + nonImageCount := 0 + lastNonImageRaw := "" + filteredJSON := "[]" + imagePartsJSON := "[]" + for _, fr := range frResults { + if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" { + inlineDataJSON := `{}` + if mimeType := fr.Get("source.media_type").String(); mimeType != "" { + inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType) + } + if data := fr.Get("source.data").String(); data != "" { + inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data) + } + + imagePartJSON := `{}` + imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON) + imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON) + continue + } + + nonImageCount++ + lastNonImageRaw = fr.Raw + filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw) + } + + if nonImageCount == 1 { + functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw) + } else if nonImageCount > 1 { + functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON) } else { - functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw) + functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "") + } + + // Place image data inside functionResponse.parts as inlineData + // instead of as sibling parts in the outer content, to avoid + // base64 data bloating the text context. + if gjson.Get(imagePartsJSON, "#").Int() > 0 { + functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON) } } else if functionResponseResult.IsObject() { - functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw) + if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" { + inlineDataJSON := `{}` + if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" { + inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType) + } + if data := functionResponseResult.Get("source.data").String(); data != "" { + inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data) + } + + imagePartJSON := `{}` + imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON) + imagePartsJSON := "[]" + imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON) + functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON) + functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "") + } else { + functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw) + } } else if functionResponseResult.Raw != "" { functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw) } else { @@ -248,7 +299,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ if sourceResult.Get("type").String() == "base64" { inlineDataJSON := `{}` if mimeType := sourceResult.Get("media_type").String(); mimeType != "" { - inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType) + inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType) } if data := sourceResult.Get("data").String(); data != "" { inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index c28a14ec..865db668 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -413,8 +413,8 @@ func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) { if !inlineData.Exists() { t.Error("inlineData should exist") } - if inlineData.Get("mime_type").String() != "image/png" { - t.Error("mime_type mismatch") + if inlineData.Get("mimeType").String() != "image/png" { + t.Error("mimeType mismatch") } if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") { t.Error("data mismatch") @@ -740,6 +740,429 @@ func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) { } } +func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) { + // tool_result with array content containing text + image should place + // image data inside functionResponse.parts as inlineData, not as a + // sibling part in the outer content (to avoid base64 context bloat). + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "Read-123-456", + "content": [ + { + "type": "text", + "text": "File content here" + }, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUg==" + } + } + ] + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + // Image should be inside functionResponse.parts, not as outer sibling part + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // Text content should be in response.result + resultText := funcResp.Get("response.result.text").String() + if resultText != "File content here" { + t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText) + } + + // Image should be in functionResponse.parts[0].inlineData + inlineData := funcResp.Get("parts.0.inlineData") + if !inlineData.Exists() { + t.Fatal("functionResponse.parts[0].inlineData should exist") + } + if inlineData.Get("mimeType").String() != "image/png" { + t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String()) + } + if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") { + t.Error("data mismatch") + } + + // Image should NOT be in outer parts (only functionResponse part should exist) + outerParts := gjson.Get(outputStr, "request.contents.0.parts") + if outerParts.IsArray() && len(outerParts.Array()) > 1 { + t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array())) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) { + // tool_result with single image object as content should place + // image data inside functionResponse.parts, not as outer sibling part. + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "Read-789-012", + "content": { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "/9j/4AAQSkZJRgABAQ==" + } + } + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // response.result should be empty (image only) + if funcResp.Get("response.result").String() != "" { + t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String()) + } + + // Image should be in functionResponse.parts[0].inlineData + inlineData := funcResp.Get("parts.0.inlineData") + if !inlineData.Exists() { + t.Fatal("functionResponse.parts[0].inlineData should exist") + } + if inlineData.Get("mimeType").String() != "image/jpeg" { + t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String()) + } + + // Image should NOT be in outer parts + outerParts := gjson.Get(outputStr, "request.contents.0.parts") + if outerParts.IsArray() && len(outerParts.Array()) > 1 { + t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array())) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) { + // tool_result with array content: 2 text items + 2 images + // All images go into functionResponse.parts, texts into response.result array + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "Multi-001", + "content": [ + {"type": "text", "text": "First text"}, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/png", "data": "AAAA"} + }, + {"type": "text", "text": "Second text"}, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"} + } + ] + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // Multiple text items => response.result is an array + resultArr := funcResp.Get("response.result") + if !resultArr.IsArray() { + t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw) + } + results := resultArr.Array() + if len(results) != 2 { + t.Fatalf("Expected 2 result items, got %d", len(results)) + } + + // Both images should be in functionResponse.parts + imgParts := funcResp.Get("parts").Array() + if len(imgParts) != 2 { + t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts)) + } + if imgParts[0].Get("inlineData.mimeType").String() != "image/png" { + t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String()) + } + if imgParts[0].Get("inlineData.data").String() != "AAAA" { + t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String()) + } + if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" { + t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String()) + } + if imgParts[1].Get("inlineData.data").String() != "BBBB" { + t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String()) + } + + // Only 1 outer part (the functionResponse itself) + outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(outerParts) != 1 { + t.Errorf("Expected 1 outer part, got %d", len(outerParts)) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) { + // tool_result with only images (no text) — response.result should be empty string + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "ImgOnly-001", + "content": [ + { + "type": "image", + "source": {"type": "base64", "media_type": "image/png", "data": "PNG1"} + }, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"} + } + ] + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // No text => response.result should be empty string + if funcResp.Get("response.result").String() != "" { + t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String()) + } + + // Both images in functionResponse.parts + imgParts := funcResp.Get("parts").Array() + if len(imgParts) != 2 { + t.Fatalf("Expected 2 image parts, got %d", len(imgParts)) + } + if imgParts[0].Get("inlineData.mimeType").String() != "image/png" { + t.Error("first image mimeType mismatch") + } + if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" { + t.Error("second image mimeType mismatch") + } + + // Only 1 outer part + outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(outerParts) != 1 { + t.Errorf("Expected 1 outer part, got %d", len(outerParts)) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) { + // image with source.type != "base64" should be treated as non-image (falls through) + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "NotB64-001", + "content": [ + {"type": "text", "text": "some output"}, + { + "type": "image", + "source": {"type": "url", "url": "https://example.com/img.png"} + } + ] + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // Non-base64 image is treated as non-image, so it goes into the filtered results + // along with the text item. Since there are 2 non-image items, result is array. + resultArr := funcResp.Get("response.result") + if !resultArr.IsArray() { + t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw) + } + results := resultArr.Array() + if len(results) != 2 { + t.Fatalf("Expected 2 result items, got %d", len(results)) + } + + // No functionResponse.parts (no base64 images collected) + if funcResp.Get("parts").Exists() { + t.Error("functionResponse.parts should NOT exist when no base64 images") + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) { + // image with source.type=base64 but missing data field + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "NoData-001", + "content": [ + {"type": "text", "text": "output"}, + { + "type": "image", + "source": {"type": "base64", "media_type": "image/png"} + } + ] + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // The image is still classified as base64 image (type check passes), + // but data field is missing => inlineData has mimeType but no data + imgParts := funcResp.Get("parts").Array() + if len(imgParts) != 1 { + t.Fatalf("Expected 1 image part, got %d", len(imgParts)) + } + if imgParts[0].Get("inlineData.mimeType").String() != "image/png" { + t.Error("mimeType should still be set") + } + if imgParts[0].Get("inlineData.data").Exists() { + t.Error("data should not exist when source.data is missing") + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) { + // image with source.type=base64 but missing media_type field + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "NoMime-001", + "content": [ + {"type": "text", "text": "output"}, + { + "type": "image", + "source": {"type": "base64", "data": "AAAA"} + } + ] + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + if !gjson.Valid(outputStr) { + t.Fatalf("Result is not valid JSON:\n%s", outputStr) + } + + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist") + } + + // The image is still classified as base64 image, + // but media_type is missing => inlineData has data but no mimeType + imgParts := funcResp.Get("parts").Array() + if len(imgParts) != 1 { + t.Fatalf("Expected 1 image part, got %d", len(imgParts)) + } + if imgParts[0].Get("inlineData.mimeType").Exists() { + t.Error("mimeType should not exist when media_type is missing") + } + if imgParts[0].Get("inlineData.data").String() != "AAAA" { + t.Error("data should still be set") + } +} + func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) { // When tools + thinking but no system instruction, should create one with hint inputJSON := []byte(`{ diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go index 8867a30e..da581d1a 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go @@ -93,3 +93,81 @@ func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) { } } } + +func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) { + // When functionResponse contains a "parts" field with inlineData (from Claude + // translator's image embedding), fixCLIToolResponse should preserve it as-is. + // parseFunctionResponseRaw returns response.Raw for valid JSON objects, + // so extra fields like "parts" survive the pipeline. + input := `{ + "model": "claude-opus-4-6-thinking", + "request": { + "contents": [ + { + "role": "model", + "parts": [ + { + "functionCall": {"name": "screenshot", "args": {}} + } + ] + }, + { + "role": "function", + "parts": [ + { + "functionResponse": { + "id": "tool-001", + "name": "screenshot", + "response": {"result": "Screenshot taken"}, + "parts": [ + {"inlineData": {"mimeType": "image/png", "data": "iVBOR"}} + ] + } + } + ] + } + ] + } + }` + + result, err := fixCLIToolResponse(input) + if err != nil { + t.Fatalf("fixCLIToolResponse failed: %v", err) + } + + // Find the function response content (role=function) + contents := gjson.Get(result, "request.contents").Array() + var funcContent gjson.Result + for _, c := range contents { + if c.Get("role").String() == "function" { + funcContent = c + break + } + } + if !funcContent.Exists() { + t.Fatal("function role content should exist in output") + } + + // The functionResponse should be preserved with its parts field + funcResp := funcContent.Get("parts.0.functionResponse") + if !funcResp.Exists() { + t.Fatal("functionResponse should exist in output") + } + + // Verify the parts field with inlineData is preserved + inlineParts := funcResp.Get("parts").Array() + if len(inlineParts) != 1 { + t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts)) + } + if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" { + t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String()) + } + if inlineParts[0].Get("inlineData.data").String() != "iVBOR" { + t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String()) + } + + // Verify response.result is also preserved + if funcResp.Get("response.result").String() != "Screenshot taken" { + t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String()) + } +} diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index a8105c4e..85b28b8b 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -187,7 +187,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if len(pieces) == 2 && len(pieces[1]) > 7 { mime := pieces[0] data := pieces[1][7:] - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) p++ @@ -201,7 +201,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ ext = sp[len(sp)-1] } if mimeType, ok := misc.MimeTypes[ext]; ok { - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData) p++ } else { @@ -235,7 +235,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ if len(pieces) == 2 && len(pieces[1]) > 7 { mime := pieces[0] data := pieces[1][7:] - node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime) + node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data) node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature) p++ diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go index af9ffef1..91bc0423 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go @@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount + promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_request.go b/internal/translator/claude/openai/chat-completions/claude_openai_request.go index 3cad1882..f94825b2 100644 --- a/internal/translator/claude/openai/chat-completions/claude_openai_request.go +++ b/internal/translator/claude/openai/chat-completions/claude_openai_request.go @@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream msg, _ = sjson.SetRaw(msg, "content.-1", imagePart) } } + + case "file": + fileData := part.Get("file.file_data").String() + if strings.HasPrefix(fileData, "data:") { + semicolonIdx := strings.Index(fileData, ";") + commaIdx := strings.Index(fileData, ",") + if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx { + mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:") + data := fileData[commaIdx+1:] + docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` + docPart, _ = sjson.Set(docPart, "source.media_type", mediaType) + docPart, _ = sjson.Set(docPart, "source.data", data) + msg, _ = sjson.SetRaw(msg, "content.-1", docPart) + } + } } return true }) diff --git a/internal/translator/claude/openai/responses/claude_openai-responses_request.go b/internal/translator/claude/openai/responses/claude_openai-responses_request.go index 337f9be9..33a81124 100644 --- a/internal/translator/claude/openai/responses/claude_openai-responses_request.go +++ b/internal/translator/claude/openai/responses/claude_openai-responses_request.go @@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte var textAggregate strings.Builder var partsJSON []string hasImage := false + hasFile := false if parts := item.Get("content"); parts.Exists() && parts.IsArray() { parts.ForEach(func(_, part gjson.Result) bool { ptype := part.Get("type").String() @@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte hasImage = true } } + case "input_file": + fileData := part.Get("file_data").String() + if fileData != "" { + mediaType := "application/octet-stream" + data := fileData + if strings.HasPrefix(fileData, "data:") { + trimmed := strings.TrimPrefix(fileData, "data:") + mediaAndData := strings.SplitN(trimmed, ";base64,", 2) + if len(mediaAndData) == 2 { + if mediaAndData[0] != "" { + mediaType = mediaAndData[0] + } + data = mediaAndData[1] + } + } + contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}` + contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType) + contentPart, _ = sjson.Set(contentPart, "source.data", data) + partsJSON = append(partsJSON, contentPart) + if role == "" { + role = "user" + } + hasFile = true + } } return true }) @@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte if len(partsJSON) > 0 { msg := `{"role":"","content":[]}` msg, _ = sjson.Set(msg, "role", role) - if len(partsJSON) == 1 && !hasImage { + if len(partsJSON) == 1 && !hasImage && !hasFile { // Preserve legacy behavior for single text content msg, _ = sjson.Delete(msg, "content") textPart := gjson.Parse(partsJSON[0]) diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request.go b/internal/translator/codex/openai/chat-completions/codex_openai_request.go index e79f97cd..1ea9ca4b 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_request.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_request.go @@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b msg, _ = sjson.SetRaw(msg, "content.-1", part) } case "file": - // Files are not specified in examples; skip for now + if role == "user" { + fileData := it.Get("file.file_data").String() + filename := it.Get("file.filename").String() + if fileData != "" { + part := `{}` + part, _ = sjson.Set(part, "type", "input_file") + part, _ = sjson.Set(part, "file_data", fileData) + if filename != "" { + part, _ = sjson.Set(part, "filename", filename) + } + msg, _ = sjson.SetRaw(msg, "content.-1", part) + } + } } } } diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request.go b/internal/translator/codex/openai/responses/codex_openai-responses_request.go index f0407149..1161c515 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request.go @@ -26,6 +26,8 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature") rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p") rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier") + rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation") + rawJSON = applyResponsesCompactionCompatibility(rawJSON) // Delete the user field as it is not supported by the Codex upstream. rawJSON, _ = sjson.DeleteBytes(rawJSON, "user") @@ -36,6 +38,23 @@ func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, return rawJSON } +// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction +// for Codex upstream compatibility. +// +// Codex /responses currently rejects context_management with: +// {"detail":"Unsupported parameter: context_management"}. +// +// Compatibility strategy: +// 1) Remove context_management before forwarding to Codex upstream. +func applyResponsesCompactionCompatibility(rawJSON []byte) []byte { + if !gjson.GetBytes(rawJSON, "context_management").Exists() { + return rawJSON + } + + rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management") + return rawJSON +} + // convertSystemRoleToDeveloper traverses the input array and converts any message items // with role "system" to role "developer". This is necessary because Codex API does not // accept "system" role in the input array. diff --git a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go index 4f562486..65732c3f 100644 --- a/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go +++ b/internal/translator/codex/openai/responses/codex_openai-responses_request_test.go @@ -280,3 +280,41 @@ func TestUserFieldDeletion(t *testing.T) { t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw) } } + +func TestContextManagementCompactionCompatibility(t *testing.T) { + inputJSON := []byte(`{ + "model": "gpt-5.2", + "context_management": [ + { + "type": "compaction", + "compact_threshold": 12000 + } + ], + "input": [{"role":"user","content":"hello"}] + }`) + + output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false) + outputStr := string(output) + + if gjson.Get(outputStr, "context_management").Exists() { + t.Fatalf("context_management should be removed for Codex compatibility") + } + if gjson.Get(outputStr, "truncation").Exists() { + t.Fatalf("truncation should be removed for Codex compatibility") + } +} + +func TestTruncationRemovedForCodexCompatibility(t *testing.T) { + inputJSON := []byte(`{ + "model": "gpt-5.2", + "truncation": "disabled", + "input": [{"role":"user","content":"hello"}] + }`) + + output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false) + outputStr := string(output) + + if gjson.Get(outputStr, "truncation").Exists() { + t.Fatalf("truncation should be removed for Codex compatibility") + } +} diff --git a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go index 0415e014..b26d431f 100644 --- a/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go +++ b/internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go @@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ } promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go index ee581c46..aeec5e9e 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_response.go @@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int()) } - promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount + promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() - baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } @@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina promptTokenCount := usageResult.Get("promptTokenCount").Int() thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int() - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount) if thoughtsTokenCount > 0 { template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) } diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index 985897fa..73609be7 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, // usage mapping if um := root.Get("usageMetadata"); um.Exists() { - // input tokens = prompt + thoughts - input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int() + // input tokens = prompt only (thoughts go to output) + input := um.Get("promptTokenCount").Int() completed, _ = sjson.Set(completed, "response.usage.input_tokens", input) // cached token details: align with OpenAI "cached_tokens" semantics. completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int()) @@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string // usage mapping if um := root.Get("usageMetadata"); um.Exists() { - // input tokens = prompt + thoughts - input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int() + // input tokens = prompt only (thoughts go to output) + input := um.Get("promptTokenCount").Int() resp, _ = sjson.Set(resp, "usage.input_tokens", input) // cached token details: align with OpenAI "cached_tokens" semantics. resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int()) diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 68859853..0e490e32 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return } if len(chunk.Payload) > 0 { + if handlerType == "openai-response" { + if err := validateSSEDataJSON(chunk.Payload); err != nil { + _ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}) + return + } + } sentPayload = true if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData { return @@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return dataChan, upstreamHeaders, errChan } +func validateSSEDataJSON(chunk []byte) error { + for _, line := range bytes.Split(chunk, []byte("\n")) { + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + if !bytes.HasPrefix(line, []byte("data:")) { + continue + } + data := bytes.TrimSpace(line[5:]) + if len(data) == 0 { + continue + } + if bytes.Equal(data, []byte("[DONE]")) { + continue + } + if json.Valid(data) { + continue + } + const max = 512 + preview := data + if len(preview) > max { + preview = preview[:max] + } + return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview) + } + return nil +} + func statusFromError(err error) int { if err == nil { return 0 diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index ba9dcac5..b08e3a99 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -134,6 +134,37 @@ type authAwareStreamExecutor struct { authIDs []string } +type invalidJSONStreamExecutor struct{} + +func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" } + +func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + ch := make(chan coreexecutor.StreamChunk, 1) + ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")} + close(ch) + return &coreexecutor.StreamResult{Chunks: ch}, nil +} + +func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{ + Code: "not_implemented", + Message: "HttpRequest not implemented", + HTTPStatus: http.StatusNotImplemented, + } +} + func (e *authAwareStreamExecutor) Identifier() string { return "codex" } func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2") } } + +func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) { + executor := &invalidJSONStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + if len(got) != 0 { + t.Fatalf("expected empty payload, got %q", string(got)) + } + + gotErr := false + for msg := range errChan { + if msg == nil { + continue + } + if msg.StatusCode != http.StatusBadGateway { + t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode) + } + if msg.Error == nil { + t.Fatalf("expected error") + } + gotErr = true + } + if !gotErr { + t.Fatalf("expected terminal error") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 1cd7e04f..3bca75f9 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -265,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush if errMsg.Error != nil && errMsg.Error.Error() != "" { errText = errMsg.Error.Error() } - body := handlers.BuildErrorResponseBody(status, errText) - _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body)) + chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0) + _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk)) }, WriteDone: func() { _, _ = c.Writer.Write([]byte("\n")) diff --git a/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go b/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go new file mode 100644 index 00000000..dce73807 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go @@ -0,0 +1,43 @@ +package openai + +import ( + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) { + gin.SetMode(gin.TestMode) + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil) + h := NewOpenAIResponsesAPIHandler(base) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + t.Fatalf("expected gin writer to implement http.Flusher") + } + + data := make(chan []byte) + errs := make(chan *interfaces.ErrorMessage, 1) + errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")} + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, data, errs) + body := recorder.Body.String() + if !strings.Contains(body, `"type":"error"`) { + t.Fatalf("expected responses error chunk, got: %q", body) + } + if strings.Contains(body, `"error":{`) { + t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body) + } +} diff --git a/sdk/api/handlers/openai_responses_stream_error.go b/sdk/api/handlers/openai_responses_stream_error.go new file mode 100644 index 00000000..e7760bd0 --- /dev/null +++ b/sdk/api/handlers/openai_responses_stream_error.go @@ -0,0 +1,119 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" +) + +type openAIResponsesStreamErrorChunk struct { + Type string `json:"type"` + Code string `json:"code"` + Message string `json:"message"` + SequenceNumber int `json:"sequence_number"` +} + +func openAIResponsesStreamErrorCode(status int) string { + switch status { + case http.StatusUnauthorized: + return "invalid_api_key" + case http.StatusForbidden: + return "insufficient_quota" + case http.StatusTooManyRequests: + return "rate_limit_exceeded" + case http.StatusNotFound: + return "model_not_found" + case http.StatusRequestTimeout: + return "request_timeout" + default: + if status >= http.StatusInternalServerError { + return "internal_server_error" + } + if status >= http.StatusBadRequest { + return "invalid_request_error" + } + return "unknown_error" + } +} + +// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk. +// +// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for +// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union +// of chunks that requires a top-level `type` field. +func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte { + if status <= 0 { + status = http.StatusInternalServerError + } + if sequenceNumber < 0 { + sequenceNumber = 0 + } + + message := strings.TrimSpace(errText) + if message == "" { + message = http.StatusText(status) + } + + code := openAIResponsesStreamErrorCode(status) + + trimmed := strings.TrimSpace(errText) + if trimmed != "" && json.Valid([]byte(trimmed)) { + var payload map[string]any + if err := json.Unmarshal([]byte(trimmed), &payload); err == nil { + if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" { + if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" { + message = strings.TrimSpace(m) + } + if v, ok := payload["code"]; ok && v != nil { + if c, ok := v.(string); ok && strings.TrimSpace(c) != "" { + code = strings.TrimSpace(c) + } else { + code = strings.TrimSpace(fmt.Sprint(v)) + } + } + if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 { + sequenceNumber = int(v) + } + } + if e, ok := payload["error"].(map[string]any); ok { + if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" { + message = strings.TrimSpace(m) + } + if v, ok := e["code"]; ok && v != nil { + if c, ok := v.(string); ok && strings.TrimSpace(c) != "" { + code = strings.TrimSpace(c) + } else { + code = strings.TrimSpace(fmt.Sprint(v)) + } + } + } + } + } + + if strings.TrimSpace(code) == "" { + code = "unknown_error" + } + + data, err := json.Marshal(openAIResponsesStreamErrorChunk{ + Type: "error", + Code: code, + Message: message, + SequenceNumber: sequenceNumber, + }) + if err == nil { + return data + } + + // Extremely defensive fallback. + data, _ = json.Marshal(openAIResponsesStreamErrorChunk{ + Type: "error", + Code: "internal_server_error", + Message: message, + SequenceNumber: sequenceNumber, + }) + if len(data) > 0 { + return data + } + return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`) +} diff --git a/sdk/api/handlers/openai_responses_stream_error_test.go b/sdk/api/handlers/openai_responses_stream_error_test.go new file mode 100644 index 00000000..90b2c667 --- /dev/null +++ b/sdk/api/handlers/openai_responses_stream_error_test.go @@ -0,0 +1,48 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "testing" +) + +func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) { + chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0) + var payload map[string]any + if err := json.Unmarshal(chunk, &payload); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if payload["type"] != "error" { + t.Fatalf("type = %v, want %q", payload["type"], "error") + } + if payload["code"] != "internal_server_error" { + t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error") + } + if payload["message"] != "unexpected EOF" { + t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF") + } + if payload["sequence_number"] != float64(0) { + t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0) + } +} + +func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) { + chunk := BuildOpenAIResponsesStreamErrorChunk( + http.StatusInternalServerError, + `{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`, + 0, + ) + var payload map[string]any + if err := json.Unmarshal(chunk, &payload); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if payload["type"] != "error" { + t.Fatalf("type = %v, want %q", payload["type"], "error") + } + if payload["code"] != "internal_server_error" { + t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error") + } + if payload["message"] != "oops" { + t.Fatalf("message = %v, want %q", payload["message"], "oops") + } +} diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index c81842eb..1af36936 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -2,8 +2,6 @@ package auth import ( "context" - "crypto/sha256" - "encoding/hex" "fmt" "net/http" "strings" @@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts opts = &LoginOptions{} } + if shouldUseCodexDeviceFlow(opts) { + return a.loginWithDeviceFlow(ctx, cfg, opts) + } + callbackPort := a.CallbackPort if opts.CallbackPort > 0 { callbackPort = opts.CallbackPort @@ -186,39 +188,5 @@ waitForCallback: return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) } - tokenStorage := authSvc.CreateTokenStorage(authBundle) - - if tokenStorage == nil || tokenStorage.Email == "" { - return nil, fmt.Errorf("codex token storage missing account information") - } - - planType := "" - hashAccountID := "" - if tokenStorage.IDToken != "" { - if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil { - planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) - accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID) - if accountID != "" { - digest := sha256.Sum256([]byte(accountID)) - hashAccountID = hex.EncodeToString(digest[:])[:8] - } - } - } - fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true) - metadata := map[string]any{ - "email": tokenStorage.Email, - } - - fmt.Println("Codex authentication successful") - if authBundle.APIKey != "" { - fmt.Println("Codex API key obtained and stored") - } - - return &coreauth.Auth{ - ID: fileName, - Provider: a.Provider(), - FileName: fileName, - Storage: tokenStorage, - Metadata: metadata, - }, nil + return a.buildAuthRecord(authSvc, authBundle) } diff --git a/sdk/auth/codex_device.go b/sdk/auth/codex_device.go new file mode 100644 index 00000000..78a95af8 --- /dev/null +++ b/sdk/auth/codex_device.go @@ -0,0 +1,291 @@ +package auth + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +const ( + codexLoginModeMetadataKey = "codex_login_mode" + codexLoginModeDevice = "device" + codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode" + codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token" + codexDeviceVerificationURL = "https://auth.openai.com/codex/device" + codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback" + codexDeviceTimeout = 15 * time.Minute + codexDeviceDefaultPollIntervalSeconds = 5 +) + +type codexDeviceUserCodeRequest struct { + ClientID string `json:"client_id"` +} + +type codexDeviceUserCodeResponse struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + UserCodeAlt string `json:"usercode"` + Interval json.RawMessage `json:"interval"` +} + +type codexDeviceTokenRequest struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` +} + +type codexDeviceTokenResponse struct { + AuthorizationCode string `json:"authorization_code"` + CodeVerifier string `json:"code_verifier"` + CodeChallenge string `json:"code_challenge"` +} + +func shouldUseCodexDeviceFlow(opts *LoginOptions) bool { + if opts == nil || opts.Metadata == nil { + return false + } + return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice) +} + +func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) { + if ctx == nil { + ctx = context.Background() + } + + httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) + + userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient) + if err != nil { + return nil, err + } + + deviceCode := strings.TrimSpace(userCodeResp.UserCode) + if deviceCode == "" { + deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt) + } + deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID) + if deviceCode == "" || deviceAuthID == "" { + return nil, fmt.Errorf("codex device flow did not return required fields") + } + + pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval) + + fmt.Println("Starting Codex device authentication...") + fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL) + fmt.Printf("Codex device code: %s\n", deviceCode) + + if !opts.NoBrowser { + if !browser.IsAvailable() { + log.Warn("No browser available; please open the device URL manually") + } else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil { + log.Warnf("Failed to open browser automatically: %v", errOpen) + } + } + + tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval) + if err != nil { + return nil, err + } + + authCode := strings.TrimSpace(tokenResp.AuthorizationCode) + codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier) + codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge) + if authCode == "" || codeVerifier == "" || codeChallenge == "" { + return nil, fmt.Errorf("codex device flow token response missing required fields") + } + + authSvc := codex.NewCodexAuth(cfg) + authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect( + ctx, + authCode, + codexDeviceTokenExchangeRedirectURI, + &codex.PKCECodes{ + CodeVerifier: codeVerifier, + CodeChallenge: codeChallenge, + }, + ) + if err != nil { + return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err) + } + + return a.buildAuthRecord(authSvc, authBundle) +} + +func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) { + body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID}) + if err != nil { + return nil, fmt.Errorf("failed to encode codex device request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create codex device request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to request codex device code: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read codex device code response: %w", err) + } + + if !codexDeviceIsSuccessStatus(resp.StatusCode) { + trimmed := strings.TrimSpace(string(respBody)) + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode) + } + if trimmed == "" { + trimmed = "empty response body" + } + return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed) + } + + var parsed codexDeviceUserCodeResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, fmt.Errorf("failed to decode codex device code response: %w", err) + } + + return &parsed, nil +} + +func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) { + deadline := time.Now().Add(codexDeviceTimeout) + + for { + if time.Now().After(deadline) { + return nil, fmt.Errorf("codex device authentication timed out after 15 minutes") + } + + body, err := json.Marshal(codexDeviceTokenRequest{ + DeviceAuthID: deviceAuthID, + UserCode: userCode, + }) + if err != nil { + return nil, fmt.Errorf("failed to encode codex device poll request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed to create codex device poll request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to poll codex device token: %w", err) + } + + respBody, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if readErr != nil { + return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr) + } + + switch { + case codexDeviceIsSuccessStatus(resp.StatusCode): + var parsed codexDeviceTokenResponse + if err := json.Unmarshal(respBody, &parsed); err != nil { + return nil, fmt.Errorf("failed to decode codex device token response: %w", err) + } + return &parsed, nil + case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound: + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + continue + } + default: + trimmed := strings.TrimSpace(string(respBody)) + if trimmed == "" { + trimmed = "empty response body" + } + return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed) + } + } +} + +func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration { + defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second + if len(raw) == 0 { + return defaultInterval + } + + var asString string + if err := json.Unmarshal(raw, &asString); err == nil { + if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + + var asInt int + if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 { + return time.Duration(asInt) * time.Second + } + + return defaultInterval +} + +func codexDeviceIsSuccessStatus(code int) bool { + return code >= 200 && code < 300 +} + +func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) { + tokenStorage := authSvc.CreateTokenStorage(authBundle) + + if tokenStorage == nil || tokenStorage.Email == "" { + return nil, fmt.Errorf("codex token storage missing account information") + } + + planType := "" + hashAccountID := "" + if tokenStorage.IDToken != "" { + if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil { + planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType) + accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID) + if accountID != "" { + digest := sha256.Sum256([]byte(accountID)) + hashAccountID = hex.EncodeToString(digest[:])[:8] + } + } + } + + fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true) + metadata := map[string]any{ + "email": tokenStorage.Email, + } + + fmt.Println("Codex authentication successful") + if authBundle.APIKey != "" { + fmt.Println("Codex API key obtained and stored") + } + + return &coreauth.Auth{ + ID: fileName, + Provider: a.Provider(), + FileName: fileName, + Storage: tokenStorage, + Metadata: metadata, + }, nil +} diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 795bba0d..c424a89b 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str return "", fmt.Errorf("auth filestore: create dir failed: %w", err) } + // metadataSetter is a private interface for TokenStorage implementations that support metadata injection. + type metadataSetter interface { + SetMetadata(map[string]any) + } + switch { case auth.Storage != nil: + if setter, ok := auth.Storage.(metadataSetter); ok { + setter.SetMetadata(auth.Metadata) + } if err = auth.Storage.SaveTokenToFile(path); err != nil { return "", err } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index cd447e68..df44c855 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -60,6 +60,7 @@ type RefreshEvaluator interface { const ( refreshCheckInterval = 5 * time.Second + refreshMaxConcurrency = 16 refreshPendingBackoff = time.Minute refreshFailureBackoff = 5 * time.Minute quotaBackoffBase = time.Second @@ -155,7 +156,8 @@ type Manager struct { rtProvider RoundTripperProvider // Auto refresh state - refreshCancel context.CancelFunc + refreshCancel context.CancelFunc + refreshSemaphore chan struct{} } // NewManager constructs a manager with optional custom selector and hook. @@ -173,6 +175,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { hook: hook, auths: make(map[string]*Auth), providerOffsets: make(map[string]int), + refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) @@ -1828,9 +1831,7 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error { // every few seconds and triggers refresh operations when required. // Only one loop is kept alive; starting a new one cancels the previous run. func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) { - if interval <= 0 || interval > refreshCheckInterval { - interval = refreshCheckInterval - } else { + if interval <= 0 { interval = refreshCheckInterval } if m.refreshCancel != nil { @@ -1880,11 +1881,25 @@ func (m *Manager) checkRefreshes(ctx context.Context) { if !m.markRefreshPending(a.ID, now) { continue } - go m.refreshAuth(ctx, a.ID) + go m.refreshAuthWithLimit(ctx, a.ID) } } } +func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) { + if m.refreshSemaphore == nil { + m.refreshAuth(ctx, id) + return + } + select { + case m.refreshSemaphore <- struct{}{}: + defer func() { <-m.refreshSemaphore }() + case <-ctx.Done(): + return + } + m.refreshAuth(ctx, id) +} + func (m *Manager) snapshotAuths() []*Auth { m.mu.RLock() defer m.mu.RUnlock() diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index a173ed01..cf79e173 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "math/rand/v2" "net/http" "sort" "strconv" @@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] } // Pick selects the next available auth for the provider in a round-robin manner. +// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute), +// a two-level round-robin is used: first cycling across credential groups (parent +// accounts), then cycling within each group's project auths. func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts now := time.Now() @@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o if limit <= 0 { limit = 4096 } - if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { - s.cursors = make(map[string]int) - } - index := s.cursors[key] + // Check if any available auth has gemini_virtual_parent attribute, + // indicating gemini-cli virtual auths that should use credential-level polling. + groups, parentOrder := groupByVirtualParent(available) + if len(parentOrder) > 1 { + // Two-level round-robin: first select a credential group, then pick within it. + groupKey := key + "::group" + s.ensureCursorKey(groupKey, limit) + if _, exists := s.cursors[groupKey]; !exists { + // Seed with a random initial offset so the starting credential is randomized. + s.cursors[groupKey] = rand.IntN(len(parentOrder)) + } + groupIndex := s.cursors[groupKey] + if groupIndex >= 2_147_483_640 { + groupIndex = 0 + } + s.cursors[groupKey] = groupIndex + 1 + + selectedParent := parentOrder[groupIndex%len(parentOrder)] + group := groups[selectedParent] + + // Second level: round-robin within the selected credential group. + innerKey := key + "::cred:" + selectedParent + s.ensureCursorKey(innerKey, limit) + innerIndex := s.cursors[innerKey] + if innerIndex >= 2_147_483_640 { + innerIndex = 0 + } + s.cursors[innerKey] = innerIndex + 1 + s.mu.Unlock() + return group[innerIndex%len(group)], nil + } + + // Flat round-robin for non-grouped auths (original behavior). + s.ensureCursorKey(key, limit) + index := s.cursors[key] if index >= 2_147_483_640 { index = 0 } - s.cursors[key] = index + 1 s.mu.Unlock() - // log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available)) return available[index%len(available)], nil } +// ensureCursorKey ensures the cursor map has capacity for the given key. +// Must be called with s.mu held. +func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) { + if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { + s.cursors = make(map[string]int) + } +} + +// groupByVirtualParent groups auths by their gemini_virtual_parent attribute. +// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration. +// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks +// this attribute, nil/nil is returned so the caller falls back to flat round-robin. +func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) { + if len(auths) == 0 { + return nil, nil + } + groups := make(map[string][]*Auth) + for _, a := range auths { + parent := "" + if a.Attributes != nil { + parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) + } + if parent == "" { + // Non-virtual auth present; fall back to flat round-robin. + return nil, nil + } + groups[parent] = append(groups[parent], a) + } + // Collect parent IDs in sorted order for stable cursor indexing. + parentOrder := make([]string, 0, len(groups)) + for p := range groups { + parentOrder = append(parentOrder, p) + } + sort.Strings(parentOrder) + return groups, parentOrder +} + // Pick selects the first available auth for the provider in a deterministic manner. func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index fe1cf15e..79431a9a 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) { t.Fatalf("selector.cursors missing key %q", "gemini:m3") } } + +func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // Simulate two gemini-cli credentials, each with multiple projects: + // Credential A (parent = "cred-a.json") has 3 projects + // Credential B (parent = "cred-b.json") has 2 projects + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, + {ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, + } + + // Two-level round-robin: consecutive picks must alternate between credentials. + // Credential group order is randomized, but within each call the group cursor + // advances by 1, so consecutive picks should cycle through different parents. + picks := make([]string, 6) + parents := make([]string, 6) + for i := 0; i < 6; i++ { + got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + picks[i] = got.ID + parents[i] = got.Attributes["gemini_virtual_parent"] + } + + // Verify property: consecutive picks must alternate between credential groups. + for i := 1; i < len(parents); i++ { + if parents[i] == parents[i-1] { + t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials", + i-1, i, parents[i], picks[i-1], picks[i]) + } + } + + // Verify property: each credential's projects are picked in sequence (round-robin within group). + credPicks := map[string][]string{} + for i, id := range picks { + credPicks[parents[i]] = append(credPicks[parents[i]], id) + } + for parent, ids := range credPicks { + for i := 1; i < len(ids); i++ { + if ids[i] == ids[i-1] { + t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i]) + } + } + } +} + +func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // All auths from the same parent - should fall back to flat round-robin + // because there's only one credential group (no benefit from two-level). + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + } + + // With single parent group, parentOrder has length 1, so it uses flat round-robin. + // Sorted by ID: proj-a1, proj-a2, proj-a3 + want := []string{ + "cred-a.json::proj-a1", + "cred-a.json::proj-a2", + "cred-a.json::proj-a3", + "cred-a.json::proj-a1", + } + + for i, expectedID := range want { + got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != expectedID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) + } + } +} + +func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects + // alongside virtual ones). Should fall back to flat round-robin. + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-regular.json"}, // no gemini_virtual_parent + } + + // groupByVirtualParent returns nil when any auth lacks the attribute, + // so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json + want := []string{ + "cred-a.json::proj-a1", + "cred-regular.json", + "cred-a.json::proj-a1", + } + + for i, expectedID := range want { + got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != expectedID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) + } + } +} diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index 96534bbe..0bfaf11a 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -1,9 +1,12 @@ package auth import ( + "context" "crypto/sha256" "encoding/hex" "encoding/json" + "net/http" + "net/url" "strconv" "strings" "sync" @@ -12,6 +15,33 @@ import ( baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth" ) +// PostAuthHook defines a function that is called after an Auth record is created +// but before it is persisted to storage. This allows for modification of the +// Auth record (e.g., injecting metadata) based on external context. +type PostAuthHook func(context.Context, *Auth) error + +// RequestInfo holds information extracted from the HTTP request. +// It is injected into the context passed to PostAuthHook. +type RequestInfo struct { + Query url.Values + Headers http.Header +} + +type requestInfoKey struct{} + +// WithRequestInfo returns a new context with the given RequestInfo attached. +func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context { + return context.WithValue(ctx, requestInfoKey{}, info) +} + +// GetRequestInfo retrieves the RequestInfo from the context, if present. +func GetRequestInfo(ctx context.Context) *RequestInfo { + if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok { + return val + } + return nil +} + // Auth encapsulates the runtime state and metadata associated with a single credential. type Auth struct { // ID uniquely identifies the auth record across restarts. diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index 60ca07f5..0e6d1421 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder { return b } +// WithPostAuthHook registers a hook to be called after an Auth record is created +// but before it is persisted to storage. +func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder { + if hook == nil { + return b + } + b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook)) + return b +} + // Build validates inputs, applies defaults, and returns a ready-to-run service. func (b *Builder) Build() (*Service, error) { if b.cfg == nil { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index e89c49c0..1f9f4d6f 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -925,6 +925,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { key = strings.ToLower(strings.TrimSpace(a.Provider)) } GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) + if provider == "antigravity" { + s.backfillAntigravityModels(a, models) + } return } @@ -1069,6 +1072,56 @@ func (s *Service) oauthExcludedModels(provider, authKind string) []string { return cfg.OAuthExcludedModels[providerKey] } +func (s *Service) backfillAntigravityModels(source *coreauth.Auth, primaryModels []*ModelInfo) { + if s == nil || s.coreManager == nil || len(primaryModels) == 0 { + return + } + + sourceID := "" + if source != nil { + sourceID = strings.TrimSpace(source.ID) + } + + reg := registry.GetGlobalRegistry() + for _, candidate := range s.coreManager.List() { + if candidate == nil || candidate.Disabled { + continue + } + candidateID := strings.TrimSpace(candidate.ID) + if candidateID == "" || candidateID == sourceID { + continue + } + if !strings.EqualFold(strings.TrimSpace(candidate.Provider), "antigravity") { + continue + } + if len(reg.GetModelsForClient(candidateID)) > 0 { + continue + } + + authKind := strings.ToLower(strings.TrimSpace(candidate.Attributes["auth_kind"])) + if authKind == "" { + if kind, _ := candidate.AccountInfo(); strings.EqualFold(kind, "api_key") { + authKind = "apikey" + } + } + excluded := s.oauthExcludedModels("antigravity", authKind) + if candidate.Attributes != nil { + if val, ok := candidate.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" { + excluded = strings.Split(val, ",") + } + } + + models := applyExcludedModels(primaryModels, excluded) + models = applyOAuthModelAlias(s.cfg, "antigravity", authKind, models) + if len(models) == 0 { + continue + } + + reg.RegisterClient(candidateID, "antigravity", applyModelPrefixes(models, candidate.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) + log.Debugf("antigravity models backfilled for auth %s using primary model list", candidateID) + } +} + func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo { if len(models) == 0 || len(excluded) == 0 { return models diff --git a/sdk/cliproxy/service_antigravity_backfill_test.go b/sdk/cliproxy/service_antigravity_backfill_test.go new file mode 100644 index 00000000..df087438 --- /dev/null +++ b/sdk/cliproxy/service_antigravity_backfill_test.go @@ -0,0 +1,135 @@ +package cliproxy + +import ( + "context" + "strings" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestBackfillAntigravityModels_RegistersMissingAuth(t *testing.T) { + source := &coreauth.Auth{ + ID: "ag-backfill-source", + Provider: "antigravity", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "oauth", + }, + } + target := &coreauth.Auth{ + ID: "ag-backfill-target", + Provider: "antigravity", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "oauth", + }, + } + + manager := coreauth.NewManager(nil, nil, nil) + if _, err := manager.Register(context.Background(), source); err != nil { + t.Fatalf("register source auth: %v", err) + } + if _, err := manager.Register(context.Background(), target); err != nil { + t.Fatalf("register target auth: %v", err) + } + + service := &Service{ + cfg: &config.Config{}, + coreManager: manager, + } + + reg := registry.GetGlobalRegistry() + reg.UnregisterClient(source.ID) + reg.UnregisterClient(target.ID) + t.Cleanup(func() { + reg.UnregisterClient(source.ID) + reg.UnregisterClient(target.ID) + }) + + primary := []*ModelInfo{ + {ID: "claude-sonnet-4-5"}, + {ID: "gemini-2.5-pro"}, + } + reg.RegisterClient(source.ID, "antigravity", primary) + + service.backfillAntigravityModels(source, primary) + + got := reg.GetModelsForClient(target.ID) + if len(got) != 2 { + t.Fatalf("expected target auth to be backfilled with 2 models, got %d", len(got)) + } + + ids := make(map[string]struct{}, len(got)) + for _, model := range got { + if model == nil { + continue + } + ids[strings.ToLower(strings.TrimSpace(model.ID))] = struct{}{} + } + if _, ok := ids["claude-sonnet-4-5"]; !ok { + t.Fatal("expected backfilled model claude-sonnet-4-5") + } + if _, ok := ids["gemini-2.5-pro"]; !ok { + t.Fatal("expected backfilled model gemini-2.5-pro") + } +} + +func TestBackfillAntigravityModels_RespectsExcludedModels(t *testing.T) { + source := &coreauth.Auth{ + ID: "ag-backfill-source-excluded", + Provider: "antigravity", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "oauth", + }, + } + target := &coreauth.Auth{ + ID: "ag-backfill-target-excluded", + Provider: "antigravity", + Status: coreauth.StatusActive, + Attributes: map[string]string{ + "auth_kind": "oauth", + "excluded_models": "gemini-2.5-pro", + }, + } + + manager := coreauth.NewManager(nil, nil, nil) + if _, err := manager.Register(context.Background(), source); err != nil { + t.Fatalf("register source auth: %v", err) + } + if _, err := manager.Register(context.Background(), target); err != nil { + t.Fatalf("register target auth: %v", err) + } + + service := &Service{ + cfg: &config.Config{}, + coreManager: manager, + } + + reg := registry.GetGlobalRegistry() + reg.UnregisterClient(source.ID) + reg.UnregisterClient(target.ID) + t.Cleanup(func() { + reg.UnregisterClient(source.ID) + reg.UnregisterClient(target.ID) + }) + + primary := []*ModelInfo{ + {ID: "claude-sonnet-4-5"}, + {ID: "gemini-2.5-pro"}, + } + reg.RegisterClient(source.ID, "antigravity", primary) + + service.backfillAntigravityModels(source, primary) + + got := reg.GetModelsForClient(target.ID) + if len(got) != 1 { + t.Fatalf("expected 1 model after exclusion, got %d", len(got)) + } + if got[0] == nil || !strings.EqualFold(strings.TrimSpace(got[0].ID), "claude-sonnet-4-5") { + t.Fatalf("expected remaining model %q, got %+v", "claude-sonnet-4-5", got[0]) + } +}