diff --git a/.github/workflows/docker-image.yml b/.github/workflows/docker-image.yml index 6c99b21b..4a9501c0 100644 --- a/.github/workflows/docker-image.yml +++ b/.github/workflows/docker-image.yml @@ -15,6 +15,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + - name: Refresh models catalog + run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to DockerHub @@ -46,6 +48,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + - name: Refresh models catalog + run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to DockerHub diff --git a/.github/workflows/pr-test-build.yml b/.github/workflows/pr-test-build.yml index 477ff049..b24b1fcb 100644 --- a/.github/workflows/pr-test-build.yml +++ b/.github/workflows/pr-test-build.yml @@ -12,6 +12,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 + - name: Refresh models catalog + run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json - name: Set up Go uses: actions/setup-go@v5 with: diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 64e7a5b7..30cdbeab 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -16,6 +16,8 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + - name: Refresh models catalog + run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json - run: git fetch --force --tags - uses: actions/setup-go@v4 with: diff --git a/README.md b/README.md index 8491b97c..722fa86b 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,10 @@ A Windows tray application implemented using PowerShell scripts, without relying A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed. +### [All API Hub](https://github.com/qixing-jk/all-api-hub) + +Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync. + > [!NOTE] > If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. diff --git a/README_CN.md b/README_CN.md index 6e987fdf..5dff9c55 100644 --- a/README_CN.md +++ b/README_CN.md @@ -149,6 +149,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方 一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。 +### [All API Hub](https://github.com/qixing-jk/all-api-hub) + +用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。 + > [!NOTE] > 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。 diff --git a/cmd/server/main.go b/cmd/server/main.go index 7353c7d9..3d9ee6cf 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -24,6 +24,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/store" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator" "github.com/router-for-me/CLIProxyAPI/v6/internal/tui" @@ -494,6 +495,7 @@ func main() { if standalone { // Standalone mode: start an embedded local server and connect TUI client to it. managementasset.StartAutoUpdater(context.Background(), configFilePath) + registry.StartModelsUpdater(context.Background()) hook := tui.NewLogHook(2000) hook.SetFormatter(&logging.LogFormatter{}) log.AddHook(hook) @@ -566,6 +568,7 @@ func main() { } else { // Start the main proxy service managementasset.StartAutoUpdater(context.Background(), configFilePath) + registry.StartModelsUpdater(context.Background()) cmd.StartService(cfg, configFilePath, password) } } diff --git a/config.example.yaml b/config.example.yaml index 4297eb15..fb29477d 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -63,6 +63,7 @@ error-logs-max-files: 10 usage-statistics-enabled: false # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ +# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly. proxy-url: "" # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). @@ -110,6 +111,7 @@ nonstream-keepalive-interval: 0 # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" +# # proxy-url: "direct" # optional: explicit direct connect for this credential # models: # - name: "gemini-2.5-flash" # upstream model name # alias: "gemini-flash" # client alias mapped to the upstream model @@ -128,6 +130,7 @@ nonstream-keepalive-interval: 0 # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # models: # - name: "gpt-5-codex" # upstream model name # alias: "codex-latest" # client alias mapped to the upstream model @@ -146,6 +149,7 @@ nonstream-keepalive-interval: 0 # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # models: # - name: "claude-3-5-sonnet-20241022" # upstream model name # alias: "claude-sonnet-latest" # client alias mapped to the upstream model @@ -191,10 +195,22 @@ nonstream-keepalive-interval: 0 # api-key-entries: # - api-key: "sk-or-v1-...b780" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # - api-key: "sk-or-v1-...b781" # without proxy-url # models: # The models supported by the provider. # - name: "moonshotai/kimi-k2:free" # The actual model name. # alias: "kimi-k2" # The alias used in the API. +# # You may repeat the same alias to build an internal model pool. +# # The client still sees only one alias in the model list. +# # Requests to that alias will round-robin across the upstream names below, +# # and if the chosen upstream fails before producing output, the request will +# # continue with the next upstream model in the same alias pool. +# - name: "qwen3.5-plus" +# alias: "claude-opus-4.66" +# - name: "glm-5" +# alias: "claude-opus-4.66" +# - name: "kimi-k2.5" +# alias: "claude-opus-4.66" # Vertex API keys (Vertex-compatible endpoints, use API key + base URL) # vertex-api-key: @@ -202,6 +218,7 @@ nonstream-keepalive-interval: 0 # prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential # base-url: "https://example.com/api" # e.g. https://zenmux.ai/api # proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override +# # proxy-url: "direct" # optional: explicit direct connect for this credential # headers: # X-Custom-Header: "custom-value" # models: # optional: map aliases to upstream model names diff --git a/internal/api/handlers/management/api_tools.go b/internal/api/handlers/management/api_tools.go index c7846a75..de546ea8 100644 --- a/internal/api/handlers/management/api_tools.go +++ b/internal/api/handlers/management/api_tools.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" "strings" @@ -14,8 +13,8 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) @@ -660,45 +659,10 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper { } func buildProxyTransport(proxyStr string) *http.Transport { - proxyStr = strings.TrimSpace(proxyStr) - if proxyStr == "" { + transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) + if errBuild != nil { + log.WithError(errBuild).Debug("build proxy transport failed") return nil } - - proxyURL, errParse := url.Parse(proxyStr) - if errParse != nil { - log.WithError(errParse).Debug("parse proxy URL failed") - return nil - } - if proxyURL.Scheme == "" || proxyURL.Host == "" { - log.Debug("proxy URL missing scheme/host") - return nil - } - - if proxyURL.Scheme == "socks5" { - var proxyAuth *proxy.Auth - if proxyURL.User != nil { - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth = &proxy.Auth{User: username, Password: password} - } - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed") - return nil - } - return &http.Transport{ - Proxy: nil, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } - - if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - return &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - - log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme) - return nil + return transport } diff --git a/internal/api/handlers/management/api_tools_test.go b/internal/api/handlers/management/api_tools_test.go index fecbee9c..5b0c6369 100644 --- a/internal/api/handlers/management/api_tools_test.go +++ b/internal/api/handlers/management/api_tools_test.go @@ -1,173 +1,58 @@ package management import ( - "context" - "encoding/json" - "io" "net/http" - "net/http/httptest" - "net/url" - "strings" - "sync" "testing" - "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) -type memoryAuthStore struct { - mu sync.Mutex - items map[string]*coreauth.Auth -} +func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) { + t.Parallel() -func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) { - _ = ctx - s.mu.Lock() - defer s.mu.Unlock() - out := make([]*coreauth.Auth, 0, len(s.items)) - for _, a := range s.items { - out = append(out, a.Clone()) - } - return out, nil -} - -func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) { - _ = ctx - if auth == nil { - return "", nil - } - s.mu.Lock() - if s.items == nil { - s.items = make(map[string]*coreauth.Auth) - } - s.items[auth.ID] = auth.Clone() - s.mu.Unlock() - return auth.ID, nil -} - -func (s *memoryAuthStore) Delete(ctx context.Context, id string) error { - _ = ctx - s.mu.Lock() - delete(s.items, id) - s.mu.Unlock() - return nil -} - -func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) { - var callCount int - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - if r.Method != http.MethodPost { - t.Fatalf("expected POST, got %s", r.Method) - } - if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") { - t.Fatalf("unexpected content-type: %s", ct) - } - bodyBytes, _ := io.ReadAll(r.Body) - _ = r.Body.Close() - values, err := url.ParseQuery(string(bodyBytes)) - if err != nil { - t.Fatalf("parse form: %v", err) - } - if values.Get("grant_type") != "refresh_token" { - t.Fatalf("unexpected grant_type: %s", values.Get("grant_type")) - } - if values.Get("refresh_token") != "rt" { - t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token")) - } - if values.Get("client_id") != antigravityOAuthClientID { - t.Fatalf("unexpected client_id: %s", values.Get("client_id")) - } - if values.Get("client_secret") != antigravityOAuthClientSecret { - t.Fatalf("unexpected client_secret") - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "access_token": "new-token", - "refresh_token": "rt2", - "expires_in": int64(3600), - "token_type": "Bearer", - }) - })) - t.Cleanup(srv.Close) - - originalURL := antigravityOAuthTokenURL - antigravityOAuthTokenURL = srv.URL - t.Cleanup(func() { antigravityOAuthTokenURL = originalURL }) - - store := &memoryAuthStore{} - manager := coreauth.NewManager(store, nil, nil) - - auth := &coreauth.Auth{ - ID: "antigravity-test.json", - FileName: "antigravity-test.json", - Provider: "antigravity", - Metadata: map[string]any{ - "type": "antigravity", - "access_token": "old-token", - "refresh_token": "rt", - "expires_in": int64(3600), - "timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(), - "expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, }, } - if _, err := manager.Register(context.Background(), auth); err != nil { - t.Fatalf("register auth: %v", err) - } - h := &Handler{authManager: manager} - token, err := h.resolveTokenForAuth(context.Background(), auth) - if err != nil { - t.Fatalf("resolveTokenForAuth: %v", err) + transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"}) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) } - if token != "new-token" { - t.Fatalf("expected refreshed token, got %q", token) - } - if callCount != 1 { - t.Fatalf("expected 1 refresh call, got %d", callCount) - } - - updated, ok := manager.GetByID(auth.ID) - if !ok || updated == nil { - t.Fatalf("expected auth in manager after update") - } - if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" { - t.Fatalf("expected manager metadata updated, got %q", got) + if httpTransport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") } } -func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) { - var callCount int - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.WriteHeader(http.StatusInternalServerError) - })) - t.Cleanup(srv.Close) +func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) { + t.Parallel() - originalURL := antigravityOAuthTokenURL - antigravityOAuthTokenURL = srv.URL - t.Cleanup(func() { antigravityOAuthTokenURL = originalURL }) - - auth := &coreauth.Auth{ - ID: "antigravity-valid.json", - FileName: "antigravity-valid.json", - Provider: "antigravity", - Metadata: map[string]any{ - "type": "antigravity", - "access_token": "ok-token", - "expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + h := &Handler{ + cfg: &config.Config{ + SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}, }, } - h := &Handler{} - token, err := h.resolveTokenForAuth(context.Background(), auth) - if err != nil { - t.Fatalf("resolveTokenForAuth: %v", err) + + transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"}) + httpTransport, ok := transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", transport) } - if token != "ok-token" { - t.Fatalf("expected existing token, got %q", token) + + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) } - if callCount != 0 { - t.Fatalf("expected no refresh calls, got %d", callCount) + + proxyURL, errProxy := httpTransport.Proxy(req) + if errProxy != nil { + t.Fatalf("httpTransport.Proxy returned error: %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" { + t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL) } } diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index e0a16377..2e471ae8 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -1306,12 +1306,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") + SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll)) return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - SetOAuthSessionError(state, "Failed to verify Cloud AI API status") + SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify)) return } ts.ProjectID = strings.Join(projects, ",") @@ -1320,7 +1320,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { ts.Auto = false if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil { log.Errorf("Google One auto-discovery failed: %v", errSetup) - SetOAuthSessionError(state, "Google One auto-discovery failed") + SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup)) return } if strings.TrimSpace(ts.ProjectID) == "" { @@ -1331,19 +1331,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - SetOAuthSessionError(state, "Failed to verify Cloud AI API status") + SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck)) return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the auto-discovered project") - SetOAuthSessionError(state, "Cloud AI API not enabled") + SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID)) return } } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") + SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure)) return } @@ -1356,13 +1356,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - SetOAuthSessionError(state, "Failed to verify Cloud AI API status") + SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck)) return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - SetOAuthSessionError(state, "Cloud AI API not enabled") + SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID)) return } } diff --git a/internal/api/handlers/management/test_store_test.go b/internal/api/handlers/management/test_store_test.go new file mode 100644 index 00000000..cf7dbaf7 --- /dev/null +++ b/internal/api/handlers/management/test_store_test.go @@ -0,0 +1,49 @@ +package management + +import ( + "context" + "sync" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +type memoryAuthStore struct { + mu sync.Mutex + items map[string]*coreauth.Auth +} + +func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) { + s.mu.Lock() + defer s.mu.Unlock() + + out := make([]*coreauth.Auth, 0, len(s.items)) + for _, item := range s.items { + out = append(out, item) + } + return out, nil +} + +func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) { + if auth == nil { + return "", nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + if s.items == nil { + s.items = make(map[string]*coreauth.Auth) + } + s.items[auth.ID] = auth + return auth.ID, nil +} + +func (s *memoryAuthStore) Delete(_ context.Context, id string) error { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.items, id) + return nil +} + +func (s *memoryAuthStore) SetBaseDir(string) {} diff --git a/internal/auth/claude/utls_transport.go b/internal/auth/claude/utls_transport.go index 27ec87e1..88b69c9b 100644 --- a/internal/auth/claude/utls_transport.go +++ b/internal/auth/claude/utls_transport.go @@ -4,12 +4,12 @@ package claude import ( "net/http" - "net/url" "strings" "sync" tls "github.com/refraction-networking/utls" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/proxy" @@ -31,17 +31,12 @@ type utlsRoundTripper struct { // newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper { var dialer proxy.Dialer = proxy.Direct - if cfg != nil && cfg.ProxyURL != "" { - proxyURL, err := url.Parse(cfg.ProxyURL) - if err != nil { - log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err) - } else { - pDialer, err := proxy.FromURL(proxyURL, proxy.Direct) - if err != nil { - log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err) - } else { - dialer = pDialer - } + if cfg != nil { + proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL) + if errBuild != nil { + log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild) + } else if mode != proxyutil.ModeInherit && proxyDialer != nil { + dialer = proxyDialer } } diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index 6406a0e1..c459c5ca 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -10,9 +10,7 @@ import ( "errors" "fmt" "io" - "net" "net/http" - "net/url" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" @@ -20,9 +18,9 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "golang.org/x/net/proxy" "golang.org/x/oauth2" "golang.org/x/oauth2/google" @@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken } callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort) - // Configure proxy settings for the HTTP client if a proxy URL is provided. - proxyURL, err := url.Parse(cfg.ProxyURL) - if err == nil { - var transport *http.Transport - if proxyURL.Scheme == "socks5" { - // Handle SOCKS5 proxy. - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - auth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) - } - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Handle HTTP/HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - - if transport != nil { - proxyClient := &http.Client{Transport: transport} - ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) - } + transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL) + if errBuild != nil { + log.Errorf("%v", errBuild) + } else if transport != nil { + proxyClient := &http.Client{Transport: transport} + ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) } + var err error + // Configure the OAuth2 client. conf := &oauth2.Config{ ClientID: ClientID, diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index c1796979..b7f5edb1 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -1,5 +1,5 @@ // Package registry provides model definitions and lookup helpers for various AI providers. -// Static model metadata is stored in model_definitions_static_data.go. +// Static model metadata is loaded from the embedded models.json file and can be refreshed from network. package registry import ( @@ -7,6 +7,131 @@ import ( "strings" ) +// AntigravityModelConfig captures static antigravity model overrides, including +// Thinking budget limits and provider max completion tokens. +type AntigravityModelConfig struct { + Thinking *ThinkingSupport `json:"thinking,omitempty"` + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` +} + +// staticModelsJSON mirrors the top-level structure of models.json. +type staticModelsJSON struct { + Claude []*ModelInfo `json:"claude"` + Gemini []*ModelInfo `json:"gemini"` + Vertex []*ModelInfo `json:"vertex"` + GeminiCLI []*ModelInfo `json:"gemini-cli"` + AIStudio []*ModelInfo `json:"aistudio"` + CodexFree []*ModelInfo `json:"codex-free"` + CodexTeam []*ModelInfo `json:"codex-team"` + CodexPlus []*ModelInfo `json:"codex-plus"` + CodexPro []*ModelInfo `json:"codex-pro"` + Qwen []*ModelInfo `json:"qwen"` + IFlow []*ModelInfo `json:"iflow"` + Kimi []*ModelInfo `json:"kimi"` + Antigravity map[string]*AntigravityModelConfig `json:"antigravity"` +} + +// GetClaudeModels returns the standard Claude model definitions. +func GetClaudeModels() []*ModelInfo { + return cloneModelInfos(getModels().Claude) +} + +// GetGeminiModels returns the standard Gemini model definitions. +func GetGeminiModels() []*ModelInfo { + return cloneModelInfos(getModels().Gemini) +} + +// GetGeminiVertexModels returns Gemini model definitions for Vertex AI. +func GetGeminiVertexModels() []*ModelInfo { + return cloneModelInfos(getModels().Vertex) +} + +// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI. +func GetGeminiCLIModels() []*ModelInfo { + return cloneModelInfos(getModels().GeminiCLI) +} + +// GetAIStudioModels returns model definitions for AI Studio. +func GetAIStudioModels() []*ModelInfo { + return cloneModelInfos(getModels().AIStudio) +} + +// GetCodexFreeModels returns model definitions for the Codex free plan tier. +func GetCodexFreeModels() []*ModelInfo { + return cloneModelInfos(getModels().CodexFree) +} + +// GetCodexTeamModels returns model definitions for the Codex team plan tier. +func GetCodexTeamModels() []*ModelInfo { + return cloneModelInfos(getModels().CodexTeam) +} + +// GetCodexPlusModels returns model definitions for the Codex plus plan tier. +func GetCodexPlusModels() []*ModelInfo { + return cloneModelInfos(getModels().CodexPlus) +} + +// GetCodexProModels returns model definitions for the Codex pro plan tier. +func GetCodexProModels() []*ModelInfo { + return cloneModelInfos(getModels().CodexPro) +} + +// GetQwenModels returns the standard Qwen model definitions. +func GetQwenModels() []*ModelInfo { + return cloneModelInfos(getModels().Qwen) +} + +// GetIFlowModels returns the standard iFlow model definitions. +func GetIFlowModels() []*ModelInfo { + return cloneModelInfos(getModels().IFlow) +} + +// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions. +func GetKimiModels() []*ModelInfo { + return cloneModelInfos(getModels().Kimi) +} + +// GetAntigravityModelConfig returns static configuration for antigravity models. +// Keys use upstream model names returned by the Antigravity models endpoint. +func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { + data := getModels() + if len(data.Antigravity) == 0 { + return nil + } + out := make(map[string]*AntigravityModelConfig, len(data.Antigravity)) + for k, v := range data.Antigravity { + out[k] = cloneAntigravityModelConfig(v) + } + return out +} + +func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig { + if cfg == nil { + return nil + } + copyConfig := *cfg + if cfg.Thinking != nil { + copyThinking := *cfg.Thinking + if len(cfg.Thinking.Levels) > 0 { + copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...) + } + copyConfig.Thinking = ©Thinking + } + return ©Config +} + +// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned. +func cloneModelInfos(models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { + return nil + } + out := make([]*ModelInfo, len(models)) + for i, m := range models { + out[i] = cloneModelInfo(m) + } + return out +} + // GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider. // It returns nil when the channel is unknown. // @@ -35,7 +160,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { case "aistudio": return GetAIStudioModels() case "codex": - return GetOpenAIModels() + return GetCodexProModels() case "qwen": return GetQwenModels() case "iflow": @@ -77,27 +202,28 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { return nil } + data := getModels() allModels := [][]*ModelInfo{ - GetClaudeModels(), - GetGeminiModels(), - GetGeminiVertexModels(), - GetGeminiCLIModels(), - GetAIStudioModels(), - GetOpenAIModels(), - GetQwenModels(), - GetIFlowModels(), - GetKimiModels(), + data.Claude, + data.Gemini, + data.Vertex, + data.GeminiCLI, + data.AIStudio, + data.CodexPro, + data.Qwen, + data.IFlow, + data.Kimi, } for _, models := range allModels { for _, m := range models { if m != nil && m.ID == modelID { - return m + return cloneModelInfo(m) } } } // Check Antigravity static config - if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil { + if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil { return &ModelInfo{ ID: modelID, Thinking: cfg.Thinking, diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go deleted file mode 100644 index 750aa4b4..00000000 --- a/internal/registry/model_definitions_static_data.go +++ /dev/null @@ -1,1078 +0,0 @@ -// Package registry provides model definitions for various AI service providers. -// This file stores the static model metadata catalog. -package registry - -// GetClaudeModels returns the standard Claude model definitions -func GetClaudeModels() []*ModelInfo { - return []*ModelInfo{ - - { - ID: "claude-haiku-4-5-20251001", - Object: "model", - Created: 1759276800, // 2025-10-01 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4.5 Haiku", - ContextLength: 200000, - MaxCompletionTokens: 64000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, - }, - { - ID: "claude-sonnet-4-5-20250929", - Object: "model", - Created: 1759104000, // 2025-09-29 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4.5 Sonnet", - ContextLength: 200000, - MaxCompletionTokens: 64000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, - }, - { - ID: "claude-sonnet-4-6", - Object: "model", - Created: 1771372800, // 2026-02-17 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4.6 Sonnet", - ContextLength: 200000, - MaxCompletionTokens: 64000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}}, - }, - { - ID: "claude-opus-4-6", - Object: "model", - Created: 1770318000, // 2026-02-05 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4.6 Opus", - Description: "Premium model combining maximum intelligence with practical performance", - ContextLength: 1000000, - MaxCompletionTokens: 128000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}}, - }, - { - ID: "claude-opus-4-5-20251101", - Object: "model", - Created: 1761955200, // 2025-11-01 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4.5 Opus", - Description: "Premium model combining maximum intelligence with practical performance", - ContextLength: 200000, - MaxCompletionTokens: 64000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, - }, - { - ID: "claude-opus-4-1-20250805", - Object: "model", - Created: 1722945600, // 2025-08-05 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4.1 Opus", - ContextLength: 200000, - MaxCompletionTokens: 32000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false}, - }, - { - ID: "claude-opus-4-20250514", - Object: "model", - Created: 1715644800, // 2025-05-14 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4 Opus", - ContextLength: 200000, - MaxCompletionTokens: 32000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false}, - }, - { - ID: "claude-sonnet-4-20250514", - Object: "model", - Created: 1715644800, // 2025-05-14 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 4 Sonnet", - ContextLength: 200000, - MaxCompletionTokens: 64000, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false}, - }, - { - ID: "claude-3-7-sonnet-20250219", - Object: "model", - Created: 1708300800, // 2025-02-19 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 3.7 Sonnet", - ContextLength: 128000, - MaxCompletionTokens: 8192, - Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: false, DynamicAllowed: false}, - }, - { - ID: "claude-3-5-haiku-20241022", - Object: "model", - Created: 1729555200, // 2024-10-22 - OwnedBy: "anthropic", - Type: "claude", - DisplayName: "Claude 3.5 Haiku", - ContextLength: 128000, - MaxCompletionTokens: 8192, - // Thinking: not supported for Haiku models - }, - } -} - -// GetGeminiModels returns the standard Gemini model definitions -func GetGeminiModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "gemini-2.5-pro", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-pro", - Version: "2.5", - DisplayName: "Gemini 2.5 Pro", - Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash", - Version: "001", - DisplayName: "Gemini 2.5 Flash", - Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash-lite", - Object: "model", - Created: 1753142400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash-lite", - Version: "2.5", - DisplayName: "Gemini 2.5 Flash Lite", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-3-pro-preview", - Object: "model", - Created: 1737158400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-pro-preview", - Version: "3.0", - DisplayName: "Gemini 3 Pro Preview", - Description: "Gemini 3 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - { - ID: "gemini-3.1-pro-preview", - Object: "model", - Created: 1771459200, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-pro-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Pro Preview", - Description: "Gemini 3.1 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - { - ID: "gemini-3-flash-preview", - Object: "model", - Created: 1765929600, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-flash-preview", - Version: "3.0", - DisplayName: "Gemini 3 Flash Preview", - Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, - }, - { - ID: "gemini-3.1-flash-lite-preview", - Object: "model", - Created: 1776288000, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-flash-lite-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Flash Lite Preview", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, - }, - { - ID: "gemini-3-pro-image-preview", - Object: "model", - Created: 1737158400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-pro-image-preview", - Version: "3.0", - DisplayName: "Gemini 3 Pro Image Preview", - Description: "Gemini 3 Pro Image Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - } -} - -func GetGeminiVertexModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "gemini-2.5-pro", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-pro", - Version: "2.5", - DisplayName: "Gemini 2.5 Pro", - Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash", - Version: "001", - DisplayName: "Gemini 2.5 Flash", - Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash-lite", - Object: "model", - Created: 1753142400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash-lite", - Version: "2.5", - DisplayName: "Gemini 2.5 Flash Lite", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-3-pro-preview", - Object: "model", - Created: 1737158400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-pro-preview", - Version: "3.0", - DisplayName: "Gemini 3 Pro Preview", - Description: "Gemini 3 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - { - ID: "gemini-3-flash-preview", - Object: "model", - Created: 1765929600, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-flash-preview", - Version: "3.0", - DisplayName: "Gemini 3 Flash Preview", - Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, - }, - { - ID: "gemini-3.1-pro-preview", - Object: "model", - Created: 1771459200, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-pro-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Pro Preview", - Description: "Gemini 3.1 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - { - ID: "gemini-3.1-flash-lite-preview", - Object: "model", - Created: 1776288000, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-flash-lite-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Flash Lite Preview", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, - }, - { - ID: "gemini-3-pro-image-preview", - Object: "model", - Created: 1737158400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-pro-image-preview", - Version: "3.0", - DisplayName: "Gemini 3 Pro Image Preview", - Description: "Gemini 3 Pro Image Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - // Imagen image generation models - use :predict action - { - ID: "imagen-4.0-generate-001", - Object: "model", - Created: 1750000000, - OwnedBy: "google", - Type: "gemini", - Name: "models/imagen-4.0-generate-001", - Version: "4.0", - DisplayName: "Imagen 4.0 Generate", - Description: "Imagen 4.0 image generation model", - SupportedGenerationMethods: []string{"predict"}, - }, - { - ID: "imagen-4.0-ultra-generate-001", - Object: "model", - Created: 1750000000, - OwnedBy: "google", - Type: "gemini", - Name: "models/imagen-4.0-ultra-generate-001", - Version: "4.0", - DisplayName: "Imagen 4.0 Ultra Generate", - Description: "Imagen 4.0 Ultra high-quality image generation model", - SupportedGenerationMethods: []string{"predict"}, - }, - { - ID: "imagen-3.0-generate-002", - Object: "model", - Created: 1740000000, - OwnedBy: "google", - Type: "gemini", - Name: "models/imagen-3.0-generate-002", - Version: "3.0", - DisplayName: "Imagen 3.0 Generate", - Description: "Imagen 3.0 image generation model", - SupportedGenerationMethods: []string{"predict"}, - }, - { - ID: "imagen-3.0-fast-generate-001", - Object: "model", - Created: 1740000000, - OwnedBy: "google", - Type: "gemini", - Name: "models/imagen-3.0-fast-generate-001", - Version: "3.0", - DisplayName: "Imagen 3.0 Fast Generate", - Description: "Imagen 3.0 fast image generation model", - SupportedGenerationMethods: []string{"predict"}, - }, - { - ID: "imagen-4.0-fast-generate-001", - Object: "model", - Created: 1750000000, - OwnedBy: "google", - Type: "gemini", - Name: "models/imagen-4.0-fast-generate-001", - Version: "4.0", - DisplayName: "Imagen 4.0 Fast Generate", - Description: "Imagen 4.0 fast image generation model", - SupportedGenerationMethods: []string{"predict"}, - }, - } -} - -// GetGeminiCLIModels returns the standard Gemini model definitions -func GetGeminiCLIModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "gemini-2.5-pro", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-pro", - Version: "2.5", - DisplayName: "Gemini 2.5 Pro", - Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash", - Version: "001", - DisplayName: "Gemini 2.5 Flash", - Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash-lite", - Object: "model", - Created: 1753142400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash-lite", - Version: "2.5", - DisplayName: "Gemini 2.5 Flash Lite", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-3-pro-preview", - Object: "model", - Created: 1737158400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-pro-preview", - Version: "3.0", - DisplayName: "Gemini 3 Pro Preview", - Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - { - ID: "gemini-3.1-pro-preview", - Object: "model", - Created: 1771459200, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-pro-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Pro Preview", - Description: "Gemini 3.1 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, - }, - { - ID: "gemini-3-flash-preview", - Object: "model", - Created: 1765929600, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-flash-preview", - Version: "3.0", - DisplayName: "Gemini 3 Flash Preview", - Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, - }, - { - ID: "gemini-3.1-flash-lite-preview", - Object: "model", - Created: 1776288000, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-flash-lite-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Flash Lite Preview", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, - }, - } -} - -// GetAIStudioModels returns the Gemini model definitions for AI Studio integrations -func GetAIStudioModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "gemini-2.5-pro", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-pro", - Version: "2.5", - DisplayName: "Gemini 2.5 Pro", - Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash", - Version: "001", - DisplayName: "Gemini 2.5 Flash", - Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-2.5-flash-lite", - Object: "model", - Created: 1753142400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash-lite", - Version: "2.5", - DisplayName: "Gemini 2.5 Flash Lite", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-3-pro-preview", - Object: "model", - Created: 1737158400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-pro-preview", - Version: "3.0", - DisplayName: "Gemini 3 Pro Preview", - Description: "Gemini 3 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-3.1-pro-preview", - Object: "model", - Created: 1771459200, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-pro-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Pro Preview", - Description: "Gemini 3.1 Pro Preview", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-3-flash-preview", - Object: "model", - Created: 1765929600, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3-flash-preview", - Version: "3.0", - DisplayName: "Gemini 3 Flash Preview", - Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-3.1-flash-lite-preview", - Object: "model", - Created: 1776288000, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-3.1-flash-lite-preview", - Version: "3.1", - DisplayName: "Gemini 3.1 Flash Lite Preview", - Description: "Our smallest and most cost effective model, built for at scale usage.", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}, - }, - { - ID: "gemini-pro-latest", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-pro-latest", - Version: "2.5", - DisplayName: "Gemini Pro Latest", - Description: "Latest release of Gemini Pro", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, - }, - { - ID: "gemini-flash-latest", - Object: "model", - Created: 1750118400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-flash-latest", - Version: "2.5", - DisplayName: "Gemini Flash Latest", - Description: "Latest release of Gemini Flash", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "gemini-flash-lite-latest", - Object: "model", - Created: 1753142400, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-flash-lite-latest", - Version: "2.5", - DisplayName: "Gemini Flash-Lite Latest", - Description: "Latest release of Gemini Flash-Lite", - InputTokenLimit: 1048576, - OutputTokenLimit: 65536, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - Thinking: &ThinkingSupport{Min: 512, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, - }, - // { - // ID: "gemini-2.5-flash-image-preview", - // Object: "model", - // Created: 1756166400, - // OwnedBy: "google", - // Type: "gemini", - // Name: "models/gemini-2.5-flash-image-preview", - // Version: "2.5", - // DisplayName: "Gemini 2.5 Flash Image Preview", - // Description: "State-of-the-art image generation and editing model.", - // InputTokenLimit: 1048576, - // OutputTokenLimit: 8192, - // SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - // // image models don't support thinkingConfig; leave Thinking nil - // }, - { - ID: "gemini-2.5-flash-image", - Object: "model", - Created: 1759363200, - OwnedBy: "google", - Type: "gemini", - Name: "models/gemini-2.5-flash-image", - Version: "2.5", - DisplayName: "Gemini 2.5 Flash Image", - Description: "State-of-the-art image generation and editing model.", - InputTokenLimit: 1048576, - OutputTokenLimit: 8192, - SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, - // image models don't support thinkingConfig; leave Thinking nil - }, - } -} - -// GetOpenAIModels returns the standard OpenAI model definitions -func GetOpenAIModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "gpt-5", - Object: "model", - Created: 1754524800, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5-2025-08-07", - DisplayName: "GPT 5", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"minimal", "low", "medium", "high"}}, - }, - { - ID: "gpt-5-codex", - Object: "model", - Created: 1757894400, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5-2025-09-15", - DisplayName: "GPT 5 Codex", - Description: "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, - }, - { - ID: "gpt-5-codex-mini", - Object: "model", - Created: 1762473600, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5-2025-11-07", - DisplayName: "GPT 5 Codex Mini", - Description: "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, - }, - { - ID: "gpt-5.1", - Object: "model", - Created: 1762905600, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5", - Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high"}}, - }, - { - ID: "gpt-5.1-codex", - Object: "model", - Created: 1762905600, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5.1 Codex", - Description: "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, - }, - { - ID: "gpt-5.1-codex-mini", - Object: "model", - Created: 1762905600, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.1-2025-11-12", - DisplayName: "GPT 5.1 Codex Mini", - Description: "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}, - }, - { - ID: "gpt-5.1-codex-max", - Object: "model", - Created: 1763424000, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.1-max", - DisplayName: "GPT 5.1 Codex Max", - Description: "Stable version of GPT 5.1 Codex Max", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, - }, - { - ID: "gpt-5.2", - Object: "model", - Created: 1765440000, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.2", - DisplayName: "GPT 5.2", - Description: "Stable version of GPT 5.2", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}}, - }, - { - ID: "gpt-5.2-codex", - Object: "model", - Created: 1765440000, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.2", - DisplayName: "GPT 5.2 Codex", - Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, - }, - { - ID: "gpt-5.3-codex", - Object: "model", - Created: 1770307200, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.3", - DisplayName: "GPT 5.3 Codex", - Description: "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.", - ContextLength: 400000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, - }, - { - ID: "gpt-5.3-codex-spark", - Object: "model", - Created: 1770912000, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.3", - DisplayName: "GPT 5.3 Codex Spark", - Description: "Ultra-fast coding model.", - ContextLength: 128000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, - }, - { - ID: "gpt-5.4", - Object: "model", - Created: 1772668800, - OwnedBy: "openai", - Type: "openai", - Version: "gpt-5.4", - DisplayName: "GPT 5.4", - Description: "Stable version of GPT 5.4", - ContextLength: 1_050_000, - MaxCompletionTokens: 128000, - SupportedParameters: []string{"tools"}, - Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}}, - }, - } -} - -// GetQwenModels returns the standard Qwen model definitions -func GetQwenModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "qwen3-coder-plus", - Object: "model", - Created: 1753228800, - OwnedBy: "qwen", - Type: "qwen", - Version: "3.0", - DisplayName: "Qwen3 Coder Plus", - Description: "Advanced code generation and understanding model", - ContextLength: 32768, - MaxCompletionTokens: 8192, - SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, - }, - { - ID: "qwen3-coder-flash", - Object: "model", - Created: 1753228800, - OwnedBy: "qwen", - Type: "qwen", - Version: "3.0", - DisplayName: "Qwen3 Coder Flash", - Description: "Fast code generation model", - ContextLength: 8192, - MaxCompletionTokens: 2048, - SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, - }, - { - ID: "coder-model", - Object: "model", - Created: 1771171200, - OwnedBy: "qwen", - Type: "qwen", - Version: "3.5", - DisplayName: "Qwen 3.5 Plus", - Description: "efficient hybrid model with leading coding performance", - ContextLength: 1048576, - MaxCompletionTokens: 65536, - SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, - }, - { - ID: "vision-model", - Object: "model", - Created: 1758672000, - OwnedBy: "qwen", - Type: "qwen", - Version: "3.0", - DisplayName: "Qwen3 Vision Model", - Description: "Vision model model", - ContextLength: 32768, - MaxCompletionTokens: 2048, - SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, - }, - } -} - -// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models -// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle). -// Uses level-based configuration so standard normalization flows apply before conversion. -var iFlowThinkingSupport = &ThinkingSupport{ - Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"}, -} - -// GetIFlowModels returns supported models for iFlow OAuth accounts. -func GetIFlowModels() []*ModelInfo { - entries := []struct { - ID string - DisplayName string - Description string - Created int64 - Thinking *ThinkingSupport - }{ - {ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800}, - {ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000}, - {ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000}, - {ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400, Thinking: iFlowThinkingSupport}, - {ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport}, - {ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000}, - {ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000, Thinking: iFlowThinkingSupport}, - {ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200, Thinking: iFlowThinkingSupport}, - {ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200}, - {ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200}, - {ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400}, - {ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600}, - {ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600}, - {ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600}, - {ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200}, - } - models := make([]*ModelInfo, 0, len(entries)) - for _, entry := range entries { - models = append(models, &ModelInfo{ - ID: entry.ID, - Object: "model", - Created: entry.Created, - OwnedBy: "iflow", - Type: "iflow", - DisplayName: entry.DisplayName, - Description: entry.Description, - Thinking: entry.Thinking, - }) - } - return models -} - -// AntigravityModelConfig captures static antigravity model overrides, including -// Thinking budget limits and provider max completion tokens. -type AntigravityModelConfig struct { - Thinking *ThinkingSupport - MaxCompletionTokens int -} - -// GetAntigravityModelConfig returns static configuration for antigravity models. -// Keys use upstream model names returned by the Antigravity models endpoint. -func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { - return map[string]*AntigravityModelConfig{ - "gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, - "gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}}, - "gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, - "gemini-3-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, - "gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, - "gemini-3.1-pro-low": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}}, - "gemini-3.1-flash-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}}, - "gemini-3.1-flash-lite-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "high"}}}, - "gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}}, - "claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, - "claude-sonnet-4-6": {Thinking: &ThinkingSupport{Min: 1024, Max: 64000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, - "gpt-oss-120b-medium": {}, - } -} - -// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions -func GetKimiModels() []*ModelInfo { - return []*ModelInfo{ - { - ID: "kimi-k2", - Object: "model", - Created: 1752192000, // 2025-07-11 - OwnedBy: "moonshot", - Type: "kimi", - DisplayName: "Kimi K2", - Description: "Kimi K2 - Moonshot AI's flagship coding model", - ContextLength: 131072, - MaxCompletionTokens: 32768, - }, - { - ID: "kimi-k2-thinking", - Object: "model", - Created: 1762387200, // 2025-11-06 - OwnedBy: "moonshot", - Type: "kimi", - DisplayName: "Kimi K2 Thinking", - Description: "Kimi K2 Thinking - Extended reasoning model", - ContextLength: 131072, - MaxCompletionTokens: 32768, - Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, - }, - { - ID: "kimi-k2.5", - Object: "model", - Created: 1769472000, // 2026-01-26 - OwnedBy: "moonshot", - Type: "kimi", - DisplayName: "Kimi K2.5", - Description: "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities", - ContextLength: 131072, - MaxCompletionTokens: 32768, - Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true}, - }, - } -} diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index e036a04f..8f56c43d 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -62,6 +62,11 @@ type ModelInfo struct { UserDefined bool `json:"-"` } +type availableModelsCacheEntry struct { + models []map[string]any + expiresAt time.Time +} + // ThinkingSupport describes a model family's supported internal reasoning budget range. // Values are interpreted in provider-native token units. type ThinkingSupport struct { @@ -116,6 +121,8 @@ type ModelRegistry struct { clientProviders map[string]string // mutex ensures thread-safe access to the registry mutex *sync.RWMutex + // availableModelsCache stores per-handler snapshots for GetAvailableModels. + availableModelsCache map[string]availableModelsCacheEntry // hook is an optional callback sink for model registration changes hook ModelRegistryHook } @@ -128,15 +135,28 @@ var registryOnce sync.Once func GetGlobalRegistry() *ModelRegistry { registryOnce.Do(func() { globalRegistry = &ModelRegistry{ - models: make(map[string]*ModelRegistration), - clientModels: make(map[string][]string), - clientModelInfos: make(map[string]map[string]*ModelInfo), - clientProviders: make(map[string]string), - mutex: &sync.RWMutex{}, + models: make(map[string]*ModelRegistration), + clientModels: make(map[string][]string), + clientModelInfos: make(map[string]map[string]*ModelInfo), + clientProviders: make(map[string]string), + availableModelsCache: make(map[string]availableModelsCacheEntry), + mutex: &sync.RWMutex{}, } }) return globalRegistry } +func (r *ModelRegistry) ensureAvailableModelsCacheLocked() { + if r.availableModelsCache == nil { + r.availableModelsCache = make(map[string]availableModelsCacheEntry) + } +} + +func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() { + if len(r.availableModelsCache) == 0 { + return + } + clear(r.availableModelsCache) +} // LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions. func LookupModelInfo(modelID string, provider ...string) *ModelInfo { @@ -151,9 +171,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo { } if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil { - return info + return cloneModelInfo(info) } - return LookupStaticModelInfo(modelID) + return cloneModelInfo(LookupStaticModelInfo(modelID)) } // SetHook sets an optional hook for observing model registration changes. @@ -211,6 +231,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) { func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) { r.mutex.Lock() defer r.mutex.Unlock() + r.ensureAvailableModelsCacheLocked() provider := strings.ToLower(clientProvider) uniqueModelIDs := make([]string, 0, len(models)) @@ -236,6 +257,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ delete(r.clientModels, clientID) delete(r.clientModelInfos, clientID) delete(r.clientProviders, clientID) + r.invalidateAvailableModelsCacheLocked() misc.LogCredentialSeparator() return } @@ -263,6 +285,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ } else { delete(r.clientProviders, clientID) } + r.invalidateAvailableModelsCacheLocked() r.triggerModelsRegistered(provider, clientID, models) log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs)) misc.LogCredentialSeparator() @@ -406,6 +429,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ delete(r.clientProviders, clientID) } + r.invalidateAvailableModelsCacheLocked() r.triggerModelsRegistered(provider, clientID, models) if len(added) == 0 && len(removed) == 0 && !providerChanged { // Only metadata (e.g., display name) changed; skip separator when no log output. @@ -509,6 +533,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo { if len(model.SupportedOutputModalities) > 0 { copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...) } + if model.Thinking != nil { + copyThinking := *model.Thinking + if len(model.Thinking.Levels) > 0 { + copyThinking.Levels = append([]string(nil), model.Thinking.Levels...) + } + copyModel.Thinking = ©Thinking + } return ©Model } @@ -538,6 +569,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) { r.mutex.Lock() defer r.mutex.Unlock() r.unregisterClientInternal(clientID) + r.invalidateAvailableModelsCacheLocked() } // unregisterClientInternal performs the actual client unregistration (internal, no locking) @@ -604,9 +636,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) { func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) { r.mutex.Lock() defer r.mutex.Unlock() + r.ensureAvailableModelsCacheLocked() if registration, exists := r.models[modelID]; exists { - registration.QuotaExceededClients[clientID] = new(time.Now()) + now := time.Now() + registration.QuotaExceededClients[clientID] = &now + r.invalidateAvailableModelsCacheLocked() log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID) } } @@ -618,9 +653,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) { func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) { r.mutex.Lock() defer r.mutex.Unlock() + r.ensureAvailableModelsCacheLocked() if registration, exists := r.models[modelID]; exists { delete(registration.QuotaExceededClients, clientID) + r.invalidateAvailableModelsCacheLocked() // log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID) } } @@ -636,6 +673,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) { } r.mutex.Lock() defer r.mutex.Unlock() + r.ensureAvailableModelsCacheLocked() registration, exists := r.models[modelID] if !exists || registration == nil { @@ -649,6 +687,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) { } registration.SuspendedClients[clientID] = reason registration.LastUpdated = time.Now() + r.invalidateAvailableModelsCacheLocked() if reason != "" { log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason) } else { @@ -666,6 +705,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { } r.mutex.Lock() defer r.mutex.Unlock() + r.ensureAvailableModelsCacheLocked() registration, exists := r.models[modelID] if !exists || registration == nil || registration.SuspendedClients == nil { @@ -676,6 +716,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) { } delete(registration.SuspendedClients, clientID) registration.LastUpdated = time.Now() + r.invalidateAvailableModelsCacheLocked() log.Debugf("Resumed client %s for model %s", clientID, modelID) } @@ -711,22 +752,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool { // Returns: // - []map[string]any: List of available models in the requested format func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any { - r.mutex.RLock() - defer r.mutex.RUnlock() + now := time.Now() - models := make([]map[string]any, 0) + r.mutex.RLock() + if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) { + models := cloneModelMaps(cache.models) + r.mutex.RUnlock() + return models + } + r.mutex.RUnlock() + + r.mutex.Lock() + defer r.mutex.Unlock() + r.ensureAvailableModelsCacheLocked() + + if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) { + return cloneModelMaps(cache.models) + } + + models, expiresAt := r.buildAvailableModelsLocked(handlerType, now) + r.availableModelsCache[handlerType] = availableModelsCacheEntry{ + models: cloneModelMaps(models), + expiresAt: expiresAt, + } + + return models +} + +func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) { + models := make([]map[string]any, 0, len(r.models)) quotaExpiredDuration := 5 * time.Minute + var expiresAt time.Time for _, registration := range r.models { - // Check if model has any non-quota-exceeded clients availableClients := registration.Count - now := time.Now() - // Count clients that have exceeded quota but haven't recovered yet expiredClients := 0 for _, quotaTime := range registration.QuotaExceededClients { - if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + if quotaTime == nil { + continue + } + recoveryAt := quotaTime.Add(quotaExpiredDuration) + if now.Before(recoveryAt) { expiredClients++ + if expiresAt.IsZero() || recoveryAt.Before(expiresAt) { + expiresAt = recoveryAt + } } } @@ -747,7 +818,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any effectiveClients = 0 } - // Include models that have available clients, or those solely cooling down. if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { model := r.convertModelToMap(registration.Info, handlerType) if model != nil { @@ -756,7 +826,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any } } - return models + return models, expiresAt +} + +func cloneModelMaps(models []map[string]any) []map[string]any { + cloned := make([]map[string]any, 0, len(models)) + for _, model := range models { + if model == nil { + cloned = append(cloned, nil) + continue + } + copyModel := make(map[string]any, len(model)) + for key, value := range model { + copyModel[key] = cloneModelMapValue(value) + } + cloned = append(cloned, copyModel) + } + return cloned +} + +func cloneModelMapValue(value any) any { + switch typed := value.(type) { + case map[string]any: + copyMap := make(map[string]any, len(typed)) + for key, entry := range typed { + copyMap[key] = cloneModelMapValue(entry) + } + return copyMap + case []any: + copySlice := make([]any, len(typed)) + for i, entry := range typed { + copySlice[i] = cloneModelMapValue(entry) + } + return copySlice + case []string: + return append([]string(nil), typed...) + default: + return value + } } // GetAvailableModelsByProvider returns models available for the given provider identifier. @@ -872,11 +979,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { if entry.info != nil { - result = append(result, entry.info) + result = append(result, cloneModelInfo(entry.info)) continue } if ok && registration != nil && registration.Info != nil { - result = append(result, registration.Info) + result = append(result, cloneModelInfo(registration.Info)) } } } @@ -985,13 +1092,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo { if reg.Providers != nil { if count, ok := reg.Providers[provider]; ok && count > 0 { if info, ok := reg.InfoByProvider[provider]; ok && info != nil { - return info + return cloneModelInfo(info) } } } } // Fallback to global info (last registered) - return reg.Info + return cloneModelInfo(reg.Info) } return nil } @@ -1031,7 +1138,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) result["max_completion_tokens"] = model.MaxCompletionTokens } if len(model.SupportedParameters) > 0 { - result["supported_parameters"] = model.SupportedParameters + result["supported_parameters"] = append([]string(nil), model.SupportedParameters...) } return result @@ -1075,13 +1182,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) result["outputTokenLimit"] = model.OutputTokenLimit } if len(model.SupportedGenerationMethods) > 0 { - result["supportedGenerationMethods"] = model.SupportedGenerationMethods + result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...) } if len(model.SupportedInputModalities) > 0 { - result["supportedInputModalities"] = model.SupportedInputModalities + result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...) } if len(model.SupportedOutputModalities) > 0 { - result["supportedOutputModalities"] = model.SupportedOutputModalities + result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...) } return result @@ -1111,15 +1218,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() { now := time.Now() quotaExpiredDuration := 5 * time.Minute + invalidated := false for modelID, registration := range r.models { for clientID, quotaTime := range registration.QuotaExceededClients { if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration { delete(registration.QuotaExceededClients, clientID) + invalidated = true log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID) } } } + if invalidated { + r.invalidateAvailableModelsCacheLocked() + } } // GetFirstAvailableModel returns the first available model for the given handler type. @@ -1133,8 +1245,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() { // - string: The model ID of the first available model, or empty string if none available // - error: An error if no models are available func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) { - r.mutex.RLock() - defer r.mutex.RUnlock() // Get all available models for this handler type models := r.GetAvailableModels(handlerType) @@ -1194,13 +1304,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo { // Prefer client's own model info to preserve original type/owned_by if clientInfos != nil { if info, ok := clientInfos[modelID]; ok && info != nil { - result = append(result, info) + result = append(result, cloneModelInfo(info)) continue } } // Fallback to global registry (for backwards compatibility) if reg, ok := r.models[modelID]; ok && reg.Info != nil { - result = append(result, reg.Info) + result = append(result, cloneModelInfo(reg.Info)) } } return result diff --git a/internal/registry/model_registry_cache_test.go b/internal/registry/model_registry_cache_test.go new file mode 100644 index 00000000..4653167b --- /dev/null +++ b/internal/registry/model_registry_cache_test.go @@ -0,0 +1,54 @@ +package registry + +import "testing" + +func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}}) + + first := r.GetAvailableModels("openai") + if len(first) != 1 { + t.Fatalf("expected 1 model, got %d", len(first)) + } + first[0]["id"] = "mutated" + first[0]["display_name"] = "Mutated" + + second := r.GetAvailableModels("openai") + if got := second[0]["id"]; got != "m1" { + t.Fatalf("expected cached snapshot to stay isolated, got id %v", got) + } + if got := second[0]["display_name"]; got != "Model One" { + t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got) + } +} + +func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}}) + + models := r.GetAvailableModels("openai") + if len(models) != 1 { + t.Fatalf("expected 1 model, got %d", len(models)) + } + if got := models[0]["display_name"]; got != "Model One" { + t.Fatalf("expected initial display_name Model One, got %v", got) + } + + r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}}) + models = r.GetAvailableModels("openai") + if got := models[0]["display_name"]; got != "Model One Updated" { + t.Fatalf("expected updated display_name after cache invalidation, got %v", got) + } + + r.SuspendClientModel("client-1", "m1", "manual") + models = r.GetAvailableModels("openai") + if len(models) != 0 { + t.Fatalf("expected no available models after suspension, got %d", len(models)) + } + + r.ResumeClientModel("client-1", "m1") + models = r.GetAvailableModels("openai") + if len(models) != 1 { + t.Fatalf("expected model to reappear after resume, got %d", len(models)) + } +} diff --git a/internal/registry/model_registry_safety_test.go b/internal/registry/model_registry_safety_test.go new file mode 100644 index 00000000..5f4f65d2 --- /dev/null +++ b/internal/registry/model_registry_safety_test.go @@ -0,0 +1,149 @@ +package registry + +import ( + "testing" + "time" +) + +func TestGetModelInfoReturnsClone(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "gemini", []*ModelInfo{{ + ID: "m1", + DisplayName: "Model One", + Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}}, + }}) + + first := r.GetModelInfo("m1", "gemini") + if first == nil { + t.Fatal("expected model info") + } + first.DisplayName = "mutated" + first.Thinking.Levels[0] = "mutated" + + second := r.GetModelInfo("m1", "gemini") + if second.DisplayName != "Model One" { + t.Fatalf("expected cloned display name, got %q", second.DisplayName) + } + if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" { + t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking) + } +} + +func TestGetModelsForClientReturnsClones(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "gemini", []*ModelInfo{{ + ID: "m1", + DisplayName: "Model One", + Thinking: &ThinkingSupport{Levels: []string{"low", "high"}}, + }}) + + first := r.GetModelsForClient("client-1") + if len(first) != 1 || first[0] == nil { + t.Fatalf("expected one model, got %+v", first) + } + first[0].DisplayName = "mutated" + first[0].Thinking.Levels[0] = "mutated" + + second := r.GetModelsForClient("client-1") + if len(second) != 1 || second[0] == nil { + t.Fatalf("expected one model on second fetch, got %+v", second) + } + if second[0].DisplayName != "Model One" { + t.Fatalf("expected cloned display name, got %q", second[0].DisplayName) + } + if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" { + t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking) + } +} + +func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "gemini", []*ModelInfo{{ + ID: "m1", + DisplayName: "Model One", + Thinking: &ThinkingSupport{Levels: []string{"low", "high"}}, + }}) + + first := r.GetAvailableModelsByProvider("gemini") + if len(first) != 1 || first[0] == nil { + t.Fatalf("expected one model, got %+v", first) + } + first[0].DisplayName = "mutated" + first[0].Thinking.Levels[0] = "mutated" + + second := r.GetAvailableModelsByProvider("gemini") + if len(second) != 1 || second[0] == nil { + t.Fatalf("expected one model on second fetch, got %+v", second) + } + if second[0].DisplayName != "Model One" { + t.Fatalf("expected cloned display name, got %q", second[0].DisplayName) + } + if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" { + t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking) + } +} + +func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}}) + r.SetModelQuotaExceeded("client-1", "m1") + if models := r.GetAvailableModels("openai"); len(models) != 1 { + t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models)) + } + + r.mutex.Lock() + quotaTime := time.Now().Add(-6 * time.Minute) + r.models["m1"].QuotaExceededClients["client-1"] = "aTime + r.mutex.Unlock() + + r.CleanupExpiredQuotas() + + if count := r.GetModelCount("m1"); count != 1 { + t.Fatalf("expected model count 1 after cleanup, got %d", count) + } + models := r.GetAvailableModels("openai") + if len(models) != 1 { + t.Fatalf("expected model to stay available after cleanup, got %d", len(models)) + } + if got := models[0]["id"]; got != "m1" { + t.Fatalf("expected model id m1, got %v", got) + } +} + +func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) { + r := newTestModelRegistry() + r.RegisterClient("client-1", "openai", []*ModelInfo{{ + ID: "m1", + DisplayName: "Model One", + SupportedParameters: []string{"temperature", "top_p"}, + }}) + + first := r.GetAvailableModels("openai") + if len(first) != 1 { + t.Fatalf("expected one model, got %d", len(first)) + } + params, ok := first[0]["supported_parameters"].([]string) + if !ok || len(params) != 2 { + t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"]) + } + params[0] = "mutated" + + second := r.GetAvailableModels("openai") + params, ok = second[0]["supported_parameters"].([]string) + if !ok || len(params) != 2 || params[0] != "temperature" { + t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"]) + } +} + +func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) { + first := LookupModelInfo("glm-4.6") + if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 { + t.Fatalf("expected static model with thinking levels, got %+v", first) + } + first.Thinking.Levels[0] = "mutated" + + second := LookupModelInfo("glm-4.6") + if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" { + t.Fatalf("expected static lookup clone, got %+v", second) + } +} diff --git a/internal/registry/model_updater.go b/internal/registry/model_updater.go new file mode 100644 index 00000000..84c9d6aa --- /dev/null +++ b/internal/registry/model_updater.go @@ -0,0 +1,198 @@ +package registry + +import ( + "context" + _ "embed" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + modelsFetchTimeout = 30 * time.Second +) + +var modelsURLs = []string{ + "https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json", + "https://models.router-for.me/models.json", +} + +//go:embed models/models.json +var embeddedModelsJSON []byte + +type modelStore struct { + mu sync.RWMutex + data *staticModelsJSON +} + +var modelsCatalogStore = &modelStore{} + +var updaterOnce sync.Once + +func init() { + // Load embedded data as fallback on startup. + if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil { + panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err)) + } +} + +// StartModelsUpdater runs a one-time models refresh on startup. +// It blocks until the startup fetch attempt finishes so service initialization +// can wait for the refreshed catalog before registering auth-backed models. +// Safe to call multiple times; only one refresh will run. +func StartModelsUpdater(ctx context.Context) { + updaterOnce.Do(func() { + runModelsUpdater(ctx) + }) +} + +func runModelsUpdater(ctx context.Context) { + // Try network fetch once on startup, then stop. + // Periodic refresh is disabled - models are only refreshed at startup. + tryRefreshModels(ctx) +} + +func tryRefreshModels(ctx context.Context) { + client := &http.Client{Timeout: modelsFetchTimeout} + for _, url := range modelsURLs { + reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout) + req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil) + if err != nil { + cancel() + log.Debugf("models fetch request creation failed for %s: %v", url, err) + continue + } + + resp, err := client.Do(req) + if err != nil { + cancel() + log.Debugf("models fetch failed from %s: %v", url, err) + continue + } + + if resp.StatusCode != 200 { + resp.Body.Close() + cancel() + log.Debugf("models fetch returned %d from %s", resp.StatusCode, url) + continue + } + + data, err := io.ReadAll(resp.Body) + resp.Body.Close() + cancel() + + if err != nil { + log.Debugf("models fetch read error from %s: %v", url, err) + continue + } + + if err := loadModelsFromBytes(data, url); err != nil { + log.Warnf("models parse failed from %s: %v", url, err) + continue + } + + log.Infof("models updated from %s", url) + return + } + log.Warn("models refresh failed from all URLs, using current data") +} + +func loadModelsFromBytes(data []byte, source string) error { + var parsed staticModelsJSON + if err := json.Unmarshal(data, &parsed); err != nil { + return fmt.Errorf("%s: decode models catalog: %w", source, err) + } + if err := validateModelsCatalog(&parsed); err != nil { + return fmt.Errorf("%s: validate models catalog: %w", source, err) + } + + modelsCatalogStore.mu.Lock() + modelsCatalogStore.data = &parsed + modelsCatalogStore.mu.Unlock() + return nil +} + +func getModels() *staticModelsJSON { + modelsCatalogStore.mu.RLock() + defer modelsCatalogStore.mu.RUnlock() + return modelsCatalogStore.data +} + +func validateModelsCatalog(data *staticModelsJSON) error { + if data == nil { + return fmt.Errorf("catalog is nil") + } + + requiredSections := []struct { + name string + models []*ModelInfo + }{ + {name: "claude", models: data.Claude}, + {name: "gemini", models: data.Gemini}, + {name: "vertex", models: data.Vertex}, + {name: "gemini-cli", models: data.GeminiCLI}, + {name: "aistudio", models: data.AIStudio}, + {name: "codex-free", models: data.CodexFree}, + {name: "codex-team", models: data.CodexTeam}, + {name: "codex-plus", models: data.CodexPlus}, + {name: "codex-pro", models: data.CodexPro}, + {name: "qwen", models: data.Qwen}, + {name: "iflow", models: data.IFlow}, + {name: "kimi", models: data.Kimi}, + } + + for _, section := range requiredSections { + if err := validateModelSection(section.name, section.models); err != nil { + return err + } + } + if err := validateAntigravitySection(data.Antigravity); err != nil { + return err + } + return nil +} + +func validateModelSection(section string, models []*ModelInfo) error { + if len(models) == 0 { + return fmt.Errorf("%s section is empty", section) + } + + seen := make(map[string]struct{}, len(models)) + for i, model := range models { + if model == nil { + return fmt.Errorf("%s[%d] is null", section, i) + } + modelID := strings.TrimSpace(model.ID) + if modelID == "" { + return fmt.Errorf("%s[%d] has empty id", section, i) + } + if _, exists := seen[modelID]; exists { + return fmt.Errorf("%s contains duplicate model id %q", section, modelID) + } + seen[modelID] = struct{}{} + } + return nil +} + +func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error { + if len(configs) == 0 { + return fmt.Errorf("antigravity section is empty") + } + + for modelID, cfg := range configs { + trimmedID := strings.TrimSpace(modelID) + if trimmedID == "" { + return fmt.Errorf("antigravity contains empty model id") + } + if cfg == nil { + return fmt.Errorf("antigravity[%q] is null", trimmedID) + } + } + return nil +} diff --git a/internal/registry/models/models.json b/internal/registry/models/models.json new file mode 100644 index 00000000..5f919f9f --- /dev/null +++ b/internal/registry/models/models.json @@ -0,0 +1,2598 @@ +{ + "claude": [ + { + "id": "claude-haiku-4-5-20251001", + "object": "model", + "created": 1759276800, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4.5 Haiku", + "context_length": 200000, + "max_completion_tokens": 64000, + "thinking": { + "min": 1024, + "max": 128000, + "zero_allowed": true + } + }, + { + "id": "claude-sonnet-4-5-20250929", + "object": "model", + "created": 1759104000, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4.5 Sonnet", + "context_length": 200000, + "max_completion_tokens": 64000, + "thinking": { + "min": 1024, + "max": 128000, + "zero_allowed": true + } + }, + { + "id": "claude-sonnet-4-6", + "object": "model", + "created": 1771372800, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4.6 Sonnet", + "context_length": 200000, + "max_completion_tokens": 64000, + "thinking": { + "min": 1024, + "max": 128000, + "zero_allowed": true, + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "claude-opus-4-6", + "object": "model", + "created": 1770318000, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4.6 Opus", + "description": "Premium model combining maximum intelligence with practical performance", + "context_length": 1000000, + "max_completion_tokens": 128000, + "thinking": { + "min": 1024, + "max": 128000, + "zero_allowed": true, + "levels": [ + "low", + "medium", + "high", + "max" + ] + } + }, + { + "id": "claude-opus-4-5-20251101", + "object": "model", + "created": 1761955200, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4.5 Opus", + "description": "Premium model combining maximum intelligence with practical performance", + "context_length": 200000, + "max_completion_tokens": 64000, + "thinking": { + "min": 1024, + "max": 128000, + "zero_allowed": true + } + }, + { + "id": "claude-opus-4-1-20250805", + "object": "model", + "created": 1722945600, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4.1 Opus", + "context_length": 200000, + "max_completion_tokens": 32000, + "thinking": { + "min": 1024, + "max": 128000 + } + }, + { + "id": "claude-opus-4-20250514", + "object": "model", + "created": 1715644800, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4 Opus", + "context_length": 200000, + "max_completion_tokens": 32000, + "thinking": { + "min": 1024, + "max": 128000 + } + }, + { + "id": "claude-sonnet-4-20250514", + "object": "model", + "created": 1715644800, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 4 Sonnet", + "context_length": 200000, + "max_completion_tokens": 64000, + "thinking": { + "min": 1024, + "max": 128000 + } + }, + { + "id": "claude-3-7-sonnet-20250219", + "object": "model", + "created": 1708300800, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 3.7 Sonnet", + "context_length": 128000, + "max_completion_tokens": 8192, + "thinking": { + "min": 1024, + "max": 128000 + } + }, + { + "id": "claude-3-5-haiku-20241022", + "object": "model", + "created": 1729555200, + "owned_by": "anthropic", + "type": "claude", + "display_name": "Claude 3.5 Haiku", + "context_length": 128000, + "max_completion_tokens": 8192 + } + ], + "gemini": [ + { + "id": "gemini-2.5-pro", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Pro", + "name": "models/gemini-2.5-pro", + "version": "2.5", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash", + "name": "models/gemini-2.5-flash", + "version": "001", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash-lite", + "object": "model", + "created": 1753142400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash Lite", + "name": "models/gemini-2.5-flash-lite", + "version": "2.5", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3-pro-preview", + "object": "model", + "created": 1737158400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Pro Preview", + "name": "models/gemini-3-pro-preview", + "version": "3.0", + "description": "Gemini 3 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "gemini-3.1-pro-preview", + "object": "model", + "created": 1771459200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Pro Preview", + "name": "models/gemini-3.1-pro-preview", + "version": "3.1", + "description": "Gemini 3.1 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "gemini-3.1-flash-image-preview", + "object": "model", + "created": 1771459200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Flash Image Preview", + "name": "models/gemini-3.1-flash-image-preview", + "version": "3.1", + "description": "Gemini 3.1 Flash Image Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + { + "id": "gemini-3-flash-preview", + "object": "model", + "created": 1765929600, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Flash Preview", + "name": "models/gemini-3-flash-preview", + "version": "3.0", + "description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gemini-3.1-flash-lite-preview", + "object": "model", + "created": 1776288000, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Flash Lite Preview", + "name": "models/gemini-3.1-flash-lite-preview", + "version": "3.1", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + { + "id": "gemini-3-pro-image-preview", + "object": "model", + "created": 1737158400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Pro Image Preview", + "name": "models/gemini-3-pro-image-preview", + "version": "3.0", + "description": "Gemini 3 Pro Image Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + } + ], + "vertex": [ + { + "id": "gemini-2.5-pro", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Pro", + "name": "models/gemini-2.5-pro", + "version": "2.5", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash", + "name": "models/gemini-2.5-flash", + "version": "001", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash-lite", + "object": "model", + "created": 1753142400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash Lite", + "name": "models/gemini-2.5-flash-lite", + "version": "2.5", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3-pro-preview", + "object": "model", + "created": 1737158400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Pro Preview", + "name": "models/gemini-3-pro-preview", + "version": "3.0", + "description": "Gemini 3 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "gemini-3-flash-preview", + "object": "model", + "created": 1765929600, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Flash Preview", + "name": "models/gemini-3-flash-preview", + "version": "3.0", + "description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gemini-3.1-pro-preview", + "object": "model", + "created": 1771459200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Pro Preview", + "name": "models/gemini-3.1-pro-preview", + "version": "3.1", + "description": "Gemini 3.1 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "gemini-3.1-flash-image-preview", + "object": "model", + "created": 1771459200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Flash Image Preview", + "name": "models/gemini-3.1-flash-image-preview", + "version": "3.1", + "description": "Gemini 3.1 Flash Image Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + { + "id": "gemini-3.1-flash-lite-preview", + "object": "model", + "created": 1776288000, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Flash Lite Preview", + "name": "models/gemini-3.1-flash-lite-preview", + "version": "3.1", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + { + "id": "gemini-3-pro-image-preview", + "object": "model", + "created": 1737158400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Pro Image Preview", + "name": "models/gemini-3-pro-image-preview", + "version": "3.0", + "description": "Gemini 3 Pro Image Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "imagen-4.0-generate-001", + "object": "model", + "created": 1750000000, + "owned_by": "google", + "type": "gemini", + "display_name": "Imagen 4.0 Generate", + "name": "models/imagen-4.0-generate-001", + "version": "4.0", + "description": "Imagen 4.0 image generation model", + "supportedGenerationMethods": [ + "predict" + ] + }, + { + "id": "imagen-4.0-ultra-generate-001", + "object": "model", + "created": 1750000000, + "owned_by": "google", + "type": "gemini", + "display_name": "Imagen 4.0 Ultra Generate", + "name": "models/imagen-4.0-ultra-generate-001", + "version": "4.0", + "description": "Imagen 4.0 Ultra high-quality image generation model", + "supportedGenerationMethods": [ + "predict" + ] + }, + { + "id": "imagen-3.0-generate-002", + "object": "model", + "created": 1740000000, + "owned_by": "google", + "type": "gemini", + "display_name": "Imagen 3.0 Generate", + "name": "models/imagen-3.0-generate-002", + "version": "3.0", + "description": "Imagen 3.0 image generation model", + "supportedGenerationMethods": [ + "predict" + ] + }, + { + "id": "imagen-3.0-fast-generate-001", + "object": "model", + "created": 1740000000, + "owned_by": "google", + "type": "gemini", + "display_name": "Imagen 3.0 Fast Generate", + "name": "models/imagen-3.0-fast-generate-001", + "version": "3.0", + "description": "Imagen 3.0 fast image generation model", + "supportedGenerationMethods": [ + "predict" + ] + }, + { + "id": "imagen-4.0-fast-generate-001", + "object": "model", + "created": 1750000000, + "owned_by": "google", + "type": "gemini", + "display_name": "Imagen 4.0 Fast Generate", + "name": "models/imagen-4.0-fast-generate-001", + "version": "4.0", + "description": "Imagen 4.0 fast image generation model", + "supportedGenerationMethods": [ + "predict" + ] + } + ], + "gemini-cli": [ + { + "id": "gemini-2.5-pro", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Pro", + "name": "models/gemini-2.5-pro", + "version": "2.5", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash", + "name": "models/gemini-2.5-flash", + "version": "001", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash-lite", + "object": "model", + "created": 1753142400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash Lite", + "name": "models/gemini-2.5-flash-lite", + "version": "2.5", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3-pro-preview", + "object": "model", + "created": 1737158400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Pro Preview", + "name": "models/gemini-3-pro-preview", + "version": "3.0", + "description": "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "gemini-3.1-pro-preview", + "object": "model", + "created": 1771459200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Pro Preview", + "name": "models/gemini-3.1-pro-preview", + "version": "3.1", + "description": "Gemini 3.1 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + { + "id": "gemini-3-flash-preview", + "object": "model", + "created": 1765929600, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Flash Preview", + "name": "models/gemini-3-flash-preview", + "version": "3.0", + "description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gemini-3.1-flash-lite-preview", + "object": "model", + "created": 1776288000, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Flash Lite Preview", + "name": "models/gemini-3.1-flash-lite-preview", + "version": "3.1", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + } + ], + "aistudio": [ + { + "id": "gemini-2.5-pro", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Pro", + "name": "models/gemini-2.5-pro", + "version": "2.5", + "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash", + "name": "models/gemini-2.5-flash", + "version": "001", + "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash-lite", + "object": "model", + "created": 1753142400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash Lite", + "name": "models/gemini-2.5-flash-lite", + "version": "2.5", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3-pro-preview", + "object": "model", + "created": 1737158400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Pro Preview", + "name": "models/gemini-3-pro-preview", + "version": "3.0", + "description": "Gemini 3 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3.1-pro-preview", + "object": "model", + "created": 1771459200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Pro Preview", + "name": "models/gemini-3.1-pro-preview", + "version": "3.1", + "description": "Gemini 3.1 Pro Preview", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3-flash-preview", + "object": "model", + "created": 1765929600, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3 Flash Preview", + "name": "models/gemini-3-flash-preview", + "version": "3.0", + "description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-3.1-flash-lite-preview", + "object": "model", + "created": 1776288000, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 3.1 Flash Lite Preview", + "name": "models/gemini-3.1-flash-lite-preview", + "version": "3.1", + "description": "Our smallest and most cost effective model, built for at scale usage.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + { + "id": "gemini-pro-latest", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini Pro Latest", + "name": "models/gemini-pro-latest", + "version": "2.5", + "description": "Latest release of Gemini Pro", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true + } + }, + { + "id": "gemini-flash-latest", + "object": "model", + "created": 1750118400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini Flash Latest", + "name": "models/gemini-flash-latest", + "version": "2.5", + "description": "Latest release of Gemini Flash", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-flash-lite-latest", + "object": "model", + "created": 1753142400, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini Flash-Lite Latest", + "name": "models/gemini-flash-lite-latest", + "version": "2.5", + "description": "Latest release of Gemini Flash-Lite", + "inputTokenLimit": 1048576, + "outputTokenLimit": 65536, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ], + "thinking": { + "min": 512, + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "gemini-2.5-flash-image", + "object": "model", + "created": 1759363200, + "owned_by": "google", + "type": "gemini", + "display_name": "Gemini 2.5 Flash Image", + "name": "models/gemini-2.5-flash-image", + "version": "2.5", + "description": "State-of-the-art image generation and editing model.", + "inputTokenLimit": 1048576, + "outputTokenLimit": 8192, + "supportedGenerationMethods": [ + "generateContent", + "countTokens", + "createCachedContent", + "batchGenerateContent" + ] + } + ], + "codex-free": [ + { + "id": "gpt-5", + "object": "model", + "created": 1754524800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5-2025-08-07", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex", + "object": "model", + "created": 1757894400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex", + "version": "gpt-5-2025-09-15", + "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex-mini", + "object": "model", + "created": 1762473600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex Mini", + "version": "gpt-5-2025-11-07", + "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-mini", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Mini", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-max", + "object": "model", + "created": 1763424000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Max", + "version": "gpt-5.1-max", + "description": "Stable version of GPT 5.1 Codex Max", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2-codex", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2 Codex", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + } + ], + "codex-team": [ + { + "id": "gpt-5", + "object": "model", + "created": 1754524800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5-2025-08-07", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex", + "object": "model", + "created": 1757894400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex", + "version": "gpt-5-2025-09-15", + "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex-mini", + "object": "model", + "created": 1762473600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex Mini", + "version": "gpt-5-2025-11-07", + "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-mini", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Mini", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-max", + "object": "model", + "created": 1763424000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Max", + "version": "gpt-5.1-max", + "description": "Stable version of GPT 5.1 Codex Max", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2-codex", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2 Codex", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.3-codex", + "object": "model", + "created": 1770307200, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.3 Codex", + "version": "gpt-5.3", + "description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.4", + "object": "model", + "created": 1772668800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4", + "version": "gpt-5.4", + "description": "Stable version of GPT 5.4", + "context_length": 1050000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + } + ], + "codex-plus": [ + { + "id": "gpt-5", + "object": "model", + "created": 1754524800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5-2025-08-07", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex", + "object": "model", + "created": 1757894400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex", + "version": "gpt-5-2025-09-15", + "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex-mini", + "object": "model", + "created": 1762473600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex Mini", + "version": "gpt-5-2025-11-07", + "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-mini", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Mini", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-max", + "object": "model", + "created": 1763424000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Max", + "version": "gpt-5.1-max", + "description": "Stable version of GPT 5.1 Codex Max", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2-codex", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2 Codex", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.3-codex", + "object": "model", + "created": 1770307200, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.3 Codex", + "version": "gpt-5.3", + "description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.3-codex-spark", + "object": "model", + "created": 1770912000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.3 Codex Spark", + "version": "gpt-5.3", + "description": "Ultra-fast coding model.", + "context_length": 128000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.4", + "object": "model", + "created": 1772668800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4", + "version": "gpt-5.4", + "description": "Stable version of GPT 5.4", + "context_length": 1050000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + } + ], + "codex-pro": [ + { + "id": "gpt-5", + "object": "model", + "created": 1754524800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5-2025-08-07", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex", + "object": "model", + "created": 1757894400, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex", + "version": "gpt-5-2025-09-15", + "description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5-codex-mini", + "object": "model", + "created": 1762473600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5 Codex Mini", + "version": "gpt-5-2025-11-07", + "description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-mini", + "object": "model", + "created": 1762905600, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Mini", + "version": "gpt-5.1-2025-11-12", + "description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high" + ] + } + }, + { + "id": "gpt-5.1-codex-max", + "object": "model", + "created": 1763424000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.1 Codex Max", + "version": "gpt-5.1-max", + "description": "Stable version of GPT 5.1 Codex Max", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "none", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.2-codex", + "object": "model", + "created": 1765440000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.2 Codex", + "version": "gpt-5.2", + "description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.3-codex", + "object": "model", + "created": 1770307200, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.3 Codex", + "version": "gpt-5.3", + "description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.", + "context_length": 400000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.3-codex-spark", + "object": "model", + "created": 1770912000, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.3 Codex Spark", + "version": "gpt-5.3", + "description": "Ultra-fast coding model.", + "context_length": 128000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "gpt-5.4", + "object": "model", + "created": 1772668800, + "owned_by": "openai", + "type": "openai", + "display_name": "GPT 5.4", + "version": "gpt-5.4", + "description": "Stable version of GPT 5.4", + "context_length": 1050000, + "max_completion_tokens": 128000, + "supported_parameters": [ + "tools" + ], + "thinking": { + "levels": [ + "low", + "medium", + "high", + "xhigh" + ] + } + } + ], + "qwen": [ + { + "id": "qwen3-coder-plus", + "object": "model", + "created": 1753228800, + "owned_by": "qwen", + "type": "qwen", + "display_name": "Qwen3 Coder Plus", + "version": "3.0", + "description": "Advanced code generation and understanding model", + "context_length": 32768, + "max_completion_tokens": 8192, + "supported_parameters": [ + "temperature", + "top_p", + "max_tokens", + "stream", + "stop" + ] + }, + { + "id": "qwen3-coder-flash", + "object": "model", + "created": 1753228800, + "owned_by": "qwen", + "type": "qwen", + "display_name": "Qwen3 Coder Flash", + "version": "3.0", + "description": "Fast code generation model", + "context_length": 8192, + "max_completion_tokens": 2048, + "supported_parameters": [ + "temperature", + "top_p", + "max_tokens", + "stream", + "stop" + ] + }, + { + "id": "coder-model", + "object": "model", + "created": 1771171200, + "owned_by": "qwen", + "type": "qwen", + "display_name": "Qwen 3.5 Plus", + "version": "3.5", + "description": "efficient hybrid model with leading coding performance", + "context_length": 1048576, + "max_completion_tokens": 65536, + "supported_parameters": [ + "temperature", + "top_p", + "max_tokens", + "stream", + "stop" + ] + }, + { + "id": "vision-model", + "object": "model", + "created": 1758672000, + "owned_by": "qwen", + "type": "qwen", + "display_name": "Qwen3 Vision Model", + "version": "3.0", + "description": "Vision model model", + "context_length": 32768, + "max_completion_tokens": 2048, + "supported_parameters": [ + "temperature", + "top_p", + "max_tokens", + "stream", + "stop" + ] + } + ], + "iflow": [ + { + "id": "qwen3-coder-plus", + "object": "model", + "created": 1753228800, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-Coder-Plus", + "description": "Qwen3 Coder Plus code generation" + }, + { + "id": "qwen3-max", + "object": "model", + "created": 1758672000, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-Max", + "description": "Qwen3 flagship model" + }, + { + "id": "qwen3-vl-plus", + "object": "model", + "created": 1758672000, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-VL-Plus", + "description": "Qwen3 multimodal vision-language" + }, + { + "id": "qwen3-max-preview", + "object": "model", + "created": 1757030400, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-Max-Preview", + "description": "Qwen3 Max preview build", + "thinking": { + "levels": [ + "none", + "auto", + "minimal", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "glm-4.6", + "object": "model", + "created": 1759190400, + "owned_by": "iflow", + "type": "iflow", + "display_name": "GLM-4.6", + "description": "Zhipu GLM 4.6 general model", + "thinking": { + "levels": [ + "none", + "auto", + "minimal", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "kimi-k2", + "object": "model", + "created": 1752192000, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Kimi-K2", + "description": "Moonshot Kimi K2 general model" + }, + { + "id": "deepseek-v3.2", + "object": "model", + "created": 1759104000, + "owned_by": "iflow", + "type": "iflow", + "display_name": "DeepSeek-V3.2-Exp", + "description": "DeepSeek V3.2 experimental", + "thinking": { + "levels": [ + "none", + "auto", + "minimal", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "deepseek-v3.1", + "object": "model", + "created": 1756339200, + "owned_by": "iflow", + "type": "iflow", + "display_name": "DeepSeek-V3.1-Terminus", + "description": "DeepSeek V3.1 Terminus", + "thinking": { + "levels": [ + "none", + "auto", + "minimal", + "low", + "medium", + "high", + "xhigh" + ] + } + }, + { + "id": "deepseek-r1", + "object": "model", + "created": 1737331200, + "owned_by": "iflow", + "type": "iflow", + "display_name": "DeepSeek-R1", + "description": "DeepSeek reasoning model R1" + }, + { + "id": "deepseek-v3", + "object": "model", + "created": 1734307200, + "owned_by": "iflow", + "type": "iflow", + "display_name": "DeepSeek-V3-671B", + "description": "DeepSeek V3 671B" + }, + { + "id": "qwen3-32b", + "object": "model", + "created": 1747094400, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-32B", + "description": "Qwen3 32B" + }, + { + "id": "qwen3-235b-a22b-thinking-2507", + "object": "model", + "created": 1753401600, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-235B-A22B-Thinking", + "description": "Qwen3 235B A22B Thinking (2507)" + }, + { + "id": "qwen3-235b-a22b-instruct", + "object": "model", + "created": 1753401600, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-235B-A22B-Instruct", + "description": "Qwen3 235B A22B Instruct" + }, + { + "id": "qwen3-235b", + "object": "model", + "created": 1753401600, + "owned_by": "iflow", + "type": "iflow", + "display_name": "Qwen3-235B-A22B", + "description": "Qwen3 235B A22B" + }, + { + "id": "iflow-rome-30ba3b", + "object": "model", + "created": 1736899200, + "owned_by": "iflow", + "type": "iflow", + "display_name": "iFlow-ROME", + "description": "iFlow Rome 30BA3B model" + } + ], + "kimi": [ + { + "id": "kimi-k2", + "object": "model", + "created": 1752192000, + "owned_by": "moonshot", + "type": "kimi", + "display_name": "Kimi K2", + "description": "Kimi K2 - Moonshot AI's flagship coding model", + "context_length": 131072, + "max_completion_tokens": 32768 + }, + { + "id": "kimi-k2-thinking", + "object": "model", + "created": 1762387200, + "owned_by": "moonshot", + "type": "kimi", + "display_name": "Kimi K2 Thinking", + "description": "Kimi K2 Thinking - Extended reasoning model", + "context_length": 131072, + "max_completion_tokens": 32768, + "thinking": { + "min": 1024, + "max": 32000, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + { + "id": "kimi-k2.5", + "object": "model", + "created": 1769472000, + "owned_by": "moonshot", + "type": "kimi", + "display_name": "Kimi K2.5", + "description": "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities", + "context_length": 131072, + "max_completion_tokens": 32768, + "thinking": { + "min": 1024, + "max": 32000, + "zero_allowed": true, + "dynamic_allowed": true + } + } + ], + "antigravity": { + "claude-opus-4-6-thinking": { + "thinking": { + "min": 1024, + "max": 64000, + "zero_allowed": true, + "dynamic_allowed": true + }, + "max_completion_tokens": 64000 + }, + "claude-sonnet-4-6": { + "thinking": { + "min": 1024, + "max": 64000, + "zero_allowed": true, + "dynamic_allowed": true + }, + "max_completion_tokens": 64000 + }, + "gemini-2.5-flash": { + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + "gemini-2.5-flash-lite": { + "thinking": { + "max": 24576, + "zero_allowed": true, + "dynamic_allowed": true + } + }, + "gemini-3-flash": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "low", + "medium", + "high" + ] + } + }, + "gemini-3-pro-high": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + "gemini-3-pro-low": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + "gemini-3.1-flash-image": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + "gemini-3.1-flash-lite-preview": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "minimal", + "high" + ] + } + }, + "gemini-3.1-pro-high": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + "gemini-3.1-pro-low": { + "thinking": { + "min": 128, + "max": 32768, + "dynamic_allowed": true, + "levels": [ + "low", + "high" + ] + } + }, + "gpt-oss-120b-medium": {} + } +} \ No newline at end of file diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 7d0ddcf2..82b12a2f 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte { } return true }) + } else if system.Type == gjson.String && system.String() != "" { + partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}` + partJSON, _ = sjson.Set(partJSON, "text", system.String()) + result += "," + partJSON } result += "]" @@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int { return count } -func normalizeTTLForBlock(obj map[string]any, seen5m *bool) { +func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool { ccRaw, exists := obj["cache_control"] if !exists { - return + return false } cc, ok := asObject(ccRaw) if !ok { *seen5m = true - return + return false } ttlRaw, ttlExists := cc["ttl"] ttl, ttlIsString := ttlRaw.(string) if !ttlExists || !ttlIsString || ttl != "1h" { *seen5m = true - return + return false } if *seen5m { delete(cc, "ttl") + return true } + return false } func findLastCacheControlIndex(arr []any) int { @@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte { } seen5m := false + modified := false if tools, ok := asArray(root["tools"]); ok { for _, tool := range tools { if obj, ok := asObject(tool); ok { - normalizeTTLForBlock(obj, &seen5m) + if normalizeTTLForBlock(obj, &seen5m) { + modified = true + } } } } @@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte { if system, ok := asArray(root["system"]); ok { for _, item := range system { if obj, ok := asObject(item); ok { - normalizeTTLForBlock(obj, &seen5m) + if normalizeTTLForBlock(obj, &seen5m) { + modified = true + } } } } @@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte { } for _, item := range content { if obj, ok := asObject(item); ok { - normalizeTTLForBlock(obj, &seen5m) + if normalizeTTLForBlock(obj, &seen5m) { + modified = true + } } } } } + if !modified { + return payload + } return marshalPayloadObject(payload, root) } diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go index c4a4d644..fa458c0f 100644 --- a/internal/runtime/executor/claude_executor_test.go +++ b/internal/runtime/executor/claude_executor_test.go @@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) { } } +func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) { + // Payload where no TTL normalization is needed (all blocks use 1h with no + // preceding 5m block). The text intentionally contains HTML chars (<, >, &) + // that json.Marshal would escape to \u003c etc., altering byte identity. + payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"foo & bar","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + + out := normalizeCacheControlTTL(payload) + + if !bytes.Equal(out, payload) { + t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out) + } +} + func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { payload := []byte(`{ "tools": [ @@ -829,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity executor := NewClaudeExecutor(&config.Config{}) // Inject Accept-Encoding via the custom header attribute mechanism. auth := &cliproxyauth.Auth{Attributes: map[string]string{ - "api_key": "key-123", - "base_url": server.URL, + "api_key": "key-123", + "base_url": server.URL, "header:Accept-Encoding": "gzip, deflate, br, zstd", }} payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`) @@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te t.Errorf("error message should contain decompressed JSON, got: %q", err.Error()) } } + +// Test case 1: String system prompt is preserved and converted to a content block +func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) { + payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`) + + out := checkSystemInstructionsWithMode(payload, false) + + system := gjson.GetBytes(out, "system") + if !system.IsArray() { + t.Fatalf("system should be an array, got %s", system.Type) + } + + blocks := system.Array() + if len(blocks) != 3 { + t.Fatalf("expected 3 system blocks, got %d", len(blocks)) + } + + if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") { + t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String()) + } + if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." { + t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String()) + } + if blocks[2].Get("text").String() != "You are a helpful assistant." { + t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String()) + } + if blocks[2].Get("cache_control.type").String() != "ephemeral" { + t.Fatalf("blocks[2] should have cache_control.type=ephemeral") + } +} + +// Test case 2: Strict mode drops the string system prompt +func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) { + payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`) + + out := checkSystemInstructionsWithMode(payload, true) + + blocks := gjson.GetBytes(out, "system").Array() + if len(blocks) != 2 { + t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks)) + } +} + +// Test case 3: Empty string system prompt does not produce a spurious block +func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) { + payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`) + + out := checkSystemInstructionsWithMode(payload, false) + + blocks := gjson.GetBytes(out, "system").Array() + if len(blocks) != 2 { + t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks)) + } +} + +// Test case 4: Array system prompt is unaffected by the string handling +func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) { + payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`) + + out := checkSystemInstructionsWithMode(payload, false) + + blocks := gjson.GetBytes(out, "system").Array() + if len(blocks) != 3 { + t.Fatalf("expected 3 system blocks, got %d", len(blocks)) + } + if blocks[2].Get("text").String() != "Be concise." { + t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String()) + } +} + +// Test case 5: Special characters in string system prompt survive conversion +func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) { + payload := []byte(`{"system":"Use tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`) + + out := checkSystemInstructionsWithMode(payload, false) + + blocks := gjson.GetBytes(out, "system").Array() + if len(blocks) != 3 { + t.Fatalf("expected 3 system blocks, got %d", len(blocks)) + } + if blocks[2].Get("text").String() != `Use tags & "quotes" in output.` { + t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String()) + } +} diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 2a4f4a3f..571a23a1 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -23,6 +23,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) * return dialer } - parsedURL, errParse := url.Parse(proxyURL) + setting, errParse := proxyutil.Parse(proxyURL) if errParse != nil { - log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse) + log.Errorf("codex websockets executor: %v", errParse) return dialer } - switch parsedURL.Scheme { + switch setting.Mode { + case proxyutil.ModeDirect: + dialer.Proxy = nil + return dialer + case proxyutil.ModeProxy: + default: + return dialer + } + + switch setting.URL.Scheme { case "socks5": var proxyAuth *proxy.Auth - if parsedURL.User != nil { - username := parsedURL.User.Username() - password, _ := parsedURL.User.Password() + if setting.URL.User != nil { + username := setting.URL.User.Username() + password, _ := setting.URL.User.Password() proxyAuth = &proxy.Auth{User: username, Password: password} } - socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) + socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct) if errSOCKS5 != nil { log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5) return dialer @@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) * return socksDialer.Dial(network, addr) } case "http", "https": - dialer.Proxy = http.ProxyURL(parsedURL) + dialer.Proxy = http.ProxyURL(setting.URL) default: - log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme) + log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme) } return dialer diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index e1335386..755ac56a 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/tidwall/gjson" ) @@ -187,3 +188,16 @@ func contextWithGinHeaders(headers map[string]string) context.Context { } return context.WithValue(context.Background(), "gin", ginCtx) } + +func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { + t.Parallel() + + dialer := newProxyAwareWebsocketDialer( + &config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}}, + &cliproxyauth.Auth{ProxyURL: "direct"}, + ) + + if dialer.Proxy != nil { + t.Fatal("expected websocket proxy function to be nil for direct mode") + } +} diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 7ad1c618..84df56f9 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip // For API key auth, use simpler URL format without project/location if baseURL == "" { - baseURL = "https://generativelanguage.googleapis.com" + baseURL = "https://aiplatform.googleapis.com" } url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) if opts.Alt != "" && action != "countTokens" { @@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth action := getVertexAction(baseModel, true) // For API key auth, use simpler URL format without project/location if baseURL == "" { - baseURL = "https://generativelanguage.googleapis.com" + baseURL = "https://aiplatform.googleapis.com" } url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) // Imagen models don't support streaming, skip SSE params @@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * // For API key auth, use simpler URL format without project/location if baseURL == "" { - baseURL = "https://generativelanguage.googleapis.com" + baseURL = "https://aiplatform.googleapis.com" } url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens") diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index ab0f626a..5511497b 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -2,16 +2,14 @@ package executor import ( "context" - "net" "net/http" - "net/url" "strings" "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" ) // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: @@ -72,45 +70,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip // Returns: // - *http.Transport: A configured transport, or nil if the proxy URL is invalid func buildProxyTransport(proxyURL string) *http.Transport { - if proxyURL == "" { + transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL) + if errBuild != nil { + log.Errorf("%v", errBuild) return nil } - - parsedURL, errParse := url.Parse(proxyURL) - if errParse != nil { - log.Errorf("parse proxy URL failed: %v", errParse) - return nil - } - - var transport *http.Transport - - // Handle different proxy schemes - if parsedURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication - var proxyAuth *proxy.Auth - if parsedURL.User != nil { - username := parsedURL.User.Username() - password, _ := parsedURL.User.Password() - proxyAuth = &proxy.Auth{User: username, Password: password} - } - dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil - } - // Set up a custom transport using the SOCKS5 dialer - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy - transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)} - } else { - log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) - return nil - } - return transport } diff --git a/internal/runtime/executor/proxy_helpers_test.go b/internal/runtime/executor/proxy_helpers_test.go new file mode 100644 index 00000000..4ae5c937 --- /dev/null +++ b/internal/runtime/executor/proxy_helpers_test.go @@ -0,0 +1,30 @@ +package executor + +import ( + "context" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) { + t.Parallel() + + client := newProxyAwareHTTPClient( + context.Background(), + &config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}}, + &cliproxyauth.Auth{ProxyURL: "direct"}, + 0, + ) + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", client.Transport) + } + if transport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") + } +} diff --git a/internal/thinking/apply.go b/internal/thinking/apply.go index b8a0fcae..c79ecd8e 100644 --- a/internal/thinking/apply.go +++ b/internal/thinking/apply.go @@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma if suffixResult.HasSuffix { config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID) } else { - config = extractThinkingConfig(body, toFormat) + config = extractThinkingConfig(body, fromFormat) + if !hasThinkingConfig(config) && fromFormat != toFormat { + config = extractThinkingConfig(body, toFormat) + } } if !hasThinkingConfig(config) { @@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri if config.Mode != ModeLevel { return config } + if toFormat == "claude" { + return config + } if !isBudgetCapableProvider(toFormat) { return config } diff --git a/internal/thinking/apply_user_defined_test.go b/internal/thinking/apply_user_defined_test.go new file mode 100644 index 00000000..aa24ab8e --- /dev/null +++ b/internal/thinking/apply_user_defined_test.go @@ -0,0 +1,55 @@ +package thinking_test + +import ( + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + _ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude" + "github.com/tidwall/gjson" +) + +func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) { + reg := registry.GetGlobalRegistry() + clientID := "test-user-defined-claude-" + t.Name() + modelID := "custom-claude-4-6" + reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}}) + t.Cleanup(func() { + reg.UnregisterClient(clientID) + }) + + tests := []struct { + name string + model string + body []byte + }{ + { + name: "claude adaptive effort body", + model: modelID, + body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`), + }, + { + name: "suffix level", + model: modelID + "(high)", + body: []byte(`{}`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude") + if err != nil { + t.Fatalf("ApplyThinking() error = %v", err) + } + if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" { + t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out)) + } + if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" { + t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out)) + } + if gjson.GetBytes(out, "thinking.budget_tokens").Exists() { + t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out)) + } + }) + } +} diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index 8c1a38c5..3a6ba4b5 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ effort = strings.ToLower(strings.TrimSpace(v.String())) } if effort != "" { - if effort == "max" { - effort = "high" - } out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) } else { out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index 39dc493d..696240ef 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw) } } - -func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) { - tests := []struct { - name string - effort string - expected string - }{ - {"low", "low", "low"}, - {"medium", "medium", "medium"}, - {"high", "high", "high"}, - {"max", "max", "high"}, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - inputJSON := []byte(`{ - "model": "claude-opus-4-6-thinking", - "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], - "thinking": {"type": "adaptive"}, - "output_config": {"effort": "` + tt.effort + `"} - }`) - - output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false) - outputStr := string(output) - - thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig") - if !thinkingConfig.Exists() { - t.Fatal("thinkingConfig should exist for adaptive thinking") - } - if thinkingConfig.Get("thinkingLevel").String() != tt.expected { - t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String()) - } - if !thinkingConfig.Get("includeThoughts").Bool() { - t.Error("includeThoughts should be true") - } - }) - } -} - -func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) { - inputJSON := []byte(`{ - "model": "claude-opus-4-6-thinking", - "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], - "thinking": {"type": "adaptive"} - }`) - - output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false) - outputStr := string(output) - - thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig") - if !thinkingConfig.Exists() { - t.Fatal("thinkingConfig should exist for adaptive thinking without effort") - } - if thinkingConfig.Get("thinkingLevel").String() != "high" { - t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String()) - } - if !thinkingConfig.Get("includeThoughts").Bool() { - t.Error("includeThoughts should be true") - } -} diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 3c834f6f..893e4d07 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -15,6 +15,7 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) + data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) diff --git a/internal/translator/codex/claude/codex_claude_request.go b/internal/translator/codex/claude/codex_claude_request.go index 6373e693..4bc116b9 100644 --- a/internal/translator/codex/claude/codex_claude_request.go +++ b/internal/translator/codex/claude/codex_claude_request.go @@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) // Process system messages and convert them to input content format. systemsResult := rootResult.Get("system") - if systemsResult.IsArray() { - systemResults := systemsResult.Array() + if systemsResult.Exists() { message := `{"type":"message","role":"developer","content":[]}` contentIndex := 0 - for i := 0; i < len(systemResults); i++ { - systemResult := systemResults[i] - systemTypeResult := systemResult.Get("type") - if systemTypeResult.String() == "text" { - text := systemResult.Get("text").String() - if strings.HasPrefix(text, "x-anthropic-billing-header: ") { - continue + + appendSystemText := func(text string) { + if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") { + return + } + + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text") + message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text) + contentIndex++ + } + + if systemsResult.Type == gjson.String { + appendSystemText(systemsResult.String()) + } else if systemsResult.IsArray() { + systemResults := systemsResult.Array() + for i := 0; i < len(systemResults); i++ { + systemResult := systemResults[i] + if systemResult.Get("type").String() == "text" { + appendSystemText(systemResult.Get("text").String()) } - message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text") - message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text) - contentIndex++ } } + if contentIndex > 0 { template, _ = sjson.SetRaw(template, "input.-1", message) } diff --git a/internal/translator/codex/claude/codex_claude_request_test.go b/internal/translator/codex/claude/codex_claude_request_test.go new file mode 100644 index 00000000..bdd41639 --- /dev/null +++ b/internal/translator/codex/claude/codex_claude_request_test.go @@ -0,0 +1,89 @@ +package claude + +import ( + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) { + tests := []struct { + name string + inputJSON string + wantHasDeveloper bool + wantTexts []string + }{ + { + name: "No system field", + inputJSON: `{ + "model": "claude-3-opus", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: false, + }, + { + name: "Empty string system field", + inputJSON: `{ + "model": "claude-3-opus", + "system": "", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: false, + }, + { + name: "String system field", + inputJSON: `{ + "model": "claude-3-opus", + "system": "Be helpful", + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: true, + wantTexts: []string{"Be helpful"}, + }, + { + name: "Array system field with filtered billing header", + inputJSON: `{ + "model": "claude-3-opus", + "system": [ + {"type": "text", "text": "x-anthropic-billing-header: tenant-123"}, + {"type": "text", "text": "Block 1"}, + {"type": "text", "text": "Block 2"} + ], + "messages": [{"role": "user", "content": "hello"}] + }`, + wantHasDeveloper: true, + wantTexts: []string{"Block 1", "Block 2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false) + resultJSON := gjson.ParseBytes(result) + inputs := resultJSON.Get("input").Array() + + hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer" + if hasDeveloper != tt.wantHasDeveloper { + t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw) + } + + if !tt.wantHasDeveloper { + return + } + + content := inputs[0].Get("content").Array() + if len(content) != len(tt.wantTexts) { + t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw) + } + + for i, wantText := range tt.wantTexts { + if gotType := content[i].Get("type").String(); gotType != "input_text" { + t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text") + } + if gotText := content[i].Get("text").String(); gotText != wantText { + t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText) + } + } + }) + } +} diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 7f597062..cf0fee46 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -12,6 +12,7 @@ import ( "fmt" "strings" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa (*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) - template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String()) + template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String())) { // Restore original tool name if shortened name := itemResult.Get("name").String() @@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original } toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` - toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String()) + toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String())) toolBlock, _ = sjson.Set(toolBlock, "name", name) inputRaw := "{}" if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) { diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go index 1126f1ee..3d310d8b 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_response.go @@ -14,6 +14,7 @@ import ( "sync/atomic" "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) + data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))) data, _ = sjson.Set(data, "content_block.name", fcName) output = output + fmt.Sprintf("data: %s\n\n\n", data) diff --git a/internal/translator/gemini/claude/gemini_claude_response.go b/internal/translator/gemini/claude/gemini_claude_response.go index e5adcb5e..eeb4af11 100644 --- a/internal/translator/gemini/claude/gemini_claude_response.go +++ b/internal/translator/gemini/claude/gemini_claude_response.go @@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR // Create the tool use block with unique ID and function details data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) - data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))) + data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))) data, _ = sjson.Set(data, "content_block.name", clientToolName) output = output + fmt.Sprintf("data: %s\n\n\n", data) @@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina clientToolName := util.MapToolName(toolNameMap, upstreamToolName) toolIDCounter++ toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` - toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)) + toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))) toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName) inputRaw := "{}" if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() { diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index f18f45be..c8948ac5 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) content := m.Get("content") if (role == "system" || role == "developer") && len(arr) > 1 { - // system -> system_instruction as a user message style + // system -> systemInstruction as a user message style if content.Type == gjson.String { - out, _ = sjson.SetBytes(out, "system_instruction.role", "user") - out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String()) + out, _ = sjson.SetBytes(out, "systemInstruction.role", "user") + out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String()) systemPartIndex++ } else if content.IsObject() && content.Get("type").String() == "text" { - out, _ = sjson.SetBytes(out, "system_instruction.role", "user") - out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String()) + out, _ = sjson.SetBytes(out, "systemInstruction.role", "user") + out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String()) systemPartIndex++ } else if content.IsArray() { contents := content.Array() if len(contents) > 0 { - out, _ = sjson.SetBytes(out, "system_instruction.role", "user") + out, _ = sjson.SetBytes(out, "systemInstruction.role", "user") for j := 0; j < len(contents); j++ { - out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String()) + out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String()) systemPartIndex++ } } diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go index 143359d6..463203a7 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_request.go @@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte if instructions := root.Get("instructions"); instructions.Exists() { systemInstr := `{"parts":[{"text":""}]}` systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String()) - out, _ = sjson.SetRaw(out, "system_instruction", systemInstr) + out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr) } // Convert input messages to Gemini contents format @@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte if strings.EqualFold(itemRole, "system") { if contentArray := item.Get("content"); contentArray.Exists() { systemInstr := "" - if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() { + if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() { systemInstr = systemInstructionResult.Raw } else { systemInstr = `{"parts":[]}` @@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte } if systemInstr != `{"parts":[]}` { - out, _ = sjson.SetRaw(out, "system_instruction", systemInstr) + out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr) } } continue diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go index 7bb496a2..eddead62 100644 --- a/internal/translator/openai/claude/openai_claude_response.go +++ b/internal/translator/openai/claude/openai_claude_response.go @@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI // Send content_block_start for tool_use contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex) - contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID) + contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID)) contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name) results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n") } @@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string { if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}` - toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String()) + toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String())) toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String()) argsStr := util.FixJSON(toolCall.Get("function.arguments").String()) @@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina toolCalls.ForEach(func(_, tc gjson.Result) bool { hasToolCall = true toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` - toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String()) + toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String())) toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String())) argsStr := util.FixJSON(tc.Get("function.arguments").String()) @@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina toolCalls.ForEach(func(_, toolCall gjson.Result) bool { hasToolCall = true toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}` - toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String()) + toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String())) toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String())) argsStr := util.FixJSON(toolCall.Get("function.arguments").String()) diff --git a/internal/util/claude_tool_id.go b/internal/util/claude_tool_id.go new file mode 100644 index 00000000..46545168 --- /dev/null +++ b/internal/util/claude_tool_id.go @@ -0,0 +1,24 @@ +package util + +import ( + "fmt" + "regexp" + "sync/atomic" + "time" +) + +var ( + claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`) + claudeToolUseIDCounter uint64 +) + +// SanitizeClaudeToolID ensures the given id conforms to Claude's +// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are +// replaced with '_'; an empty result gets a generated fallback. +func SanitizeClaudeToolID(id string) string { + s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_") + if s == "" { + s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1)) + } + return s +} diff --git a/internal/util/proxy.go b/internal/util/proxy.go index aea52ba8..9b57ca17 100644 --- a/internal/util/proxy.go +++ b/internal/util/proxy.go @@ -4,50 +4,25 @@ package util import ( - "context" - "net" "net/http" - "net/url" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" ) // SetProxy configures the provided HTTP client with proxy settings from the configuration. // It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport // to route requests through the configured proxy server. func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { - var transport *http.Transport - // Attempt to parse the proxy URL from the configuration. - proxyURL, errParse := url.Parse(cfg.ProxyURL) - if errParse == nil { - // Handle different proxy schemes. - if proxyURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication. - var proxyAuth *proxy.Auth - if proxyURL.User != nil { - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth = &proxy.Auth{User: username, Password: password} - } - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return httpClient - } - // Set up a custom transport using the SOCKS5 dialer. - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } + if cfg == nil || httpClient == nil { + return httpClient + } + + transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL) + if errBuild != nil { + log.Errorf("%v", errBuild) } - // If a new transport was created, apply it to the HTTP client. if transport != nil { httpClient.Transport = transport } diff --git a/internal/watcher/synthesizer/file.go b/internal/watcher/synthesizer/file.go index 02a0cefa..ab54aeaa 100644 --- a/internal/watcher/synthesizer/file.go +++ b/internal/watcher/synthesizer/file.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" ) @@ -149,6 +150,16 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) [] } } ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") + // For codex auth files, extract plan_type from the JWT id_token. + if provider == "codex" { + if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" { + if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil { + if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" { + a.Attributes["plan_type"] = pt + } + } + } + } if provider == "gemini-cli" { if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { for _, v := range virtuals { diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 6a444b45..d417d6b2 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -34,6 +34,8 @@ const ( wsTurnStateHeader = "x-codex-turn-state" wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" wsPayloadLogMaxSize = 2048 + wsBodyLogMaxSize = 64 * 1024 + wsBodyLogTruncated = "\n[websocket log truncated]\n" ) var responsesWebsocketUpgrader = websocket.Upgrader{ @@ -825,18 +827,71 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload [] if builder == nil { return } + if builder.Len() >= wsBodyLogMaxSize { + return + } trimmedPayload := bytes.TrimSpace(payload) if len(trimmedPayload) == 0 { return } if builder.Len() > 0 { - builder.WriteString("\n") + if !appendWebsocketLogString(builder, "\n") { + return + } } - builder.WriteString("websocket.") - builder.WriteString(eventType) - builder.WriteString("\n") - builder.Write(trimmedPayload) - builder.WriteString("\n") + if !appendWebsocketLogString(builder, "websocket.") { + return + } + if !appendWebsocketLogString(builder, eventType) { + return + } + if !appendWebsocketLogString(builder, "\n") { + return + } + if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) { + appendWebsocketLogString(builder, wsBodyLogTruncated) + return + } + appendWebsocketLogString(builder, "\n") +} + +func appendWebsocketLogString(builder *strings.Builder, value string) bool { + if builder == nil { + return false + } + remaining := wsBodyLogMaxSize - builder.Len() + if remaining <= 0 { + return false + } + if len(value) <= remaining { + builder.WriteString(value) + return true + } + builder.WriteString(value[:remaining]) + return false +} + +func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool { + if builder == nil { + return false + } + remaining := wsBodyLogMaxSize - builder.Len() + if remaining <= 0 { + return false + } + if len(value) <= remaining { + builder.Write(value) + return true + } + limit := remaining - reserveForSuffix + if limit < 0 { + limit = 0 + } + if limit > len(value) { + limit = len(value) + } + builder.Write(value[:limit]) + return false } func websocketPayloadEventType(payload []byte) string { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index d30c648d..981c6630 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -266,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) { } } +func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) { + var builder strings.Builder + payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize) + + appendWebsocketEvent(&builder, "request", payload) + + got := builder.String() + if len(got) > wsBodyLogMaxSize { + t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize) + } + if !strings.Contains(got, wsBodyLogTruncated) { + t.Fatalf("expected truncation marker in body log") + } +} + +func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) { + var builder strings.Builder + appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize)) + initial := builder.String() + + appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`)) + + if builder.String() != initial { + t.Fatalf("builder grew after reaching limit") + } +} + func TestSetWebsocketRequestBody(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() diff --git a/sdk/auth/codex_device.go b/sdk/auth/codex_device.go index 78a95af8..10f59fb9 100644 --- a/sdk/auth/codex_device.go +++ b/sdk/auth/codex_device.go @@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl FileName: fileName, Storage: tokenStorage, Metadata: metadata, + Attributes: map[string]string{ + "plan_type": planType, + }, }, nil } diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index ae5b745c..b29e04db 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -134,6 +134,7 @@ type Manager struct { hook Hook mu sync.RWMutex auths map[string]*Auth + scheduler *authScheduler // providerOffsets tracks per-model provider rotation state for multi-provider routing. providerOffsets map[string]int @@ -149,6 +150,9 @@ type Manager struct { // Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix). apiKeyModelAlias atomic.Value + // modelPoolOffsets tracks per-auth alias pool rotation state. + modelPoolOffsets map[string]int + // runtimeConfig stores the latest application config for request-time decisions. // It is initialized in NewManager; never Load() before first Store(). runtimeConfig atomic.Value @@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { hook: hook, auths: make(map[string]*Auth), providerOffsets: make(map[string]int), + modelPoolOffsets: make(map[string]int), refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), } // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil)) + manager.scheduler = newAuthScheduler(selector) return manager } +func isBuiltInSelector(selector Selector) bool { + switch selector.(type) { + case *RoundRobinSelector, *FillFirstSelector: + return true + default: + return false + } +} + +func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) { + if m == nil || m.scheduler == nil { + return + } + m.scheduler.rebuild(auths) +} + +func (m *Manager) syncScheduler() { + if m == nil || m.scheduler == nil { + return + } + m.syncSchedulerFromSnapshot(m.snapshotAuths()) +} + +// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its +// supportedModelSet is rebuilt from the current global model registry state. +// This must be called after models have been registered for a newly added auth, +// because the initial scheduler.upsertAuth during Register/Update runs before +// registerModelsForAuth and therefore snapshots an empty model set. +func (m *Manager) RefreshSchedulerEntry(authID string) { + if m == nil || m.scheduler == nil || authID == "" { + return + } + m.mu.RLock() + auth, ok := m.auths[authID] + if !ok || auth == nil { + m.mu.RUnlock() + return + } + snapshot := auth.Clone() + m.mu.RUnlock() + m.scheduler.upsertAuth(snapshot) +} + func (m *Manager) SetSelector(selector Selector) { if m == nil { return @@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) { m.mu.Lock() m.selector = selector m.mu.Unlock() + if m.scheduler != nil { + m.scheduler.setSelector(selector) + m.syncScheduler() + } } // SetStore swaps the underlying persistence store. @@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin if resolved == "" { return "" } - // Preserve thinking suffix from the client's requested model unless config already has one. - requestResult := thinking.ParseSuffix(requestedModel) - if thinking.ParseSuffix(resolved).HasSuffix { - return resolved - } - if requestResult.HasSuffix && requestResult.RawSuffix != "" { - return resolved + "(" + requestResult.RawSuffix + ")" - } - return resolved + return preserveRequestedModelSuffix(requestedModel, resolved) +} +func isAPIKeyAuth(auth *Auth) bool { + if auth == nil { + return false + } + kind, _ := auth.AccountInfo() + return strings.EqualFold(strings.TrimSpace(kind), "api_key") +} + +func isOpenAICompatAPIKeyAuth(auth *Auth) bool { + if !isAPIKeyAuth(auth) { + return false + } + if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") { + return true + } + if auth.Attributes == nil { + return false + } + return strings.TrimSpace(auth.Attributes["compat_name"]) != "" +} + +func openAICompatProviderKey(auth *Auth) string { + if auth == nil { + return "" + } + if auth.Attributes != nil { + if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" { + return strings.ToLower(providerKey) + } + if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" { + return strings.ToLower(compatName) + } + } + return strings.ToLower(strings.TrimSpace(auth.Provider)) +} + +func openAICompatModelPoolKey(auth *Auth, requestedModel string) string { + base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName) + if base == "" { + base = strings.TrimSpace(requestedModel) + } + return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base) +} + +func (m *Manager) nextModelPoolOffset(key string, size int) int { + if m == nil || size <= 1 { + return 0 + } + key = strings.TrimSpace(key) + if key == "" { + return 0 + } + m.mu.Lock() + defer m.mu.Unlock() + if m.modelPoolOffsets == nil { + m.modelPoolOffsets = make(map[string]int) + } + offset := m.modelPoolOffsets[key] + if offset >= 2_147_483_640 { + offset = 0 + } + m.modelPoolOffsets[key] = offset + 1 + if size <= 0 { + return 0 + } + return offset % size +} + +func rotateStrings(values []string, offset int) []string { + if len(values) <= 1 { + return values + } + if offset <= 0 { + out := make([]string, len(values)) + copy(out, values) + return out + } + offset = offset % len(values) + out := make([]string, 0, len(values)) + out = append(out, values[offset:]...) + out = append(out, values[:offset]...) + return out +} + +func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string { + if m == nil || !isOpenAICompatAPIKeyAuth(auth) { + return nil + } + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return nil + } + cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config) + if cfg == nil { + cfg = &internalconfig.Config{} + } + providerKey := "" + compatName := "" + if auth.Attributes != nil { + providerKey = strings.TrimSpace(auth.Attributes["provider_key"]) + compatName = strings.TrimSpace(auth.Attributes["compat_name"]) + } + entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider) + if entry == nil { + return nil + } + return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models)) +} + +func preserveRequestedModelSuffix(requestedModel, resolved string) string { + return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel)) +} + +func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string { + return m.prepareExecutionModels(auth, routeModel) +} + +func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string { + requestedModel := rewriteModelForAuth(routeModel, auth) + requestedModel = m.applyOAuthModelAlias(auth, requestedModel) + if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 { + if len(pool) == 1 { + return pool + } + offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool)) + return rotateStrings(pool, offset) + } + resolved := m.applyAPIKeyModelAlias(auth, requestedModel) + if strings.TrimSpace(resolved) == "" { + resolved = requestedModel + } + return []string{resolved} +} + +func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) { + if ch == nil { + return + } + go func() { + for range ch { + } + }() +} + +func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) { + if ch == nil { + return nil, true, nil + } + buffered := make([]cliproxyexecutor.StreamChunk, 0, 1) + for { + var ( + chunk cliproxyexecutor.StreamChunk + ok bool + ) + if ctx != nil { + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + case chunk, ok = <-ch: + } + } else { + chunk, ok = <-ch + } + if !ok { + return buffered, true, nil + } + if chunk.Err != nil { + return nil, false, chunk.Err + } + buffered = append(buffered, chunk) + if len(chunk.Payload) > 0 { + return buffered, false, nil + } + } +} + +func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult { + out := make(chan cliproxyexecutor.StreamChunk) + go func() { + defer close(out) + var failed bool + forward := true + emit := func(chunk cliproxyexecutor.StreamChunk) bool { + if chunk.Err != nil && !failed { + failed = true + rerr := &Error{Message: chunk.Err.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}) + } + if !forward { + return false + } + if ctx == nil { + out <- chunk + return true + } + select { + case <-ctx.Done(): + forward = false + return false + case out <- chunk: + return true + } + } + for _, chunk := range buffered { + if ok := emit(chunk); !ok { + discardStreamChunks(remaining) + return + } + } + for chunk := range remaining { + if ok := emit(chunk); !ok { + discardStreamChunks(remaining) + return + } + } + if !failed { + m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true}) + } + }() + return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out} +} + +func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) { + if executor == nil { + return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} + } + execModels := m.prepareExecutionModels(auth, routeModel) + var lastErr error + for idx, execModel := range execModels { + execReq := req + execReq.Model = execModel + streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts) + if errStream != nil { + if errCtx := ctx.Err(); errCtx != nil { + return nil, errCtx + } + rerr := &Error{Message: errStream.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(errStream) + m.MarkResult(ctx, result) + if isRequestInvalidError(errStream) { + return nil, errStream + } + lastErr = errStream + continue + } + + buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks) + if bootstrapErr != nil { + if errCtx := ctx.Err(); errCtx != nil { + discardStreamChunks(streamResult.Chunks) + return nil, errCtx + } + if isRequestInvalidError(bootstrapErr) { + rerr := &Error{Message: bootstrapErr.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(bootstrapErr) + m.MarkResult(ctx, result) + discardStreamChunks(streamResult.Chunks) + return nil, bootstrapErr + } + if idx < len(execModels)-1 { + rerr := &Error{Message: bootstrapErr.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { + rerr.HTTPStatus = se.StatusCode() + } + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} + result.RetryAfter = retryAfterFromError(bootstrapErr) + m.MarkResult(ctx, result) + discardStreamChunks(streamResult.Chunks) + lastErr = bootstrapErr + continue + } + errCh := make(chan cliproxyexecutor.StreamChunk, 1) + errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr} + close(errCh) + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + } + + if closed && len(buffered) == 0 { + emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true} + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr} + m.MarkResult(ctx, result) + if idx < len(execModels)-1 { + lastErr = emptyErr + continue + } + errCh := make(chan cliproxyexecutor.StreamChunk, 1) + errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr} + close(errCh) + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil + } + + remaining := streamResult.Chunks + if closed { + closedCh := make(chan cliproxyexecutor.StreamChunk) + close(closedCh) + remaining = closedCh + } + return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil + } + if lastErr == nil { + lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"} + } + return nil, lastErr } func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() { @@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) { auth.ID = uuid.NewString() } auth.EnsureIndex() + authClone := auth.Clone() m.mu.Lock() - m.auths[auth.ID] = auth.Clone() + m.auths[auth.ID] = authClone m.mu.Unlock() m.rebuildAPIKeyModelAliasFromRuntimeConfig() + if m.scheduler != nil { + m.scheduler.upsertAuth(authClone) + } _ = m.persist(ctx, auth) m.hook.OnAuthRegistered(ctx, auth.Clone()) return auth.Clone(), nil @@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { } } auth.EnsureIndex() - m.auths[auth.ID] = auth.Clone() + authClone := auth.Clone() + m.auths[auth.ID] = authClone m.mu.Unlock() m.rebuildAPIKeyModelAliasFromRuntimeConfig() + if m.scheduler != nil { + m.scheduler.upsertAuth(authClone) + } _ = m.persist(ctx, auth) m.hook.OnAuthUpdated(ctx, auth.Clone()) return auth.Clone(), nil @@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) { // Load resets manager state from the backing store. func (m *Manager) Load(ctx context.Context) error { m.mu.Lock() - defer m.mu.Unlock() if m.store == nil { + m.mu.Unlock() return nil } items, err := m.store.List(ctx) if err != nil { + m.mu.Unlock() return err } m.auths = make(map[string]*Auth, len(items)) @@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error { cfg = &internalconfig.Config{} } m.rebuildAPIKeyModelAliasLocked(cfg) + m.mu.Unlock() + m.syncScheduler() return nil } @@ -634,32 +1005,42 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - resp, errExec := executor.Execute(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} - if errExec != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return cliproxyexecutor.Response{}, errCtx - } - result.Error = &Error{Message: errExec.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { - result.Error.HTTPStatus = se.StatusCode() - } - if ra := retryAfterFromError(errExec); ra != nil { - result.RetryAfter = ra + + models := m.prepareExecutionModels(auth, routeModel) + var authErr error + for _, upstreamModel := range models { + execReq := req + execReq.Model = upstreamModel + resp, errExec := executor.Execute(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + if errExec != nil { + if errCtx := execCtx.Err(); errCtx != nil { + return cliproxyexecutor.Response{}, errCtx + } + result.Error = &Error{Message: errExec.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.MarkResult(execCtx, result) + if isRequestInvalidError(errExec) { + return cliproxyexecutor.Response{}, errExec + } + authErr = errExec + continue } m.MarkResult(execCtx, result) - if isRequestInvalidError(errExec) { - return cliproxyexecutor.Response{}, errExec + return resp, nil + } + if authErr != nil { + if isRequestInvalidError(authErr) { + return cliproxyexecutor.Response{}, authErr } - lastErr = errExec + lastErr = authErr continue } - m.MarkResult(execCtx, result) - return resp, nil } } @@ -696,32 +1077,42 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} - if errExec != nil { - if errCtx := execCtx.Err(); errCtx != nil { - return cliproxyexecutor.Response{}, errCtx - } - result.Error = &Error{Message: errExec.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { - result.Error.HTTPStatus = se.StatusCode() - } - if ra := retryAfterFromError(errExec); ra != nil { - result.RetryAfter = ra + + models := m.prepareExecutionModels(auth, routeModel) + var authErr error + for _, upstreamModel := range models { + execReq := req + execReq.Model = upstreamModel + resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} + if errExec != nil { + if errCtx := execCtx.Err(); errCtx != nil { + return cliproxyexecutor.Response{}, errCtx + } + result.Error = &Error{Message: errExec.Error()} + if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + if ra := retryAfterFromError(errExec); ra != nil { + result.RetryAfter = ra + } + m.hook.OnResult(execCtx, result) + if isRequestInvalidError(errExec) { + return cliproxyexecutor.Response{}, errExec + } + authErr = errExec + continue } m.hook.OnResult(execCtx, result) - if isRequestInvalidError(errExec) { - return cliproxyexecutor.Response{}, errExec + return resp, nil + } + if authErr != nil { + if isRequestInvalidError(authErr) { + return cliproxyexecutor.Response{}, authErr } - lastErr = errExec + lastErr = authErr continue } - m.hook.OnResult(execCtx, result) - return resp, nil } } @@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - execReq := req - execReq.Model = rewriteModelForAuth(routeModel, auth) - execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model) - execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model) - streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) + streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel) if errStream != nil { if errCtx := execCtx.Err(); errCtx != nil { return nil, errCtx } - rerr := &Error{Message: errStream.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil { - rerr.HTTPStatus = se.StatusCode() - } - result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} - result.RetryAfter = retryAfterFromError(errStream) - m.MarkResult(execCtx, result) if isRequestInvalidError(errStream) { return nil, errStream } lastErr = errStream continue } - out := make(chan cliproxyexecutor.StreamChunk) - go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) { - defer close(out) - var failed bool - forward := true - for chunk := range streamChunks { - if chunk.Err != nil && !failed { - failed = true - rerr := &Error{Message: chunk.Err.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil { - rerr.HTTPStatus = se.StatusCode() - } - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) - } - if !forward { - continue - } - if streamCtx == nil { - out <- chunk - continue - } - select { - case <-streamCtx.Done(): - forward = false - case out <- chunk: - } - } - if !failed { - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) - } - }(execCtx, auth.Clone(), provider, streamResult.Chunks) - return &cliproxyexecutor.StreamResult{ - Headers: streamResult.Headers, - Chunks: out, - }, nil + return streamResult, nil } } @@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { suspendReason := "" clearModelQuota := false setModelQuota := false + var authSnapshot *Auth m.mu.Lock() if auth, ok := m.auths[result.AuthID]; ok && auth != nil { @@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) { } _ = m.persist(ctx, auth) + authSnapshot = auth.Clone() } m.mu.Unlock() + if m.scheduler != nil && authSnapshot != nil { + m.scheduler.upsertAuth(authSnapshot) + } if clearModelQuota && result.Model != "" { registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model) @@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int { } // isRequestInvalidError returns true if the error represents a client request -// error that should not be retried. Specifically, it checks for 400 Bad Request -// with "invalid_request_error" in the message, indicating the request itself is -// malformed and switching to a different auth will not help. +// error that should not be retried. Specifically, it treats 400 responses with +// "invalid_request_error" and all 422 responses as request-shape failures, +// where switching auths or pooled upstream models will not help. func isRequestInvalidError(err error) bool { if err == nil { return false } status := statusCodeFromError(err) - if status != http.StatusBadRequest { + switch status { + case http.StatusBadRequest: + return strings.Contains(err.Error(), "invalid_request_error") + case http.StatusUnprocessableEntity: + return true + default: return false } - return strings.Contains(err.Error(), "invalid_request_error") } func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) { @@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) { } } -func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { +func (m *Manager) useSchedulerFastPath() bool { + if m == nil || m.scheduler == nil { + return false + } + return isBuiltInSelector(m.selector) +} + +func shouldRetrySchedulerPick(err error) bool { + if err == nil { + return false + } + var cooldownErr *modelCooldownError + if errors.As(err, &cooldownErr) { + return true + } + var authErr *Error + if !errors.As(err, &authErr) || authErr == nil { + return false + } + return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable" +} + +func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) m.mu.RLock() @@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli return authCopy, executor, nil } -func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { +func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { + if !m.useSchedulerFastPath() { + return m.pickNextLegacy(ctx, provider, model, opts, tried) + } + executor, okExecutor := m.Executor(provider) + if !okExecutor { + return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"} + } + selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried) + if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) { + m.syncScheduler() + selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried) + } + if errPick != nil { + return nil, nil, errPick + } + if selected == nil { + return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"} + } + authCopy := selected.Clone() + if !selected.indexAssigned { + m.mu.Lock() + if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned { + current.EnsureIndex() + authCopy = current.Clone() + } + m.mu.Unlock() + } + return authCopy, executor, nil +} + +func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) providerSet := make(map[string]struct{}, len(providers)) @@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s return authCopy, executor, providerKey, nil } +func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { + if !m.useSchedulerFastPath() { + return m.pickNextMixedLegacy(ctx, providers, model, opts, tried) + } + + eligibleProviders := make([]string, 0, len(providers)) + seenProviders := make(map[string]struct{}, len(providers)) + for _, provider := range providers { + providerKey := strings.TrimSpace(strings.ToLower(provider)) + if providerKey == "" { + continue + } + if _, seen := seenProviders[providerKey]; seen { + continue + } + if _, okExecutor := m.Executor(providerKey); !okExecutor { + continue + } + seenProviders[providerKey] = struct{}{} + eligibleProviders = append(eligibleProviders, providerKey) + } + if len(eligibleProviders) == 0 { + return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + + selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried) + if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) { + m.syncScheduler() + selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried) + } + if errPick != nil { + return nil, nil, "", errPick + } + if selected == nil { + return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"} + } + executor, okExecutor := m.Executor(providerKey) + if !okExecutor { + return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"} + } + authCopy := selected.Clone() + if !selected.indexAssigned { + m.mu.Lock() + if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned { + current.EnsureIndex() + authCopy = current.Clone() + } + m.mu.Unlock() + } + return authCopy, executor, providerKey, nil +} + func (m *Manager) persist(ctx context.Context, auth *Auth) error { if m.store == nil || auth == nil { return nil @@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) { current.NextRefreshAfter = now.Add(refreshFailureBackoff) current.LastError = &Error{Message: err.Error()} m.auths[id] = current + if m.scheduler != nil { + m.scheduler.upsertAuth(current.Clone()) + } } m.mu.Unlock() return diff --git a/sdk/cliproxy/auth/conductor_scheduler_refresh_test.go b/sdk/cliproxy/auth/conductor_scheduler_refresh_test.go new file mode 100644 index 00000000..5c6eff78 --- /dev/null +++ b/sdk/cliproxy/auth/conductor_scheduler_refresh_test.go @@ -0,0 +1,163 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type schedulerProviderTestExecutor struct { + provider string +} + +func (e schedulerProviderTestExecutor) Identifier() string { return e.provider } + +func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, nil +} + +func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + return nil, nil +} + +func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) { + ctx := context.Background() + + testCases := []struct { + name string + prime func(*Manager, *Auth) error + }{ + { + name: "register", + prime: func(manager *Manager, auth *Auth) error { + _, errRegister := manager.Register(ctx, auth) + return errRegister + }, + }, + { + name: "update", + prime: func(manager *Manager, auth *Auth) error { + _, errRegister := manager.Register(ctx, auth) + if errRegister != nil { + return errRegister + } + updated := auth.Clone() + updated.Metadata = map[string]any{"updated": true} + _, errUpdate := manager.Update(ctx, updated) + return errUpdate + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + manager := NewManager(nil, &RoundRobinSelector{}, nil) + auth := &Auth{ + ID: "refresh-entry-" + testCase.name, + Provider: "gemini", + } + if errPrime := testCase.prime(manager, auth); errPrime != nil { + t.Fatalf("prime auth %s: %v", testCase.name, errPrime) + } + + registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID) + + got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil) + var authErr *Error + if !errors.As(errPick, &authErr) || authErr == nil { + t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick) + } + if authErr.Code != "auth_not_found" { + t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found") + } + if got != nil { + t.Fatalf("pickSingle() before refresh auth = %v, want nil", got) + } + + manager.RefreshSchedulerEntry(auth.ID) + + got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() after refresh error = %v", errPick) + } + if got == nil || got.ID != auth.ID { + t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID) + } + }) + } +} + +func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) { + ctx := context.Background() + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"}) + + registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old") + + oldAuth := &Auth{ + ID: "cooldown-stale-old", + Provider: "gemini", + } + if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil { + t.Fatalf("register old auth: %v", errRegister) + } + + manager.MarkResult(ctx, Result{ + AuthID: oldAuth.ID, + Provider: "gemini", + Model: "scheduler-cooldown-rebuild-model", + Success: false, + Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}, + }) + + newAuth := &Auth{ + ID: "cooldown-stale-new", + Provider: "gemini", + } + if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil { + t.Fatalf("register new auth: %v", errRegister) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}}) + t.Cleanup(func() { + reg.UnregisterClient(newAuth.ID) + }) + + got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil) + var cooldownErr *modelCooldownError + if !errors.As(errPick, &cooldownErr) { + t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick) + } + if got != nil { + t.Fatalf("pickSingle() before sync auth = %v, want nil", got) + } + + got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickNext() error = %v", errPick) + } + if executor == nil { + t.Fatal("pickNext() executor = nil") + } + if got == nil || got.ID != newAuth.ID { + t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID) + } +} diff --git a/sdk/cliproxy/auth/oauth_model_alias.go b/sdk/cliproxy/auth/oauth_model_alias.go index d5d2ff8a..77a11c19 100644 --- a/sdk/cliproxy/auth/oauth_model_alias.go +++ b/sdk/cliproxy/auth/oauth_model_alias.go @@ -80,54 +80,98 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string return upstreamModel } -func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string { +func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) { requestedModel = strings.TrimSpace(requestedModel) if requestedModel == "" { - return "" + return thinking.SuffixResult{}, nil } - if len(models) == 0 { - return "" - } - requestResult := thinking.ParseSuffix(requestedModel) base := requestResult.ModelName + if base == "" { + base = requestedModel + } candidates := []string{base} if base != requestedModel { candidates = append(candidates, requestedModel) } + return requestResult, candidates +} - preserveSuffix := func(resolved string) string { - resolved = strings.TrimSpace(resolved) - if resolved == "" { - return "" - } - if thinking.ParseSuffix(resolved).HasSuffix { - return resolved - } - if requestResult.HasSuffix && requestResult.RawSuffix != "" { - return resolved + "(" + requestResult.RawSuffix + ")" - } +func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string { + resolved = strings.TrimSpace(resolved) + if resolved == "" { + return "" + } + if thinking.ParseSuffix(resolved).HasSuffix { return resolved } + if requestResult.HasSuffix && requestResult.RawSuffix != "" { + return resolved + "(" + requestResult.RawSuffix + ")" + } + return resolved +} +func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string { + requestedModel = strings.TrimSpace(requestedModel) + if requestedModel == "" { + return nil + } + if len(models) == 0 { + return nil + } + + requestResult, candidates := modelAliasLookupCandidates(requestedModel) + if len(candidates) == 0 { + return nil + } + + out := make([]string, 0) + seen := make(map[string]struct{}) for i := range models { name := strings.TrimSpace(models[i].GetName()) alias := strings.TrimSpace(models[i].GetAlias()) for _, candidate := range candidates { - if candidate == "" { + if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) { continue } - if alias != "" && strings.EqualFold(alias, candidate) { - if name != "" { - return preserveSuffix(name) - } - return preserveSuffix(candidate) + resolved := candidate + if name != "" { + resolved = name } - if name != "" && strings.EqualFold(name, candidate) { - return preserveSuffix(name) + resolved = preserveResolvedModelSuffix(resolved, requestResult) + key := strings.ToLower(strings.TrimSpace(resolved)) + if key == "" { + break } + if _, exists := seen[key]; exists { + break + } + seen[key] = struct{}{} + out = append(out, resolved) + break } } + if len(out) > 0 { + return out + } + + for i := range models { + name := strings.TrimSpace(models[i].GetName()) + for _, candidate := range candidates { + if candidate == "" || name == "" || !strings.EqualFold(name, candidate) { + continue + } + return []string{preserveResolvedModelSuffix(name, requestResult)} + } + } + return nil +} + +func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string { + resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models) + if len(resolved) > 0 { + return resolved[0] + } return "" } diff --git a/sdk/cliproxy/auth/openai_compat_pool_test.go b/sdk/cliproxy/auth/openai_compat_pool_test.go new file mode 100644 index 00000000..5a5ecb4f --- /dev/null +++ b/sdk/cliproxy/auth/openai_compat_pool_test.go @@ -0,0 +1,419 @@ +package auth + +import ( + "context" + "net/http" + "sync" + "testing" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type openAICompatPoolExecutor struct { + id string + + mu sync.Mutex + executeModels []string + countModels []string + streamModels []string + executeErrors map[string]error + countErrors map[string]error + streamFirstErrors map[string]error + streamPayloads map[string][]cliproxyexecutor.StreamChunk +} + +func (e *openAICompatPoolExecutor) Identifier() string { return e.id } + +func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + _ = ctx + _ = auth + _ = opts + e.mu.Lock() + e.executeModels = append(e.executeModels, req.Model) + err := e.executeErrors[req.Model] + e.mu.Unlock() + if err != nil { + return cliproxyexecutor.Response{}, err + } + return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil +} + +func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + _ = ctx + _ = auth + _ = opts + e.mu.Lock() + e.streamModels = append(e.streamModels, req.Model) + err := e.streamFirstErrors[req.Model] + payloadChunks, hasCustomChunks := e.streamPayloads[req.Model] + chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...) + e.mu.Unlock() + ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks))) + if err != nil { + ch <- cliproxyexecutor.StreamChunk{Err: err} + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil + } + if !hasCustomChunks { + ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)} + } else { + for _, chunk := range chunks { + ch <- chunk + } + } + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil +} + +func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + _ = ctx + _ = auth + _ = opts + e.mu.Lock() + e.countModels = append(e.countModels, req.Model) + err := e.countErrors[req.Model] + e.mu.Unlock() + if err != nil { + return cliproxyexecutor.Response{}, err + } + return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil +} + +func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + _ = ctx + _ = auth + _ = req + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"} +} + +func (e *openAICompatPoolExecutor) ExecuteModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.executeModels)) + copy(out, e.executeModels) + return out +} + +func (e *openAICompatPoolExecutor) CountModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.countModels)) + copy(out, e.countModels) + return out +} + +func (e *openAICompatPoolExecutor) StreamModels() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.streamModels)) + copy(out, e.streamModels) + return out +} + +func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager { + t.Helper() + cfg := &internalconfig.Config{ + OpenAICompatibility: []internalconfig.OpenAICompatibility{{ + Name: "pool", + Models: models, + }}, + } + m := NewManager(nil, nil, nil) + m.SetConfig(cfg) + if executor == nil { + executor = &openAICompatPoolExecutor{id: "pool"} + } + m.RegisterExecutor(executor) + + auth := &Auth{ + ID: "pool-auth-" + t.Name(), + Provider: "pool", + Status: StatusActive, + Attributes: map[string]string{ + "api_key": "test-key", + "compat_name": "pool", + "provider_key": "pool", + }, + } + if _, err := m.Register(context.Background(), auth); err != nil { + t.Fatalf("register auth: %v", err) + } + + reg := registry.GetGlobalRegistry() + reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}}) + t.Cleanup(func() { + reg.UnregisterClient(auth.ID) + }) + return m +} + +func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} + executor := &openAICompatPoolExecutor{ + id: "pool", + countErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + _, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil || err.Error() != invalidErr.Error() { + t.Fatalf("execute count error = %v, want %v", err, invalidErr) + } + got := executor.CountModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("count calls = %v, want only first invalid model", got) + } +} +func TestResolveModelAliasPoolFromConfigModels(t *testing.T) { + models := []modelAliasEntry{ + internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"}, + internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"}, + internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"}, + } + got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models) + want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"} + if len(got) != len(want) { + t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{id: "pool"} + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 3; i++ { + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute %d: %v", i, err) + } + if len(resp.Payload) == 0 { + t.Fatalf("execute %d returned empty payload", i) + } + } + + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"} + if len(got) != len(want) { + t.Fatalf("execute calls = %v, want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + _, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil || err.Error() != invalidErr.Error() { + t.Fatalf("execute error = %v, want %v", err, invalidErr) + } + got := executor.ExecuteModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("execute calls = %v, want only first invalid model", got) + } +} +func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{ + id: "pool", + executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute: %v", err) + } + if string(resp.Payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5") + } + got := executor.ExecuteModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{ + id: "pool", + streamPayloads: map[string][]cliproxyexecutor.StreamChunk{ + "qwen3.5-plus": {}, + }, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute stream: %v", err) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(payload), "glm-5") + } + got := executor.StreamModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute stream: %v", err) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != "glm-5" { + t.Fatalf("payload = %q, want %q", string(payload), "glm-5") + } + got := executor.StreamModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i]) + } + } + if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" { + t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5") + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"} + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + _, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil || err.Error() != invalidErr.Error() { + t.Fatalf("execute stream error = %v, want %v", err, invalidErr) + } + got := executor.StreamModels() + if len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("stream calls = %v, want only first invalid model", got) + } +} +func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) { + alias := "claude-opus-4.66" + executor := &openAICompatPoolExecutor{id: "pool"} + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + for i := 0; i < 2; i++ { + resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("execute count %d: %v", i, err) + } + if len(resp.Payload) == 0 { + t.Fatalf("execute count %d returned empty payload", i) + } + } + + got := executor.CountModels() + want := []string{"qwen3.5-plus", "glm-5"} + for i := range want { + if got[i] != want[i] { + t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i]) + } + } +} + +func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) { + alias := "claude-opus-4.66" + invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"} + executor := &openAICompatPoolExecutor{ + id: "pool", + streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr}, + } + m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{ + {Name: "qwen3.5-plus", Alias: alias}, + {Name: "glm-5", Alias: alias}, + }, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{}) + if err == nil { + t.Fatal("expected invalid request error") + } + if err != invalidErr { + t.Fatalf("error = %v, want %v", err, invalidErr) + } + if streamResult != nil { + t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult) + } + if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" { + t.Fatalf("stream calls = %v, want only first upstream model", got) + } +} diff --git a/sdk/cliproxy/auth/scheduler.go b/sdk/cliproxy/auth/scheduler.go new file mode 100644 index 00000000..bfff53bf --- /dev/null +++ b/sdk/cliproxy/auth/scheduler.go @@ -0,0 +1,904 @@ +package auth + +import ( + "context" + "sort" + "strings" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// schedulerStrategy identifies which built-in routing semantics the scheduler should apply. +type schedulerStrategy int + +const ( + schedulerStrategyCustom schedulerStrategy = iota + schedulerStrategyRoundRobin + schedulerStrategyFillFirst +) + +// scheduledState describes how an auth currently participates in a model shard. +type scheduledState int + +const ( + scheduledStateReady scheduledState = iota + scheduledStateCooldown + scheduledStateBlocked + scheduledStateDisabled +) + +// authScheduler keeps the incremental provider/model scheduling state used by Manager. +type authScheduler struct { + mu sync.Mutex + strategy schedulerStrategy + providers map[string]*providerScheduler + authProviders map[string]string + mixedCursors map[string]int +} + +// providerScheduler stores auth metadata and model shards for a single provider. +type providerScheduler struct { + providerKey string + auths map[string]*scheduledAuthMeta + modelShards map[string]*modelScheduler +} + +// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot. +type scheduledAuthMeta struct { + auth *Auth + providerKey string + priority int + virtualParent string + websocketEnabled bool + supportedModelSet map[string]struct{} +} + +// modelScheduler tracks ready and blocked auths for one provider/model combination. +type modelScheduler struct { + modelKey string + entries map[string]*scheduledAuth + priorityOrder []int + readyByPriority map[int]*readyBucket + blocked cooldownQueue +} + +// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard. +type scheduledAuth struct { + meta *scheduledAuthMeta + auth *Auth + state scheduledState + nextRetryAt time.Time +} + +// readyBucket keeps the ready views for one priority level. +type readyBucket struct { + all readyView + ws readyView +} + +// readyView holds the selection order for flat or grouped round-robin traversal. +type readyView struct { + flat []*scheduledAuth + cursor int + parentOrder []string + parentCursor int + children map[string]*childBucket +} + +// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths. +type childBucket struct { + items []*scheduledAuth + cursor int +} + +// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds. +type cooldownQueue []*scheduledAuth + +// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy. +func newAuthScheduler(selector Selector) *authScheduler { + return &authScheduler{ + strategy: selectorStrategy(selector), + providers: make(map[string]*providerScheduler), + authProviders: make(map[string]string), + mixedCursors: make(map[string]int), + } +} + +// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate. +func selectorStrategy(selector Selector) schedulerStrategy { + switch selector.(type) { + case *FillFirstSelector: + return schedulerStrategyFillFirst + case nil, *RoundRobinSelector: + return schedulerStrategyRoundRobin + default: + return schedulerStrategyCustom + } +} + +// setSelector updates the active built-in strategy and resets mixed-provider cursors. +func (s *authScheduler) setSelector(selector Selector) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.strategy = selectorStrategy(selector) + clear(s.mixedCursors) +} + +// rebuild recreates the complete scheduler state from an auth snapshot. +func (s *authScheduler) rebuild(auths []*Auth) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.providers = make(map[string]*providerScheduler) + s.authProviders = make(map[string]string) + s.mixedCursors = make(map[string]int) + now := time.Now() + for _, auth := range auths { + s.upsertAuthLocked(auth, now) + } +} + +// upsertAuth incrementally synchronizes one auth into the scheduler. +func (s *authScheduler) upsertAuth(auth *Auth) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.upsertAuthLocked(auth, time.Now()) +} + +// removeAuth deletes one auth from every scheduler shard that references it. +func (s *authScheduler) removeAuth(authID string) { + if s == nil { + return + } + authID = strings.TrimSpace(authID) + if authID == "" { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.removeAuthLocked(authID) +} + +// pickSingle returns the next auth for a single provider/model request using scheduler state. +func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) { + if s == nil { + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + providerKey := strings.ToLower(strings.TrimSpace(provider)) + modelKey := canonicalModelKey(model) + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == "" + + s.mu.Lock() + defer s.mu.Unlock() + providerState := s.providers[providerKey] + if providerState == nil { + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + shard := providerState.ensureModelLocked(modelKey, time.Now()) + if shard == nil { + return nil, &Error{Code: "auth_not_found", Message: "no auth available"} + } + predicate := func(entry *scheduledAuth) bool { + if entry == nil || entry.auth == nil { + return false + } + if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID { + return false + } + if len(tried) > 0 { + if _, ok := tried[entry.auth.ID]; ok { + return false + } + } + return true + } + if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil { + return picked, nil + } + return nil, shard.unavailableErrorLocked(provider, model, predicate) +} + +// pickMixed returns the next auth and provider for a mixed-provider request. +func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) { + if s == nil { + return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + normalized := normalizeProviderKeys(providers) + if len(normalized) == 0 { + return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + modelKey := canonicalModelKey(model) + + s.mu.Lock() + defer s.mu.Unlock() + if pinnedAuthID != "" { + providerKey := s.authProviders[pinnedAuthID] + if providerKey == "" || !containsProvider(normalized, providerKey) { + return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + providerState := s.providers[providerKey] + if providerState == nil { + return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} + } + shard := providerState.ensureModelLocked(modelKey, time.Now()) + predicate := func(entry *scheduledAuth) bool { + if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID { + return false + } + if len(tried) == 0 { + return true + } + _, ok := tried[pinnedAuthID] + return !ok + } + if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil { + return picked, providerKey, nil + } + return nil, "", shard.unavailableErrorLocked("mixed", model, predicate) + } + + predicate := triedPredicate(tried) + candidateShards := make([]*modelScheduler, len(normalized)) + bestPriority := 0 + hasCandidate := false + now := time.Now() + for providerIndex, providerKey := range normalized { + providerState := s.providers[providerKey] + if providerState == nil { + continue + } + shard := providerState.ensureModelLocked(modelKey, now) + candidateShards[providerIndex] = shard + if shard == nil { + continue + } + priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate) + if !okPriority { + continue + } + if !hasCandidate || priorityReady > bestPriority { + bestPriority = priorityReady + hasCandidate = true + } + } + if !hasCandidate { + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) + } + + if s.strategy == schedulerStrategyFillFirst { + for providerIndex, providerKey := range normalized { + shard := candidateShards[providerIndex] + if shard == nil { + continue + } + picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate) + if picked != nil { + return picked, providerKey, nil + } + } + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) + } + + cursorKey := strings.Join(normalized, ",") + ":" + modelKey + start := 0 + if len(normalized) > 0 { + start = s.mixedCursors[cursorKey] % len(normalized) + } + for offset := 0; offset < len(normalized); offset++ { + providerIndex := (start + offset) % len(normalized) + providerKey := normalized[providerIndex] + shard := candidateShards[providerIndex] + if shard == nil { + continue + } + picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate) + if picked == nil { + continue + } + s.mixedCursors[cursorKey] = providerIndex + 1 + return picked, providerKey, nil + } + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) +} + +// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error. +func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error { + now := time.Now() + total := 0 + cooldownCount := 0 + earliest := time.Time{} + for _, providerKey := range providers { + providerState := s.providers[providerKey] + if providerState == nil { + continue + } + shard := providerState.ensureModelLocked(canonicalModelKey(model), now) + if shard == nil { + continue + } + localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried)) + total += localTotal + cooldownCount += localCooldownCount + if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) { + earliest = localEarliest + } + } + if total == 0 { + return &Error{Code: "auth_not_found", Message: "no auth available"} + } + if cooldownCount == total && !earliest.IsZero() { + resetIn := earliest.Sub(now) + if resetIn < 0 { + resetIn = 0 + } + return newModelCooldownError(model, "", resetIn) + } + return &Error{Code: "auth_unavailable", Message: "no auth available"} +} + +// triedPredicate builds a filter that excludes auths already attempted for the current request. +func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool { + if len(tried) == 0 { + return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil } + } + return func(entry *scheduledAuth) bool { + if entry == nil || entry.auth == nil { + return false + } + _, ok := tried[entry.auth.ID] + return !ok + } +} + +// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order. +func normalizeProviderKeys(providers []string) []string { + seen := make(map[string]struct{}, len(providers)) + out := make([]string, 0, len(providers)) + for _, provider := range providers { + providerKey := strings.ToLower(strings.TrimSpace(provider)) + if providerKey == "" { + continue + } + if _, ok := seen[providerKey]; ok { + continue + } + seen[providerKey] = struct{}{} + out = append(out, providerKey) + } + return out +} + +// containsProvider reports whether provider is present in the normalized provider list. +func containsProvider(providers []string, provider string) bool { + for _, candidate := range providers { + if candidate == provider { + return true + } + } + return false +} + +// upsertAuthLocked updates one auth in-place while the scheduler mutex is held. +func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) { + if auth == nil { + return + } + authID := strings.TrimSpace(auth.ID) + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + if authID == "" || providerKey == "" || auth.Disabled { + s.removeAuthLocked(authID) + return + } + if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey { + if previousState := s.providers[previousProvider]; previousState != nil { + previousState.removeAuthLocked(authID) + } + } + meta := buildScheduledAuthMeta(auth) + s.authProviders[authID] = providerKey + s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now) +} + +// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held. +func (s *authScheduler) removeAuthLocked(authID string) { + if authID == "" { + return + } + if providerKey := s.authProviders[authID]; providerKey != "" { + if providerState := s.providers[providerKey]; providerState != nil { + providerState.removeAuthLocked(authID) + } + delete(s.authProviders, authID) + } +} + +// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed. +func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler { + if s.providers == nil { + s.providers = make(map[string]*providerScheduler) + } + providerState := s.providers[providerKey] + if providerState == nil { + providerState = &providerScheduler{ + providerKey: providerKey, + auths: make(map[string]*scheduledAuthMeta), + modelShards: make(map[string]*modelScheduler), + } + s.providers[providerKey] = providerState + } + return providerState +} + +// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping. +func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta { + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + virtualParent := "" + if auth.Attributes != nil { + virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"]) + } + return &scheduledAuthMeta{ + auth: auth, + providerKey: providerKey, + priority: authPriority(auth), + virtualParent: virtualParent, + websocketEnabled: authWebsocketsEnabled(auth), + supportedModelSet: supportedModelSetForAuth(auth.ID), + } +} + +// supportedModelSetForAuth snapshots the registry models currently registered for an auth. +func supportedModelSetForAuth(authID string) map[string]struct{} { + authID = strings.TrimSpace(authID) + if authID == "" { + return nil + } + models := registry.GetGlobalRegistry().GetModelsForClient(authID) + if len(models) == 0 { + return nil + } + set := make(map[string]struct{}, len(models)) + for _, model := range models { + if model == nil { + continue + } + modelKey := canonicalModelKey(model.ID) + if modelKey == "" { + continue + } + set[modelKey] = struct{}{} + } + return set +} + +// upsertAuthLocked updates every existing model shard that can reference the auth metadata. +func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) { + if p == nil || meta == nil || meta.auth == nil { + return + } + p.auths[meta.auth.ID] = meta + for modelKey, shard := range p.modelShards { + if shard == nil { + continue + } + if !meta.supportsModel(modelKey) { + shard.removeEntryLocked(meta.auth.ID) + continue + } + shard.upsertEntryLocked(meta, now) + } +} + +// removeAuthLocked removes an auth from all model shards owned by the provider scheduler. +func (p *providerScheduler) removeAuthLocked(authID string) { + if p == nil || authID == "" { + return + } + delete(p.auths, authID) + for _, shard := range p.modelShards { + if shard != nil { + shard.removeEntryLocked(authID) + } + } +} + +// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths. +func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler { + if p == nil { + return nil + } + modelKey = canonicalModelKey(modelKey) + if shard, ok := p.modelShards[modelKey]; ok && shard != nil { + shard.promoteExpiredLocked(now) + return shard + } + shard := &modelScheduler{ + modelKey: modelKey, + entries: make(map[string]*scheduledAuth), + readyByPriority: make(map[int]*readyBucket), + } + for _, meta := range p.auths { + if meta == nil || !meta.supportsModel(modelKey) { + continue + } + shard.upsertEntryLocked(meta, now) + } + p.modelShards[modelKey] = shard + return shard +} + +// supportsModel reports whether the auth metadata currently supports modelKey. +func (m *scheduledAuthMeta) supportsModel(modelKey string) bool { + modelKey = canonicalModelKey(modelKey) + if modelKey == "" { + return true + } + if len(m.supportedModelSet) == 0 { + return false + } + _, ok := m.supportedModelSet[modelKey] + return ok +} + +// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes. +func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) { + if m == nil || meta == nil || meta.auth == nil { + return + } + entry, ok := m.entries[meta.auth.ID] + if !ok || entry == nil { + entry = &scheduledAuth{} + m.entries[meta.auth.ID] = entry + } + previousState := entry.state + previousNextRetryAt := entry.nextRetryAt + previousPriority := 0 + previousParent := "" + previousWebsocketEnabled := false + if entry.meta != nil { + previousPriority = entry.meta.priority + previousParent = entry.meta.virtualParent + previousWebsocketEnabled = entry.meta.websocketEnabled + } + + entry.meta = meta + entry.auth = meta.auth + entry.nextRetryAt = time.Time{} + blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now) + switch { + case !blocked: + entry.state = scheduledStateReady + case reason == blockReasonCooldown: + entry.state = scheduledStateCooldown + entry.nextRetryAt = next + case reason == blockReasonDisabled: + entry.state = scheduledStateDisabled + default: + entry.state = scheduledStateBlocked + entry.nextRetryAt = next + } + + if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled { + return + } + m.rebuildIndexesLocked() +} + +// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed. +func (m *modelScheduler) removeEntryLocked(authID string) { + if m == nil || authID == "" { + return + } + if _, ok := m.entries[authID]; !ok { + return + } + delete(m.entries, authID) + m.rebuildIndexesLocked() +} + +// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed. +func (m *modelScheduler) promoteExpiredLocked(now time.Time) { + if m == nil || len(m.blocked) == 0 { + return + } + changed := false + for _, entry := range m.blocked { + if entry == nil || entry.auth == nil { + continue + } + if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) { + continue + } + blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now) + switch { + case !blocked: + entry.state = scheduledStateReady + entry.nextRetryAt = time.Time{} + case reason == blockReasonCooldown: + entry.state = scheduledStateCooldown + entry.nextRetryAt = next + case reason == blockReasonDisabled: + entry.state = scheduledStateDisabled + entry.nextRetryAt = time.Time{} + default: + entry.state = scheduledStateBlocked + entry.nextRetryAt = next + } + changed = true + } + if changed { + m.rebuildIndexesLocked() + } +} + +// pickReadyLocked selects the next ready auth from the highest available priority bucket. +func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth { + if m == nil { + return nil + } + m.promoteExpiredLocked(time.Now()) + priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate) + if !okPriority { + return nil + } + return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate) +} + +// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth. +// The caller must ensure expired entries are already promoted when needed. +func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) { + if m == nil { + return 0, false + } + for _, priority := range m.priorityOrder { + bucket := m.readyByPriority[priority] + if bucket == nil { + continue + } + view := &bucket.all + if preferWebsocket && len(bucket.ws.flat) > 0 { + view = &bucket.ws + } + if view.pickFirst(predicate) != nil { + return priority, true + } + } + return 0, false +} + +// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket. +// The caller must ensure expired entries are already promoted when needed. +func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth { + if m == nil { + return nil + } + bucket := m.readyByPriority[priority] + if bucket == nil { + return nil + } + view := &bucket.all + if preferWebsocket && len(bucket.ws.flat) > 0 { + view = &bucket.ws + } + var picked *scheduledAuth + if strategy == schedulerStrategyFillFirst { + picked = view.pickFirst(predicate) + } else { + picked = view.pickRoundRobin(predicate) + } + if picked == nil || picked.auth == nil { + return nil + } + return picked.auth +} + +// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard. +func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error { + now := time.Now() + total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate) + if total == 0 { + return &Error{Code: "auth_not_found", Message: "no auth available"} + } + if cooldownCount == total && !earliest.IsZero() { + providerForError := provider + if providerForError == "mixed" { + providerForError = "" + } + resetIn := earliest.Sub(now) + if resetIn < 0 { + resetIn = 0 + } + return newModelCooldownError(model, providerForError, resetIn) + } + return &Error{Code: "auth_unavailable", Message: "no auth available"} +} + +// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time. +func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) { + if m == nil { + return 0, 0, time.Time{} + } + total := 0 + cooldownCount := 0 + earliest := time.Time{} + for _, entry := range m.entries { + if predicate != nil && !predicate(entry) { + continue + } + total++ + if entry == nil || entry.auth == nil { + continue + } + if entry.state != scheduledStateCooldown { + continue + } + cooldownCount++ + if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) { + earliest = entry.nextRetryAt + } + } + return total, cooldownCount, earliest +} + +// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map. +func (m *modelScheduler) rebuildIndexesLocked() { + m.readyByPriority = make(map[int]*readyBucket) + m.priorityOrder = m.priorityOrder[:0] + m.blocked = m.blocked[:0] + priorityBuckets := make(map[int][]*scheduledAuth) + for _, entry := range m.entries { + if entry == nil || entry.auth == nil { + continue + } + switch entry.state { + case scheduledStateReady: + priority := entry.meta.priority + priorityBuckets[priority] = append(priorityBuckets[priority], entry) + case scheduledStateCooldown, scheduledStateBlocked: + m.blocked = append(m.blocked, entry) + } + } + for priority, entries := range priorityBuckets { + sort.Slice(entries, func(i, j int) bool { + return entries[i].auth.ID < entries[j].auth.ID + }) + m.readyByPriority[priority] = buildReadyBucket(entries) + m.priorityOrder = append(m.priorityOrder, priority) + } + sort.Slice(m.priorityOrder, func(i, j int) bool { + return m.priorityOrder[i] > m.priorityOrder[j] + }) + sort.Slice(m.blocked, func(i, j int) bool { + left := m.blocked[i] + right := m.blocked[j] + if left == nil || right == nil { + return left != nil + } + if left.nextRetryAt.Equal(right.nextRetryAt) { + return left.auth.ID < right.auth.ID + } + if left.nextRetryAt.IsZero() { + return false + } + if right.nextRetryAt.IsZero() { + return true + } + return left.nextRetryAt.Before(right.nextRetryAt) + }) +} + +// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket. +func buildReadyBucket(entries []*scheduledAuth) *readyBucket { + bucket := &readyBucket{} + bucket.all = buildReadyView(entries) + wsEntries := make([]*scheduledAuth, 0, len(entries)) + for _, entry := range entries { + if entry != nil && entry.meta != nil && entry.meta.websocketEnabled { + wsEntries = append(wsEntries, entry) + } + } + bucket.ws = buildReadyView(wsEntries) + return bucket +} + +// buildReadyView creates either a flat view or a grouped parent/child view for rotation. +func buildReadyView(entries []*scheduledAuth) readyView { + view := readyView{flat: append([]*scheduledAuth(nil), entries...)} + if len(entries) == 0 { + return view + } + groups := make(map[string][]*scheduledAuth) + for _, entry := range entries { + if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" { + return view + } + groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry) + } + if len(groups) <= 1 { + return view + } + view.children = make(map[string]*childBucket, len(groups)) + view.parentOrder = make([]string, 0, len(groups)) + for parent := range groups { + view.parentOrder = append(view.parentOrder, parent) + } + sort.Strings(view.parentOrder) + for _, parent := range view.parentOrder { + view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)} + } + return view +} + +// pickFirst returns the first ready entry that satisfies predicate without advancing cursors. +func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth { + for _, entry := range v.flat { + if predicate == nil || predicate(entry) { + return entry + } + } + return nil +} + +// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal. +func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth { + if len(v.parentOrder) > 1 && len(v.children) > 0 { + return v.pickGroupedRoundRobin(predicate) + } + if len(v.flat) == 0 { + return nil + } + start := 0 + if len(v.flat) > 0 { + start = v.cursor % len(v.flat) + } + for offset := 0; offset < len(v.flat); offset++ { + index := (start + offset) % len(v.flat) + entry := v.flat[index] + if predicate != nil && !predicate(entry) { + continue + } + v.cursor = index + 1 + return entry + } + return nil +} + +// pickGroupedRoundRobin rotates across parents first and then within the selected parent. +func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth { + start := 0 + if len(v.parentOrder) > 0 { + start = v.parentCursor % len(v.parentOrder) + } + for offset := 0; offset < len(v.parentOrder); offset++ { + parentIndex := (start + offset) % len(v.parentOrder) + parent := v.parentOrder[parentIndex] + child := v.children[parent] + if child == nil || len(child.items) == 0 { + continue + } + itemStart := child.cursor % len(child.items) + for itemOffset := 0; itemOffset < len(child.items); itemOffset++ { + itemIndex := (itemStart + itemOffset) % len(child.items) + entry := child.items[itemIndex] + if predicate != nil && !predicate(entry) { + continue + } + child.cursor = itemIndex + 1 + v.parentCursor = parentIndex + 1 + return entry + } + } + return nil +} diff --git a/sdk/cliproxy/auth/scheduler_benchmark_test.go b/sdk/cliproxy/auth/scheduler_benchmark_test.go new file mode 100644 index 00000000..050a7cbd --- /dev/null +++ b/sdk/cliproxy/auth/scheduler_benchmark_test.go @@ -0,0 +1,216 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type schedulerBenchmarkExecutor struct { + id string +} + +func (e schedulerBenchmarkExecutor) Identifier() string { return e.id } + +func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, nil +} + +func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + return nil, nil +} + +func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) { + b.Helper() + manager := NewManager(nil, &RoundRobinSelector{}, nil) + providers := []string{"gemini"} + manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"} + if mixed { + providers = []string{"gemini", "claude"} + manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"} + } + + reg := registry.GetGlobalRegistry() + model := "bench-model" + for index := 0; index < total; index++ { + provider := providers[0] + if mixed && index%2 == 1 { + provider = providers[1] + } + auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider} + if withPriority { + priority := "0" + if index%2 == 0 { + priority = "10" + } + auth.Attributes = map[string]string{"priority": priority} + } + _, errRegister := manager.Register(context.Background(), auth) + if errRegister != nil { + b.Fatalf("Register(%s) error = %v", auth.ID, errRegister) + } + reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}}) + } + manager.syncScheduler() + b.Cleanup(func() { + for index := 0; index < total; index++ { + provider := providers[0] + if mixed && index%2 == 1 { + provider = providers[1] + } + reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index)) + } + }) + + return manager, providers, model +} + +func BenchmarkManagerPickNext500(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 500, false, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNext1000(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 1000, false, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNextPriority500(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 500, false, true) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNextPriority1000(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 1000, false, true) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil || exec == nil { + b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick) + } + } +} + +func BenchmarkManagerPickNextMixed500(b *testing.B) { + manager, providers, model := benchmarkManagerSetup(b, 500, true, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNextMixed error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried) + if errPick != nil || auth == nil || exec == nil || provider == "" { + b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick) + } + } +} + +func BenchmarkManagerPickNextMixedPriority500(b *testing.B) { + manager, providers, model := benchmarkManagerSetup(b, 500, true, true) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNextMixed error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried) + if errPick != nil || auth == nil || exec == nil || provider == "" { + b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick) + } + } +} + +func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) { + manager, _, model := benchmarkManagerSetup(b, 1000, false, false) + ctx := context.Background() + opts := cliproxyexecutor.Options{} + tried := map[string]struct{}{} + if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil { + b.Fatalf("warmup pickNext error = %v", errWarm) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried) + if errPick != nil || auth == nil { + b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick) + } + manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true}) + } +} diff --git a/sdk/cliproxy/auth/scheduler_test.go b/sdk/cliproxy/auth/scheduler_test.go new file mode 100644 index 00000000..e7d435a9 --- /dev/null +++ b/sdk/cliproxy/auth/scheduler_test.go @@ -0,0 +1,503 @@ +package auth + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type schedulerTestExecutor struct{} + +func (schedulerTestExecutor) Identifier() string { return "test" } + +func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + return nil, nil +} + +func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) { + return nil, nil +} + +type trackingSelector struct { + calls int + lastAuthID []string +} + +func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { + s.calls++ + s.lastAuthID = s.lastAuthID[:0] + for _, auth := range auths { + s.lastAuthID = append(s.lastAuthID, auth.ID) + } + if len(auths) == 0 { + return nil, nil + } + return auths[len(auths)-1], nil +} + +func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler { + scheduler := newAuthScheduler(selector) + scheduler.rebuild(auths) + return scheduler +} + +func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) { + t.Helper() + reg := registry.GetGlobalRegistry() + for _, authID := range authIDs { + reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}}) + } + t.Cleanup(func() { + for _, authID := range authIDs { + reg.UnregisterClient(authID) + } + }) +} + +func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}}, + &Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}}, + &Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}}, + ) + + want := []string{"high-a", "high-b", "high-a"} + for index, wantID := range want { + got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantID { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID) + } + } +} + +func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &FillFirstSelector{}, + &Auth{ID: "b", Provider: "gemini"}, + &Auth{ID: "a", Provider: "gemini"}, + &Auth{ID: "c", Provider: "gemini"}, + ) + + for index := 0; index < 3; index++ { + got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != "a" { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a") + } + } +} + +func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) { + t.Parallel() + + model := "gemini-2.5-pro" + registerSchedulerModels(t, "gemini", model, "cooldown-expired") + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ + ID: "cooldown-expired", + Provider: "gemini", + ModelStates: map[string]*ModelState{ + model: { + Status: StatusError, + Unavailable: true, + NextRetryAfter: time.Now().Add(-1 * time.Second), + }, + }, + }, + ) + + got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() error = %v", errPick) + } + if got == nil { + t.Fatalf("pickSingle() auth = nil") + } + if got.ID != "cooldown-expired" { + t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired") + } +} + +func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) { + t.Parallel() + + registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2") + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}}, + &Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}}, + &Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}}, + &Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}}, + ) + + wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"} + wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"} + for index := range wantIDs { + got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + if got.Attributes["gemini_virtual_parent"] != wantParents[index] { + t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index]) + } + } +} + +func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "codex-http", Provider: "codex"}, + &Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}}, + &Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}}, + ) + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"} + for index, wantID := range want { + got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickSingle() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickSingle() #%d auth = nil", index) + } + if got.ID != wantID { + t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID) + } + } +} + +func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) { + t.Parallel() + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "gemini-a", Provider: "gemini"}, + &Auth{ID: "gemini-b", Provider: "gemini"}, + &Auth{ID: "claude-a", Provider: "claude"}, + ) + + wantProviders := []string{"gemini", "claude", "gemini", "claude"} + wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + for index := range wantProviders { + got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) { + t.Parallel() + + model := "gpt-default" + registerSchedulerModels(t, "provider-low", model, "low") + registerSchedulerModels(t, "provider-high-a", model, "high-a") + registerSchedulerModels(t, "provider-high-b", model, "high-b") + + scheduler := newSchedulerForTest( + &RoundRobinSelector{}, + &Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}}, + &Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}}, + &Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}}, + ) + + providers := []string{"provider-low", "provider-high-a", "provider-high-b"} + wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"} + wantIDs := []string{"high-a", "high-b", "high-a", "high-b"} + for index := range wantProviders { + got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.executors["gemini"] = schedulerTestExecutor{} + manager.executors["claude"] = schedulerTestExecutor{} + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-b) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil { + t.Fatalf("Register(claude-a) error = %v", errRegister) + } + + wantProviders := []string{"gemini", "claude", "gemini", "claude"} + wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + for index := range wantProviders { + got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{}) + if errPick != nil { + t.Fatalf("pickNextMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickNextMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) { + t.Parallel() + + selector := &trackingSelector{} + manager := NewManager(nil, selector, nil) + manager.executors["gemini"] = schedulerTestExecutor{} + manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"} + manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"} + + got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{}) + if errPick != nil { + t.Fatalf("pickNext() error = %v", errPick) + } + if got == nil { + t.Fatalf("pickNext() auth = nil") + } + if selector.calls != 1 { + t.Fatalf("selector.calls = %d, want %d", selector.calls, 1) + } + if len(selector.lastAuthID) != 2 { + t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2) + } + if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] { + t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1]) + } +} + +func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + if manager.scheduler == nil { + t.Fatalf("manager.scheduler = nil") + } + if manager.scheduler.strategy != schedulerStrategyRoundRobin { + t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin) + } + + manager.SetSelector(&FillFirstSelector{}) + if manager.scheduler.strategy != schedulerStrategyFillFirst { + t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst) + } +} + +func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-b) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-a) error = %v", errRegister) + } + + got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() error = %v", errPick) + } + if got == nil || got.ID != "auth-a" { + t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got) + } + + if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil { + t.Fatalf("Update(auth-a) error = %v", errUpdate) + } + + got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() after update error = %v", errPick) + } + if got == nil || got.ID != "auth-b" { + t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got) + } +} + +func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.executors["gemini"] = schedulerTestExecutor{} + manager.executors["claude"] = schedulerTestExecutor{} + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-b) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil { + t.Fatalf("Register(claude-a) error = %v", errRegister) + } + + wantProviders := []string{"gemini", "claude", "gemini", "claude"} + wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + for index := range wantProviders { + got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickNextMixed() #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("pickNextMixed() #%d auth = nil", index) + } + if provider != wantProviders[index] { + t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index]) + } + if got.ID != wantIDs[index] { + t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index]) + } + } +} + +func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + manager.executors["claude"] = schedulerTestExecutor{} + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(gemini-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil { + t.Fatalf("Register(claude-a) error = %v", errRegister) + } + + got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("pickNextMixed() error = %v", errPick) + } + if got == nil { + t.Fatalf("pickNextMixed() auth = nil") + } + if provider != "claude" { + t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude") + } + if got.ID != "claude-a" { + t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a") + } +} + +func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, &RoundRobinSelector{}, nil) + reg := registry.GetGlobalRegistry() + reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}}) + reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + reg.UnregisterClient("auth-a") + reg.UnregisterClient("auth-b") + }) + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-a) error = %v", errRegister) + } + if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil { + t.Fatalf("Register(auth-b) error = %v", errRegister) + } + + manager.MarkResult(context.Background(), Result{ + AuthID: "auth-a", + Provider: "gemini", + Model: "test-model", + Success: false, + Error: &Error{HTTPStatus: 429, Message: "quota"}, + }) + + got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick) + } + if got == nil || got.ID != "auth-b" { + t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got) + } + + manager.MarkResult(context.Background(), Result{ + AuthID: "auth-a", + Provider: "gemini", + Model: "test-model", + Success: true, + }) + + seen := make(map[string]struct{}, 2) + for index := 0; index < 2; index++ { + got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil) + if errPick != nil { + t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick) + } + if got == nil { + t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index) + } + seen[got.ID] = struct{}{} + } + if len(seen) != 2 { + t.Fatalf("len(seen) = %d, want %d", len(seen), 2) + } +} diff --git a/sdk/cliproxy/rtprovider.go b/sdk/cliproxy/rtprovider.go index dad4fc23..5c4f579a 100644 --- a/sdk/cliproxy/rtprovider.go +++ b/sdk/cliproxy/rtprovider.go @@ -1,16 +1,13 @@ package cliproxy import ( - "context" - "net" "net/http" - "net/url" "strings" "sync" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" ) // defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on @@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http. if rt != nil { return rt } - // Parse the proxy URL to determine the scheme. - proxyURL, errParse := url.Parse(proxyStr) - if errParse != nil { - log.Errorf("parse proxy URL failed: %v", errParse) + transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr) + if errBuild != nil { + log.Errorf("%v", errBuild) return nil } - var transport *http.Transport - // Handle different proxy schemes. - if proxyURL.Scheme == "socks5" { - // Configure SOCKS5 proxy with optional authentication. - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5) - return nil - } - // Set up a custom transport using the SOCKS5 dialer. - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Configure HTTP or HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } else { - log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) + if transport == nil { return nil } p.mu.Lock() diff --git a/sdk/cliproxy/rtprovider_test.go b/sdk/cliproxy/rtprovider_test.go new file mode 100644 index 00000000..f907081e --- /dev/null +++ b/sdk/cliproxy/rtprovider_test.go @@ -0,0 +1,22 @@ +package cliproxy + +import ( + "net/http" + "testing" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestRoundTripperForDirectBypassesProxy(t *testing.T) { + t.Parallel() + + provider := newDefaultRoundTripperProvider() + rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"}) + transport, ok := rt.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", rt) + } + if transport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 6124f8b1..596db3dd 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -312,6 +312,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A // This operation may block on network calls, but the auth configuration // is already effective at this point. s.registerModelsForAuth(auth) + + // Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt + // from the now-populated global model registry. Without this, newly added auths + // have an empty supportedModelSet (because Register/Update upserts into the + // scheduler before registerModelsForAuth runs) and are invisible to the scheduler. + s.coreManager.RefreshSchedulerEntry(auth.ID) } func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { @@ -823,7 +829,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } models = applyExcludedModels(models, excluded) case "codex": - models = registry.GetOpenAIModels() + codexPlanType := "" + if a.Attributes != nil { + codexPlanType = strings.TrimSpace(a.Attributes["plan_type"]) + } + switch strings.ToLower(codexPlanType) { + case "pro": + models = registry.GetCodexProModels() + case "plus": + models = registry.GetCodexPlusModels() + case "team": + models = registry.GetCodexTeamModels() + case "free": + models = registry.GetCodexFreeModels() + default: + models = registry.GetCodexProModels() + } if entry := s.resolveConfigCodexKey(a); entry != nil { if len(entry.Models) > 0 { models = buildCodexConfigModels(entry) diff --git a/sdk/proxyutil/proxy.go b/sdk/proxyutil/proxy.go new file mode 100644 index 00000000..591ec9d9 --- /dev/null +++ b/sdk/proxyutil/proxy.go @@ -0,0 +1,139 @@ +package proxyutil + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/proxy" +) + +// Mode describes how a proxy setting should be interpreted. +type Mode int + +const ( + // ModeInherit means no explicit proxy behavior was configured. + ModeInherit Mode = iota + // ModeDirect means outbound requests must bypass proxies explicitly. + ModeDirect + // ModeProxy means a concrete proxy URL was configured. + ModeProxy + // ModeInvalid means the proxy setting is present but malformed or unsupported. + ModeInvalid +) + +// Setting is the normalized interpretation of a proxy configuration value. +type Setting struct { + Raw string + Mode Mode + URL *url.URL +} + +// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes. +func Parse(raw string) (Setting, error) { + trimmed := strings.TrimSpace(raw) + setting := Setting{Raw: trimmed} + + if trimmed == "" { + setting.Mode = ModeInherit + return setting, nil + } + + if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") { + setting.Mode = ModeDirect + return setting, nil + } + + parsedURL, errParse := url.Parse(trimmed) + if errParse != nil { + setting.Mode = ModeInvalid + return setting, fmt.Errorf("parse proxy URL failed: %w", errParse) + } + if parsedURL.Scheme == "" || parsedURL.Host == "" { + setting.Mode = ModeInvalid + return setting, fmt.Errorf("proxy URL missing scheme/host") + } + + switch parsedURL.Scheme { + case "socks5", "http", "https": + setting.Mode = ModeProxy + setting.URL = parsedURL + return setting, nil + default: + setting.Mode = ModeInvalid + return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme) + } +} + +// NewDirectTransport returns a transport that bypasses environment proxies. +func NewDirectTransport() *http.Transport { + if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil { + clone := transport.Clone() + clone.Proxy = nil + return clone + } + return &http.Transport{Proxy: nil} +} + +// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting. +func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) { + setting, errParse := Parse(raw) + if errParse != nil { + return nil, setting.Mode, errParse + } + + switch setting.Mode { + case ModeInherit: + return nil, setting.Mode, nil + case ModeDirect: + return NewDirectTransport(), setting.Mode, nil + case ModeProxy: + if setting.URL.Scheme == "socks5" { + var proxyAuth *proxy.Auth + if setting.URL.User != nil { + username := setting.URL.User.Username() + password, _ := setting.URL.User.Password() + proxyAuth = &proxy.Auth{User: username, Password: password} + } + dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5) + } + return &http.Transport{ + Proxy: nil, + DialContext: func(_ context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + }, setting.Mode, nil + } + return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil + default: + return nil, setting.Mode, nil + } +} + +// BuildDialer constructs a proxy dialer for settings that operate at the connection layer. +func BuildDialer(raw string) (proxy.Dialer, Mode, error) { + setting, errParse := Parse(raw) + if errParse != nil { + return nil, setting.Mode, errParse + } + + switch setting.Mode { + case ModeInherit: + return nil, setting.Mode, nil + case ModeDirect: + return proxy.Direct, setting.Mode, nil + case ModeProxy: + dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct) + if errDialer != nil { + return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer) + } + return dialer, setting.Mode, nil + default: + return nil, setting.Mode, nil + } +} diff --git a/sdk/proxyutil/proxy_test.go b/sdk/proxyutil/proxy_test.go new file mode 100644 index 00000000..bea413dc --- /dev/null +++ b/sdk/proxyutil/proxy_test.go @@ -0,0 +1,89 @@ +package proxyutil + +import ( + "net/http" + "testing" +) + +func TestParse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + want Mode + wantErr bool + }{ + {name: "inherit", input: "", want: ModeInherit}, + {name: "direct", input: "direct", want: ModeDirect}, + {name: "none", input: "none", want: ModeDirect}, + {name: "http", input: "http://proxy.example.com:8080", want: ModeProxy}, + {name: "https", input: "https://proxy.example.com:8443", want: ModeProxy}, + {name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy}, + {name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + setting, errParse := Parse(tt.input) + if tt.wantErr && errParse == nil { + t.Fatal("expected error, got nil") + } + if !tt.wantErr && errParse != nil { + t.Fatalf("unexpected error: %v", errParse) + } + if setting.Mode != tt.want { + t.Fatalf("mode = %d, want %d", setting.Mode, tt.want) + } + }) + } +} + +func TestBuildHTTPTransportDirectBypassesProxy(t *testing.T) { + t.Parallel() + + transport, mode, errBuild := BuildHTTPTransport("direct") + if errBuild != nil { + t.Fatalf("BuildHTTPTransport returned error: %v", errBuild) + } + if mode != ModeDirect { + t.Fatalf("mode = %d, want %d", mode, ModeDirect) + } + if transport == nil { + t.Fatal("expected transport, got nil") + } + if transport.Proxy != nil { + t.Fatal("expected direct transport to disable proxy function") + } +} + +func TestBuildHTTPTransportHTTPProxy(t *testing.T) { + t.Parallel() + + transport, mode, errBuild := BuildHTTPTransport("http://proxy.example.com:8080") + if errBuild != nil { + t.Fatalf("BuildHTTPTransport returned error: %v", errBuild) + } + if mode != ModeProxy { + t.Fatalf("mode = %d, want %d", mode, ModeProxy) + } + if transport == nil { + t.Fatal("expected transport, got nil") + } + + req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errRequest != nil { + t.Fatalf("http.NewRequest returned error: %v", errRequest) + } + + proxyURL, errProxy := transport.Proxy(req) + if errProxy != nil { + t.Fatalf("transport.Proxy returned error: %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" { + t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL) + } +}