Merge branch 'dev' into codex/custom-useragent-request

This commit is contained in:
Luis Pater
2026-03-11 22:55:50 +08:00
committed by GitHub
58 changed files with 7068 additions and 1724 deletions
+4
View File
@@ -15,6 +15,8 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Login to DockerHub - name: Login to DockerHub
@@ -46,6 +48,8 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 uses: docker/setup-buildx-action@v3
- name: Login to DockerHub - name: Login to DockerHub
+2
View File
@@ -12,6 +12,8 @@ jobs:
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
+2
View File
@@ -16,6 +16,8 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- run: git fetch --force --tags - run: git fetch --force --tags
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
+4
View File
@@ -150,6 +150,10 @@ A Windows tray application implemented using PowerShell scripts, without relying
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed. A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
> [!NOTE] > [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list. > If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
+4
View File
@@ -149,6 +149,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。 一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
> [!NOTE] > [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。 > 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
+3
View File
@@ -24,6 +24,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging" "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset" "github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store" "github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui" "github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
@@ -494,6 +495,7 @@ func main() {
if standalone { if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it. // Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath) managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
hook := tui.NewLogHook(2000) hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{}) hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook) log.AddHook(hook)
@@ -566,6 +568,7 @@ func main() {
} else { } else {
// Start the main proxy service // Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath) managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
cmd.StartService(cfg, configFilePath, password) cmd.StartService(cfg, configFilePath, password)
} }
} }
+17
View File
@@ -63,6 +63,7 @@ error-logs-max-files: 10
usage-statistics-enabled: false usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: "" proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). # When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
@@ -110,6 +111,7 @@ nonstream-keepalive-interval: 0
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # proxy-url: "socks5://proxy.example.com:1080"
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models: # models:
# - name: "gemini-2.5-flash" # upstream model name # - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model # alias: "gemini-flash" # client alias mapped to the upstream model
@@ -128,6 +130,7 @@ nonstream-keepalive-interval: 0
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models: # models:
# - name: "gpt-5-codex" # upstream model name # - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model # alias: "codex-latest" # client alias mapped to the upstream model
@@ -146,6 +149,7 @@ nonstream-keepalive-interval: 0
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models: # models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name # - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model # alias: "claude-sonnet-latest" # client alias mapped to the upstream model
@@ -191,10 +195,22 @@ nonstream-keepalive-interval: 0
# api-key-entries: # api-key-entries:
# - api-key: "sk-or-v1-...b780" # - api-key: "sk-or-v1-...b780"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# - api-key: "sk-or-v1-...b781" # without proxy-url # - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider. # models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name. # - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API. # alias: "kimi-k2" # The alias used in the API.
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,
# # and if the chosen upstream fails before producing output, the request will
# # continue with the next upstream model in the same alias pool.
# - name: "qwen3.5-plus"
# alias: "claude-opus-4.66"
# - name: "glm-5"
# alias: "claude-opus-4.66"
# - name: "kimi-k2.5"
# alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL) # Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
# vertex-api-key: # vertex-api-key:
@@ -202,6 +218,7 @@ nonstream-keepalive-interval: 0
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential # prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api # base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override # proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# headers: # headers:
# X-Custom-Header: "custom-value" # X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names # models: # optional: map aliases to upstream model names
+5 -41
View File
@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@@ -14,8 +13,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
@@ -660,45 +659,10 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
} }
func buildProxyTransport(proxyStr string) *http.Transport { func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr) transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if proxyStr == "" { if errBuild != nil {
log.WithError(errBuild).Debug("build proxy transport failed")
return nil return nil
} }
return transport
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
} }
@@ -1,173 +1,58 @@
package management package management
import ( import (
"context"
"encoding/json"
"io"
"net/http" "net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing" "testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
) )
type memoryAuthStore struct { func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
mu sync.Mutex t.Parallel()
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) { h := &Handler{
_ = ctx cfg: &config.Config{
s.mu.Lock() SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
}, },
} }
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager} transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
token, err := h.resolveTokenForAuth(context.Background(), auth) httpTransport, ok := transport.(*http.Transport)
if err != nil { if !ok {
t.Fatalf("resolveTokenForAuth: %v", err) t.Fatalf("transport type = %T, want *http.Transport", transport)
} }
if token != "new-token" { if httpTransport.Proxy != nil {
t.Fatalf("expected refreshed token, got %q", token) t.Fatal("expected direct transport to disable proxy function")
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
} }
} }
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) { func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
var callCount int t.Parallel()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL h := &Handler{
antigravityOAuthTokenURL = srv.URL cfg: &config.Config{
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL }) SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
}, },
} }
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth) transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
if err != nil { httpTransport, ok := transport.(*http.Transport)
t.Fatalf("resolveTokenForAuth: %v", err) if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
} }
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token) req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
} }
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount) proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
} }
} }
@@ -1306,12 +1306,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil { if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
return return
} }
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status") SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
return return
} }
ts.ProjectID = strings.Join(projects, ",") ts.ProjectID = strings.Join(projects, ",")
@@ -1320,7 +1320,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ts.Auto = false ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil { if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup) log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, "Google One auto-discovery failed") SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
return return
} }
if strings.TrimSpace(ts.ProjectID) == "" { if strings.TrimSpace(ts.ProjectID) == "" {
@@ -1331,19 +1331,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil { if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status") SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return return
} }
ts.Checked = isChecked ts.Checked = isChecked
if !isChecked { if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project") log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, "Cloud AI API not enabled") SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return return
} }
} else { } else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
return return
} }
@@ -1356,13 +1356,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil { if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status") SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return return
} }
ts.Checked = isChecked ts.Checked = isChecked
if !isChecked { if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project") log.Error("Cloud AI API is not enabled for the selected project")
SetOAuthSessionError(state, "Cloud AI API not enabled") SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return return
} }
} }
@@ -0,0 +1,49 @@
package management
import (
"context"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
return out, nil
}
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.items, id)
return nil
}
func (s *memoryAuthStore) SetBaseDir(string) {}
+7 -12
View File
@@ -4,12 +4,12 @@ package claude
import ( import (
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
tls "github.com/refraction-networking/utls" tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support // newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper { func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct var dialer proxy.Dialer = proxy.Direct
if cfg != nil && cfg.ProxyURL != "" { if cfg != nil {
proxyURL, err := url.Parse(cfg.ProxyURL) proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if err != nil { if errBuild != nil {
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err) log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
} else { } else if mode != proxyutil.ModeInherit && proxyDialer != nil {
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct) dialer = proxyDialer
if err != nil {
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
} else {
dialer = pDialer
}
} }
} }
+7 -29
View File
@@ -10,9 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/url"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
@@ -20,9 +18,9 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
@@ -80,35 +78,15 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
} }
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort) callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided. transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
proxyURL, err := url.Parse(cfg.ProxyURL) if errBuild != nil {
if err == nil { log.Errorf("%v", errBuild)
var transport *http.Transport } else if transport != nil {
if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if transport != nil {
proxyClient := &http.Client{Transport: transport} proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
} }
}
var err error
// Configure the OAuth2 client. // Configure the OAuth2 client.
conf := &oauth2.Config{ conf := &oauth2.Config{
+139 -13
View File
@@ -1,5 +1,5 @@
// Package registry provides model definitions and lookup helpers for various AI providers. // Package registry provides model definitions and lookup helpers for various AI providers.
// Static model metadata is stored in model_definitions_static_data.go. // Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
package registry package registry
import ( import (
@@ -7,6 +7,131 @@ import (
"strings" "strings"
) )
// AntigravityModelConfig captures static antigravity model overrides, including
// Thinking budget limits and provider max completion tokens.
type AntigravityModelConfig struct {
Thinking *ThinkingSupport `json:"thinking,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
// staticModelsJSON mirrors the top-level structure of models.json.
type staticModelsJSON struct {
Claude []*ModelInfo `json:"claude"`
Gemini []*ModelInfo `json:"gemini"`
Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"`
CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"`
CodexPro []*ModelInfo `json:"codex-pro"`
Qwen []*ModelInfo `json:"qwen"`
IFlow []*ModelInfo `json:"iflow"`
Kimi []*ModelInfo `json:"kimi"`
Antigravity map[string]*AntigravityModelConfig `json:"antigravity"`
}
// GetClaudeModels returns the standard Claude model definitions.
func GetClaudeModels() []*ModelInfo {
return cloneModelInfos(getModels().Claude)
}
// GetGeminiModels returns the standard Gemini model definitions.
func GetGeminiModels() []*ModelInfo {
return cloneModelInfos(getModels().Gemini)
}
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
func GetGeminiVertexModels() []*ModelInfo {
return cloneModelInfos(getModels().Vertex)
}
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
func GetGeminiCLIModels() []*ModelInfo {
return cloneModelInfos(getModels().GeminiCLI)
}
// GetAIStudioModels returns model definitions for AI Studio.
func GetAIStudioModels() []*ModelInfo {
return cloneModelInfos(getModels().AIStudio)
}
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
func GetCodexFreeModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexFree)
}
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
func GetCodexTeamModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexTeam)
}
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
func GetCodexPlusModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPlus)
}
// GetCodexProModels returns model definitions for the Codex pro plan tier.
func GetCodexProModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPro)
}
// GetQwenModels returns the standard Qwen model definitions.
func GetQwenModels() []*ModelInfo {
return cloneModelInfos(getModels().Qwen)
}
// GetIFlowModels returns the standard iFlow model definitions.
func GetIFlowModels() []*ModelInfo {
return cloneModelInfos(getModels().IFlow)
}
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
func GetKimiModels() []*ModelInfo {
return cloneModelInfos(getModels().Kimi)
}
// GetAntigravityModelConfig returns static configuration for antigravity models.
// Keys use upstream model names returned by the Antigravity models endpoint.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
data := getModels()
if len(data.Antigravity) == 0 {
return nil
}
out := make(map[string]*AntigravityModelConfig, len(data.Antigravity))
for k, v := range data.Antigravity {
out[k] = cloneAntigravityModelConfig(v)
}
return out
}
func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig {
if cfg == nil {
return nil
}
copyConfig := *cfg
if cfg.Thinking != nil {
copyThinking := *cfg.Thinking
if len(cfg.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...)
}
copyConfig.Thinking = &copyThinking
}
return &copyConfig
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*ModelInfo, len(models))
for i, m := range models {
out[i] = cloneModelInfo(m)
}
return out
}
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider. // GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
// It returns nil when the channel is unknown. // It returns nil when the channel is unknown.
// //
@@ -35,7 +160,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "aistudio": case "aistudio":
return GetAIStudioModels() return GetAIStudioModels()
case "codex": case "codex":
return GetOpenAIModels() return GetCodexProModels()
case "qwen": case "qwen":
return GetQwenModels() return GetQwenModels()
case "iflow": case "iflow":
@@ -77,27 +202,28 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil return nil
} }
data := getModels()
allModels := [][]*ModelInfo{ allModels := [][]*ModelInfo{
GetClaudeModels(), data.Claude,
GetGeminiModels(), data.Gemini,
GetGeminiVertexModels(), data.Vertex,
GetGeminiCLIModels(), data.GeminiCLI,
GetAIStudioModels(), data.AIStudio,
GetOpenAIModels(), data.CodexPro,
GetQwenModels(), data.Qwen,
GetIFlowModels(), data.IFlow,
GetKimiModels(), data.Kimi,
} }
for _, models := range allModels { for _, models := range allModels {
for _, m := range models { for _, m := range models {
if m != nil && m.ID == modelID { if m != nil && m.ID == modelID {
return m return cloneModelInfo(m)
} }
} }
} }
// Check Antigravity static config // Check Antigravity static config
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil { if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil {
return &ModelInfo{ return &ModelInfo{
ID: modelID, ID: modelID,
Thinking: cfg.Thinking, Thinking: cfg.Thinking,
File diff suppressed because it is too large Load Diff
+138 -28
View File
@@ -62,6 +62,11 @@ type ModelInfo struct {
UserDefined bool `json:"-"` UserDefined bool `json:"-"`
} }
type availableModelsCacheEntry struct {
models []map[string]any
expiresAt time.Time
}
// ThinkingSupport describes a model family's supported internal reasoning budget range. // ThinkingSupport describes a model family's supported internal reasoning budget range.
// Values are interpreted in provider-native token units. // Values are interpreted in provider-native token units.
type ThinkingSupport struct { type ThinkingSupport struct {
@@ -116,6 +121,8 @@ type ModelRegistry struct {
clientProviders map[string]string clientProviders map[string]string
// mutex ensures thread-safe access to the registry // mutex ensures thread-safe access to the registry
mutex *sync.RWMutex mutex *sync.RWMutex
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
availableModelsCache map[string]availableModelsCacheEntry
// hook is an optional callback sink for model registration changes // hook is an optional callback sink for model registration changes
hook ModelRegistryHook hook ModelRegistryHook
} }
@@ -132,11 +139,24 @@ func GetGlobalRegistry() *ModelRegistry {
clientModels: make(map[string][]string), clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo), clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string), clientProviders: make(map[string]string),
availableModelsCache: make(map[string]availableModelsCacheEntry),
mutex: &sync.RWMutex{}, mutex: &sync.RWMutex{},
} }
}) })
return globalRegistry return globalRegistry
} }
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
if r.availableModelsCache == nil {
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
}
}
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
if len(r.availableModelsCache) == 0 {
return
}
clear(r.availableModelsCache)
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions. // LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo { func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
@@ -151,9 +171,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
} }
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil { if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return info return cloneModelInfo(info)
} }
return LookupStaticModelInfo(modelID) return cloneModelInfo(LookupStaticModelInfo(modelID))
} }
// SetHook sets an optional hook for observing model registration changes. // SetHook sets an optional hook for observing model registration changes.
@@ -211,6 +231,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) { func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
provider := strings.ToLower(clientProvider) provider := strings.ToLower(clientProvider)
uniqueModelIDs := make([]string, 0, len(models)) uniqueModelIDs := make([]string, 0, len(models))
@@ -236,6 +257,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientModels, clientID) delete(r.clientModels, clientID)
delete(r.clientModelInfos, clientID) delete(r.clientModelInfos, clientID)
delete(r.clientProviders, clientID) delete(r.clientProviders, clientID)
r.invalidateAvailableModelsCacheLocked()
misc.LogCredentialSeparator() misc.LogCredentialSeparator()
return return
} }
@@ -263,6 +285,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else { } else {
delete(r.clientProviders, clientID) delete(r.clientProviders, clientID)
} }
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models) r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs)) log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator() misc.LogCredentialSeparator()
@@ -406,6 +429,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID) delete(r.clientProviders, clientID)
} }
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models) r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged { if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output. // Only metadata (e.g., display name) changed; skip separator when no log output.
@@ -509,6 +533,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedOutputModalities) > 0 { if len(model.SupportedOutputModalities) > 0 {
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...) copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
} }
if model.Thinking != nil {
copyThinking := *model.Thinking
if len(model.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
}
copyModel.Thinking = &copyThinking
}
return &copyModel return &copyModel
} }
@@ -538,6 +569,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
r.unregisterClientInternal(clientID) r.unregisterClientInternal(clientID)
r.invalidateAvailableModelsCacheLocked()
} }
// unregisterClientInternal performs the actual client unregistration (internal, no locking) // unregisterClientInternal performs the actual client unregistration (internal, no locking)
@@ -604,9 +636,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) { func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists { if registration, exists := r.models[modelID]; exists {
registration.QuotaExceededClients[clientID] = new(time.Now()) now := time.Now()
registration.QuotaExceededClients[clientID] = &now
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID) log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
} }
} }
@@ -618,9 +653,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) { func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists { if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID) delete(registration.QuotaExceededClients, clientID)
r.invalidateAvailableModelsCacheLocked()
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID) // log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
} }
} }
@@ -636,6 +673,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
} }
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID] registration, exists := r.models[modelID]
if !exists || registration == nil { if !exists || registration == nil {
@@ -649,6 +687,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
} }
registration.SuspendedClients[clientID] = reason registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now() registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
if reason != "" { if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason) log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else { } else {
@@ -666,6 +705,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
} }
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID] registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil { if !exists || registration == nil || registration.SuspendedClients == nil {
@@ -676,6 +716,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
} }
delete(registration.SuspendedClients, clientID) delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now() registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Resumed client %s for model %s", clientID, modelID) log.Debugf("Resumed client %s for model %s", clientID, modelID)
} }
@@ -711,22 +752,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
// Returns: // Returns:
// - []map[string]any: List of available models in the requested format // - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any { func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
r.mutex.RLock()
defer r.mutex.RUnlock()
models := make([]map[string]any, 0)
quotaExpiredDuration := 5 * time.Minute
for _, registration := range r.models {
// Check if model has any non-quota-exceeded clients
availableClients := registration.Count
now := time.Now() now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet r.mutex.RLock()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
models := cloneModelMaps(cache.models)
r.mutex.RUnlock()
return models
}
r.mutex.RUnlock()
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
return cloneModelMaps(cache.models)
}
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
models: cloneModelMaps(models),
expiresAt: expiresAt,
}
return models
}
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
quotaExpiredDuration := 5 * time.Minute
var expiresAt time.Time
for _, registration := range r.models {
availableClients := registration.Count
expiredClients := 0 expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients { for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(quotaExpiredDuration)
if now.Before(recoveryAt) {
expiredClients++ expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
expiresAt = recoveryAt
}
} }
} }
@@ -747,7 +818,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
effectiveClients = 0 effectiveClients = 0
} }
// Include models that have available clients, or those solely cooling down.
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
model := r.convertModelToMap(registration.Info, handlerType) model := r.convertModelToMap(registration.Info, handlerType)
if model != nil { if model != nil {
@@ -756,7 +826,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
} }
} }
return models return models, expiresAt
}
func cloneModelMaps(models []map[string]any) []map[string]any {
cloned := make([]map[string]any, 0, len(models))
for _, model := range models {
if model == nil {
cloned = append(cloned, nil)
continue
}
copyModel := make(map[string]any, len(model))
for key, value := range model {
copyModel[key] = cloneModelMapValue(value)
}
cloned = append(cloned, copyModel)
}
return cloned
}
func cloneModelMapValue(value any) any {
switch typed := value.(type) {
case map[string]any:
copyMap := make(map[string]any, len(typed))
for key, entry := range typed {
copyMap[key] = cloneModelMapValue(entry)
}
return copyMap
case []any:
copySlice := make([]any, len(typed))
for i, entry := range typed {
copySlice[i] = cloneModelMapValue(entry)
}
return copySlice
case []string:
return append([]string(nil), typed...)
default:
return value
}
} }
// GetAvailableModelsByProvider returns models available for the given provider identifier. // GetAvailableModelsByProvider returns models available for the given provider identifier.
@@ -872,11 +979,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) { if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil { if entry.info != nil {
result = append(result, entry.info) result = append(result, cloneModelInfo(entry.info))
continue continue
} }
if ok && registration != nil && registration.Info != nil { if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info) result = append(result, cloneModelInfo(registration.Info))
} }
} }
} }
@@ -985,13 +1092,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
if reg.Providers != nil { if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 { if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil { if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return info return cloneModelInfo(info)
} }
} }
} }
} }
// Fallback to global info (last registered) // Fallback to global info (last registered)
return reg.Info return cloneModelInfo(reg.Info)
} }
return nil return nil
} }
@@ -1031,7 +1138,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["max_completion_tokens"] = model.MaxCompletionTokens result["max_completion_tokens"] = model.MaxCompletionTokens
} }
if len(model.SupportedParameters) > 0 { if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
} }
return result return result
@@ -1075,13 +1182,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["outputTokenLimit"] = model.OutputTokenLimit result["outputTokenLimit"] = model.OutputTokenLimit
} }
if len(model.SupportedGenerationMethods) > 0 { if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
} }
if len(model.SupportedInputModalities) > 0 { if len(model.SupportedInputModalities) > 0 {
result["supportedInputModalities"] = model.SupportedInputModalities result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
} }
if len(model.SupportedOutputModalities) > 0 { if len(model.SupportedOutputModalities) > 0 {
result["supportedOutputModalities"] = model.SupportedOutputModalities result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
} }
return result return result
@@ -1111,15 +1218,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
now := time.Now() now := time.Now()
quotaExpiredDuration := 5 * time.Minute quotaExpiredDuration := 5 * time.Minute
invalidated := false
for modelID, registration := range r.models { for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients { for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration { if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
delete(registration.QuotaExceededClients, clientID) delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID) log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
} }
} }
} }
if invalidated {
r.invalidateAvailableModelsCacheLocked()
}
} }
// GetFirstAvailableModel returns the first available model for the given handler type. // GetFirstAvailableModel returns the first available model for the given handler type.
@@ -1133,8 +1245,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
// - string: The model ID of the first available model, or empty string if none available // - string: The model ID of the first available model, or empty string if none available
// - error: An error if no models are available // - error: An error if no models are available
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) { func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Get all available models for this handler type // Get all available models for this handler type
models := r.GetAvailableModels(handlerType) models := r.GetAvailableModels(handlerType)
@@ -1194,13 +1304,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
// Prefer client's own model info to preserve original type/owned_by // Prefer client's own model info to preserve original type/owned_by
if clientInfos != nil { if clientInfos != nil {
if info, ok := clientInfos[modelID]; ok && info != nil { if info, ok := clientInfos[modelID]; ok && info != nil {
result = append(result, info) result = append(result, cloneModelInfo(info))
continue continue
} }
} }
// Fallback to global registry (for backwards compatibility) // Fallback to global registry (for backwards compatibility)
if reg, ok := r.models[modelID]; ok && reg.Info != nil { if reg, ok := r.models[modelID]; ok && reg.Info != nil {
result = append(result, reg.Info) result = append(result, cloneModelInfo(reg.Info))
} }
} }
return result return result
@@ -0,0 +1,54 @@
package registry
import "testing"
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected 1 model, got %d", len(first))
}
first[0]["id"] = "mutated"
first[0]["display_name"] = "Mutated"
second := r.GetAvailableModels("openai")
if got := second[0]["id"]; got != "m1" {
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
}
if got := second[0]["display_name"]; got != "Model One" {
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
}
}
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected 1 model, got %d", len(models))
}
if got := models[0]["display_name"]; got != "Model One" {
t.Fatalf("expected initial display_name Model One, got %v", got)
}
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
models = r.GetAvailableModels("openai")
if got := models[0]["display_name"]; got != "Model One Updated" {
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
}
r.SuspendClientModel("client-1", "m1", "manual")
models = r.GetAvailableModels("openai")
if len(models) != 0 {
t.Fatalf("expected no available models after suspension, got %d", len(models))
}
r.ResumeClientModel("client-1", "m1")
models = r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to reappear after resume, got %d", len(models))
}
}
@@ -0,0 +1,149 @@
package registry
import (
"testing"
"time"
)
func TestGetModelInfoReturnsClone(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
}})
first := r.GetModelInfo("m1", "gemini")
if first == nil {
t.Fatal("expected model info")
}
first.DisplayName = "mutated"
first.Thinking.Levels[0] = "mutated"
second := r.GetModelInfo("m1", "gemini")
if second.DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
}
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
}
}
func TestGetModelsForClientReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetModelsForClient("client-1")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetModelsForClient("client-1")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetAvailableModelsByProvider("gemini")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetAvailableModelsByProvider("gemini")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
r.SetModelQuotaExceeded("client-1", "m1")
if models := r.GetAvailableModels("openai"); len(models) != 1 {
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
}
r.mutex.Lock()
quotaTime := time.Now().Add(-6 * time.Minute)
r.models["m1"].QuotaExceededClients["client-1"] = &quotaTime
r.mutex.Unlock()
r.CleanupExpiredQuotas()
if count := r.GetModelCount("m1"); count != 1 {
t.Fatalf("expected model count 1 after cleanup, got %d", count)
}
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
}
if got := models[0]["id"]; got != "m1" {
t.Fatalf("expected model id m1, got %v", got)
}
}
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
SupportedParameters: []string{"temperature", "top_p"},
}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected one model, got %d", len(first))
}
params, ok := first[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 {
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
}
params[0] = "mutated"
second := r.GetAvailableModels("openai")
params, ok = second[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 || params[0] != "temperature" {
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
}
}
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
first := LookupModelInfo("glm-4.6")
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
t.Fatalf("expected static model with thinking levels, got %+v", first)
}
first.Thinking.Levels[0] = "mutated"
second := LookupModelInfo("glm-4.6")
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
t.Fatalf("expected static lookup clone, got %+v", second)
}
}
+198
View File
@@ -0,0 +1,198 @@
package registry
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
modelsFetchTimeout = 30 * time.Second
)
var modelsURLs = []string{
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
"https://models.router-for.me/models.json",
}
//go:embed models/models.json
var embeddedModelsJSON []byte
type modelStore struct {
mu sync.RWMutex
data *staticModelsJSON
}
var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once
func init() {
// Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
}
}
// StartModelsUpdater runs a one-time models refresh on startup.
// It blocks until the startup fetch attempt finishes so service initialization
// can wait for the refreshed catalog before registering auth-backed models.
// Safe to call multiple times; only one refresh will run.
func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() {
runModelsUpdater(ctx)
})
}
func runModelsUpdater(ctx context.Context) {
// Try network fetch once on startup, then stop.
// Periodic refresh is disabled - models are only refreshed at startup.
tryRefreshModels(ctx)
}
func tryRefreshModels(ctx context.Context) {
client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
if err != nil {
cancel()
log.Debugf("models fetch request creation failed for %s: %v", url, err)
continue
}
resp, err := client.Do(req)
if err != nil {
cancel()
log.Debugf("models fetch failed from %s: %v", url, err)
continue
}
if resp.StatusCode != 200 {
resp.Body.Close()
cancel()
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
cancel()
if err != nil {
log.Debugf("models fetch read error from %s: %v", url, err)
continue
}
if err := loadModelsFromBytes(data, url); err != nil {
log.Warnf("models parse failed from %s: %v", url, err)
continue
}
log.Infof("models updated from %s", url)
return
}
log.Warn("models refresh failed from all URLs, using current data")
}
func loadModelsFromBytes(data []byte, source string) error {
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
return fmt.Errorf("%s: decode models catalog: %w", source, err)
}
if err := validateModelsCatalog(&parsed); err != nil {
return fmt.Errorf("%s: validate models catalog: %w", source, err)
}
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = &parsed
modelsCatalogStore.mu.Unlock()
return nil
}
func getModels() *staticModelsJSON {
modelsCatalogStore.mu.RLock()
defer modelsCatalogStore.mu.RUnlock()
return modelsCatalogStore.data
}
func validateModelsCatalog(data *staticModelsJSON) error {
if data == nil {
return fmt.Errorf("catalog is nil")
}
requiredSections := []struct {
name string
models []*ModelInfo
}{
{name: "claude", models: data.Claude},
{name: "gemini", models: data.Gemini},
{name: "vertex", models: data.Vertex},
{name: "gemini-cli", models: data.GeminiCLI},
{name: "aistudio", models: data.AIStudio},
{name: "codex-free", models: data.CodexFree},
{name: "codex-team", models: data.CodexTeam},
{name: "codex-plus", models: data.CodexPlus},
{name: "codex-pro", models: data.CodexPro},
{name: "qwen", models: data.Qwen},
{name: "iflow", models: data.IFlow},
{name: "kimi", models: data.Kimi},
}
for _, section := range requiredSections {
if err := validateModelSection(section.name, section.models); err != nil {
return err
}
}
if err := validateAntigravitySection(data.Antigravity); err != nil {
return err
}
return nil
}
func validateModelSection(section string, models []*ModelInfo) error {
if len(models) == 0 {
return fmt.Errorf("%s section is empty", section)
}
seen := make(map[string]struct{}, len(models))
for i, model := range models {
if model == nil {
return fmt.Errorf("%s[%d] is null", section, i)
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
return fmt.Errorf("%s[%d] has empty id", section, i)
}
if _, exists := seen[modelID]; exists {
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
}
seen[modelID] = struct{}{}
}
return nil
}
func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error {
if len(configs) == 0 {
return fmt.Errorf("antigravity section is empty")
}
for modelID, cfg := range configs {
trimmedID := strings.TrimSpace(modelID)
if trimmedID == "" {
return fmt.Errorf("antigravity contains empty model id")
}
if cfg == nil {
return fmt.Errorf("antigravity[%q] is null", trimmedID)
}
}
return nil
}
File diff suppressed because it is too large Load Diff
+23 -7
View File
@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
} }
return true return true
}) })
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String())
result += "," + partJSON
} }
result += "]" result += "]"
@@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int {
return count return count
} }
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) { func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"] ccRaw, exists := obj["cache_control"]
if !exists { if !exists {
return return false
} }
cc, ok := asObject(ccRaw) cc, ok := asObject(ccRaw)
if !ok { if !ok {
*seen5m = true *seen5m = true
return return false
} }
ttlRaw, ttlExists := cc["ttl"] ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string) ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" { if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true *seen5m = true
return return false
} }
if *seen5m { if *seen5m {
delete(cc, "ttl") delete(cc, "ttl")
return true
} }
return false
} }
func findLastCacheControlIndex(arr []any) int { func findLastCacheControlIndex(arr []any) int {
@@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
} }
seen5m := false seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok { if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools { for _, tool := range tools {
if obj, ok := asObject(tool); ok { if obj, ok := asObject(tool); ok {
normalizeTTLForBlock(obj, &seen5m) if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
} }
} }
} }
@@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
if system, ok := asArray(root["system"]); ok { if system, ok := asArray(root["system"]); ok {
for _, item := range system { for _, item := range system {
if obj, ok := asObject(item); ok { if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m) if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
} }
} }
} }
@@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
} }
for _, item := range content { for _, item := range content {
if obj, ok := asObject(item); ok { if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m) if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
} }
} }
} }
} }
if !modified {
return payload
}
return marshalPayloadObject(payload, root) return marshalPayloadObject(payload, root)
} }
@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
} }
} }
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
// Payload where no TTL normalization is needed (all blocks use 1h with no
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
// that json.Marshal would escape to \u003c etc., altering byte identity.
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := normalizeCacheControlTTL(payload)
if !bytes.Equal(out, payload) {
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) { func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{ payload := []byte(`{
"tools": [ "tools": [
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error()) t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
} }
} }
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
system := gjson.GetBytes(out, "system")
if !system.IsArray() {
t.Fatalf("system should be an array, got %s", system.Type)
}
blocks := system.Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
}
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
}
if blocks[2].Get("text").String() != "You are a helpful assistant." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
}
}
// Test case 2: Strict mode drops the string system prompt
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, true)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 3: Empty string system prompt does not produce a spurious block
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 4: Array system prompt is unaffected by the string handling
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != "Be concise." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
}
// Test case 5: Special characters in string system prompt survive conversion
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}
@@ -23,6 +23,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return dialer return dialer
} }
parsedURL, errParse := url.Parse(proxyURL) setting, errParse := proxyutil.Parse(proxyURL)
if errParse != nil { if errParse != nil {
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse) log.Errorf("codex websockets executor: %v", errParse)
return dialer return dialer
} }
switch parsedURL.Scheme { switch setting.Mode {
case proxyutil.ModeDirect:
dialer.Proxy = nil
return dialer
case proxyutil.ModeProxy:
default:
return dialer
}
switch setting.URL.Scheme {
case "socks5": case "socks5":
var proxyAuth *proxy.Auth var proxyAuth *proxy.Auth
if parsedURL.User != nil { if setting.URL.User != nil {
username := parsedURL.User.Username() username := setting.URL.User.Username()
password, _ := parsedURL.User.Password() password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password} proxyAuth = &proxy.Auth{User: username, Password: password}
} }
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil { if errSOCKS5 != nil {
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5) log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
return dialer return dialer
@@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return socksDialer.Dial(network, addr) return socksDialer.Dial(network, addr)
} }
case "http", "https": case "http", "https":
dialer.Proxy = http.ProxyURL(parsedURL) dialer.Proxy = http.ProxyURL(setting.URL)
default: default:
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme) log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
} }
return dialer return dialer
@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -187,3 +188,16 @@ func contextWithGinHeaders(headers map[string]string) context.Context {
} }
return context.WithValue(context.Background(), "gin", ginCtx) return context.WithValue(context.Background(), "gin", ginCtx)
} }
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Parallel()
dialer := newProxyAwareWebsocketDialer(
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
)
if dialer.Proxy != nil {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}
@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
// For API key auth, use simpler URL format without project/location // For API key auth, use simpler URL format without project/location
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://aiplatform.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
action := getVertexAction(baseModel, true) action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location // For API key auth, use simpler URL format without project/location
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://aiplatform.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
// Imagen models don't support streaming, skip SSE params // Imagen models don't support streaming, skip SSE params
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
// For API key auth, use simpler URL format without project/location // For API key auth, use simpler URL format without project/location
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://aiplatform.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
+4 -41
View File
@@ -2,16 +2,14 @@ package executor
import ( import (
"context" "context"
"net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
) )
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
@@ -72,45 +70,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
// Returns: // Returns:
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid // - *http.Transport: A configured transport, or nil if the proxy URL is invalid
func buildProxyTransport(proxyURL string) *http.Transport { func buildProxyTransport(proxyURL string) *http.Transport {
if proxyURL == "" { transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil return nil
} }
parsedURL, errParse := url.Parse(proxyURL)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
return nil
}
var transport *http.Transport
// Handle different proxy schemes
if parsedURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
return nil
}
return transport return transport
} }
@@ -0,0 +1,30 @@
package executor
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
0,
)
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
+6
View File
@@ -257,8 +257,11 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
if suffixResult.HasSuffix { if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID) config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
} else { } else {
config = extractThinkingConfig(body, fromFormat)
if !hasThinkingConfig(config) && fromFormat != toFormat {
config = extractThinkingConfig(body, toFormat) config = extractThinkingConfig(body, toFormat)
} }
}
if !hasThinkingConfig(config) { if !hasThinkingConfig(config) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
if config.Mode != ModeLevel { if config.Mode != ModeLevel {
return config return config
} }
if toFormat == "claude" {
return config
}
if !isBudgetCapableProvider(toFormat) { if !isBudgetCapableProvider(toFormat) {
return config return config
} }
@@ -0,0 +1,55 @@
package thinking_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
"github.com/tidwall/gjson"
)
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-user-defined-claude-" + t.Name()
modelID := "custom-claude-4-6"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
t.Cleanup(func() {
reg.UnregisterClient(clientID)
})
tests := []struct {
name string
model string
body []byte
}{
{
name: "claude adaptive effort body",
model: modelID,
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
},
{
name: "suffix level",
model: modelID + "(high)",
body: []byte(`{}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
if err != nil {
t.Fatalf("ApplyThinking() error = %v", err)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
}
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
})
}
}
@@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String())) effort = strings.ToLower(strings.TrimSpace(v.String()))
} }
if effort != "" { if effort != "" {
if effort == "max" {
effort = "high"
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort) out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else { } else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high") out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
@@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw) t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
} }
} }
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
tests := []struct {
name string
effort string
expected string
}{
{"low", "low", "low"},
{"medium", "medium", "medium"},
{"high", "high", "high"},
{"max", "max", "high"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"},
"output_config": {"effort": "` + tt.effort + `"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking")
}
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
})
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
}
if thinkingConfig.Get("thinkingLevel").String() != "high" {
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
}
@@ -15,6 +15,7 @@ import (
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Create the tool use block with unique ID and function details // Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex) data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName) data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Process system messages and convert them to input content format. // Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system") systemsResult := rootResult.Get("system")
if systemsResult.IsArray() { if systemsResult.Exists() {
systemResults := systemsResult.Array()
message := `{"type":"message","role":"developer","content":[]}` message := `{"type":"message","role":"developer","content":[]}`
contentIndex := 0 contentIndex := 0
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i] appendSystemText := func(text string) {
systemTypeResult := systemResult.Get("type") if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
if systemTypeResult.String() == "text" { return
text := systemResult.Get("text").String()
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
continue
} }
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text") message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text) message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++ contentIndex++
} }
if systemsResult.Type == gjson.String {
appendSystemText(systemsResult.String())
} else if systemsResult.IsArray() {
systemResults := systemsResult.Array()
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
if systemResult.Get("type").String() == "text" {
appendSystemText(systemResult.Get("text").String())
} }
}
}
if contentIndex > 0 { if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message) template, _ = sjson.SetRaw(template, "input.-1", message)
} }
@@ -0,0 +1,89 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasDeveloper bool
wantTexts []string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Block 1", "Block 2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
inputs := resultJSON.Get("input").Array()
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
if hasDeveloper != tt.wantHasDeveloper {
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
}
if !tt.wantHasDeveloper {
return
}
content := inputs[0].Get("content").Array()
if len(content) != len(tt.wantTexts) {
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
}
for i, wantText := range tt.wantTexts {
if gotType := content[i].Get("type").String(); gotType != "input_text" {
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
}
if gotText := content[i].Get("text").String(); gotText != wantText {
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
}
}
})
}
}
@@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false (*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex) template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String()) template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{ {
// Restore original tool name if shortened // Restore original tool name if shortened
name := itemResult.Get("name").String() name := itemResult.Get("name").String()
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
} }
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String()) toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name) toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}" inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) { if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
@@ -14,6 +14,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Create the tool use block with unique ID and function details // Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName) data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Create the tool use block with unique ID and function details // Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex) data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))) data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", clientToolName) data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
clientToolName := util.MapToolName(toolNameMap, upstreamToolName) clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++ toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)) toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName) toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
inputRaw := "{}" inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() { if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
content := m.Get("content") content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 { if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style // system -> systemInstruction as a user message style
if content.Type == gjson.String { if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user") out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String()) out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++ systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" { } else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user") out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String()) out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++ systemPartIndex++
} else if content.IsArray() { } else if content.IsArray() {
contents := content.Array() contents := content.Array()
if len(contents) > 0 { if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user") out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
for j := 0; j < len(contents); j++ { for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String()) out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++ systemPartIndex++
} }
} }
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if instructions := root.Get("instructions"); instructions.Exists() { if instructions := root.Get("instructions"); instructions.Exists() {
systemInstr := `{"parts":[{"text":""}]}` systemInstr := `{"parts":[{"text":""}]}`
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String()) systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr) out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
} }
// Convert input messages to Gemini contents format // Convert input messages to Gemini contents format
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if strings.EqualFold(itemRole, "system") { if strings.EqualFold(itemRole, "system") {
if contentArray := item.Get("content"); contentArray.Exists() { if contentArray := item.Get("content"); contentArray.Exists() {
systemInstr := "" systemInstr := ""
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() { if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
systemInstr = systemInstructionResult.Raw systemInstr = systemInstructionResult.Raw
} else { } else {
systemInstr = `{"parts":[]}` systemInstr = `{"parts":[]}`
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
} }
if systemInstr != `{"parts":[]}` { if systemInstr != `{"parts":[]}` {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr) out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
} }
} }
continue continue
@@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_start for tool_use // Send content_block_start for tool_use
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex) contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID) contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name) contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n") results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
} }
@@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() { if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String()) toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String()) toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
argsStr := util.FixJSON(toolCall.Get("function.arguments").String()) argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
@@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, tc gjson.Result) bool { toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true hasToolCall = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}` toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String()) toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String())) toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
argsStr := util.FixJSON(tc.Get("function.arguments").String()) argsStr := util.FixJSON(tc.Get("function.arguments").String())
@@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, toolCall gjson.Result) bool { toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true hasToolCall = true
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}` toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String()) toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String())) toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
argsStr := util.FixJSON(toolCall.Get("function.arguments").String()) argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
+24
View File
@@ -0,0 +1,24 @@
package util
import (
"fmt"
"regexp"
"sync/atomic"
"time"
)
var (
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
claudeToolUseIDCounter uint64
)
// SanitizeClaudeToolID ensures the given id conforms to Claude's
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
// replaced with '_'; an empty result gets a generated fallback.
func SanitizeClaudeToolID(id string) string {
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
if s == "" {
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
}
return s
}
+6 -31
View File
@@ -4,50 +4,25 @@
package util package util
import ( import (
"context"
"net"
"net/http" "net/http"
"net/url"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
) )
// SetProxy configures the provided HTTP client with proxy settings from the configuration. // SetProxy configures the provided HTTP client with proxy settings from the configuration.
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport // It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
// to route requests through the configured proxy server. // to route requests through the configured proxy server.
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client { func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
var transport *http.Transport if cfg == nil || httpClient == nil {
// Attempt to parse the proxy URL from the configuration.
proxyURL, errParse := url.Parse(cfg.ProxyURL)
if errParse == nil {
// Handle different proxy schemes.
if proxyURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication.
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return httpClient return httpClient
} }
// Set up a custom transport using the SOCKS5 dialer.
transport = &http.Transport{ transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if errBuild != nil {
return dialer.Dial(network, addr) log.Errorf("%v", errBuild)
},
} }
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
// If a new transport was created, apply it to the HTTP client.
if transport != nil { if transport != nil {
httpClient.Transport = transport httpClient.Transport = transport
} }
+11
View File
@@ -10,6 +10,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
) )
@@ -149,6 +150,16 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
} }
} }
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth") ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
// For codex auth files, extract plan_type from the JWT id_token.
if provider == "codex" {
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
a.Attributes["plan_type"] = pt
}
}
}
}
if provider == "gemini-cli" { if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 { if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals { for _, v := range virtuals {
@@ -34,6 +34,8 @@ const (
wsTurnStateHeader = "x-codex-turn-state" wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048 wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
) )
var responsesWebsocketUpgrader = websocket.Upgrader{ var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -825,18 +827,71 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
if builder == nil { if builder == nil {
return return
} }
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload) trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 { if len(trimmedPayload) == 0 {
return return
} }
if builder.Len() > 0 { if builder.Len() > 0 {
builder.WriteString("\n") if !appendWebsocketLogString(builder, "\n") {
return
} }
builder.WriteString("websocket.") }
builder.WriteString(eventType) if !appendWebsocketLogString(builder, "websocket.") {
builder.WriteString("\n") return
builder.Write(trimmedPayload) }
builder.WriteString("\n") if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
} }
func websocketPayloadEventType(payload []byte) string { func websocketPayloadEventType(payload []byte) string {
@@ -266,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) {
} }
} }
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) { func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+3
View File
@@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl
FileName: fileName, FileName: fileName,
Storage: tokenStorage, Storage: tokenStorage,
Metadata: metadata, Metadata: metadata,
Attributes: map[string]string{
"plan_type": planType,
},
}, nil }, nil
} }
+536 -73
View File
@@ -134,6 +134,7 @@ type Manager struct {
hook Hook hook Hook
mu sync.RWMutex mu sync.RWMutex
auths map[string]*Auth auths map[string]*Auth
scheduler *authScheduler
// providerOffsets tracks per-model provider rotation state for multi-provider routing. // providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int providerOffsets map[string]int
@@ -149,6 +150,9 @@ type Manager struct {
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix). // Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelAlias atomic.Value apiKeyModelAlias atomic.Value
// modelPoolOffsets tracks per-auth alias pool rotation state.
modelPoolOffsets map[string]int
// runtimeConfig stores the latest application config for request-time decisions. // runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store(). // It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value runtimeConfig atomic.Value
@@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
hook: hook, hook: hook,
auths: make(map[string]*Auth), auths: make(map[string]*Auth),
providerOffsets: make(map[string]int), providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency), refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
} }
// atomic.Value requires non-nil initial value. // atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{}) manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil)) manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
manager.scheduler = newAuthScheduler(selector)
return manager return manager
} }
func isBuiltInSelector(selector Selector) bool {
switch selector.(type) {
case *RoundRobinSelector, *FillFirstSelector:
return true
default:
return false
}
}
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
if m == nil || m.scheduler == nil {
return
}
m.scheduler.rebuild(auths)
}
func (m *Manager) syncScheduler() {
if m == nil || m.scheduler == nil {
return
}
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
// because the initial scheduler.upsertAuth during Register/Update runs before
// registerModelsForAuth and therefore snapshots an empty model set.
func (m *Manager) RefreshSchedulerEntry(authID string) {
if m == nil || m.scheduler == nil || authID == "" {
return
}
m.mu.RLock()
auth, ok := m.auths[authID]
if !ok || auth == nil {
m.mu.RUnlock()
return
}
snapshot := auth.Clone()
m.mu.RUnlock()
m.scheduler.upsertAuth(snapshot)
}
func (m *Manager) SetSelector(selector Selector) { func (m *Manager) SetSelector(selector Selector) {
if m == nil { if m == nil {
return return
@@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) {
m.mu.Lock() m.mu.Lock()
m.selector = selector m.selector = selector
m.mu.Unlock() m.mu.Unlock()
if m.scheduler != nil {
m.scheduler.setSelector(selector)
m.syncScheduler()
}
} }
// SetStore swaps the underlying persistence store. // SetStore swaps the underlying persistence store.
@@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin
if resolved == "" { if resolved == "" {
return "" return ""
} }
// Preserve thinking suffix from the client's requested model unless config already has one. return preserveRequestedModelSuffix(requestedModel, resolved)
requestResult := thinking.ParseSuffix(requestedModel)
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
} }
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
func isAPIKeyAuth(auth *Auth) bool {
if auth == nil {
return false
}
kind, _ := auth.AccountInfo()
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
}
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
if !isAPIKeyAuth(auth) {
return false
}
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return true
}
if auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
}
func openAICompatProviderKey(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
return strings.ToLower(providerKey)
}
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
return strings.ToLower(compatName)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
if base == "" {
base = strings.TrimSpace(requestedModel)
}
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
}
func (m *Manager) nextModelPoolOffset(key string, size int) int {
if m == nil || size <= 1 {
return 0
}
key = strings.TrimSpace(key)
if key == "" {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
if m.modelPoolOffsets == nil {
m.modelPoolOffsets = make(map[string]int)
}
offset := m.modelPoolOffsets[key]
if offset >= 2_147_483_640 {
offset = 0
}
m.modelPoolOffsets[key] = offset + 1
if size <= 0 {
return 0
}
return offset % size
}
func rotateStrings(values []string, offset int) []string {
if len(values) <= 1 {
return values
}
if offset <= 0 {
out := make([]string, len(values))
copy(out, values)
return out
}
offset = offset % len(values)
out := make([]string, 0, len(values))
out = append(out, values[offset:]...)
out = append(out, values[:offset]...)
return out
}
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
return nil
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return nil
}
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
}
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
return m.prepareExecutionModels(auth, routeModel)
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
requestedModel := rewriteModelForAuth(routeModel, auth)
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
if len(pool) == 1 {
return pool
}
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
return rotateStrings(pool, offset)
}
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
if strings.TrimSpace(resolved) == "" {
resolved = requestedModel
}
return []string{resolved}
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
}
go func() {
for range ch {
}
}()
}
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
if ch == nil {
return nil, true, nil
}
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
for {
var (
chunk cliproxyexecutor.StreamChunk
ok bool
)
if ctx != nil {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case chunk, ok = <-ch:
}
} else {
chunk, ok = <-ch
}
if !ok {
return buffered, true, nil
}
if chunk.Err != nil {
return nil, false, chunk.Err
}
buffered = append(buffered, chunk)
if len(chunk.Payload) > 0 {
return buffered, false, nil
}
}
}
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
var failed bool
forward := true
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
return false
}
if ctx == nil {
out <- chunk
return true
}
select {
case <-ctx.Done():
forward = false
return false
case out <- chunk:
return true
}
}
for _, chunk := range buffered {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
for chunk := range remaining {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
if !failed {
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
}
}()
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
execModels := m.prepareExecutionModels(auth, routeModel)
var lastErr error
for idx, execModel := range execModels {
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
lastErr = bootstrapErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
if closed && len(buffered) == 0 {
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
m.MarkResult(ctx, result)
if idx < len(execModels)-1 {
lastErr = emptyErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
remaining := streamResult.Chunks
if closed {
closedCh := make(chan cliproxyexecutor.StreamChunk)
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
} }
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() { func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
@@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
auth.ID = uuid.NewString() auth.ID = uuid.NewString()
} }
auth.EnsureIndex() auth.EnsureIndex()
authClone := auth.Clone()
m.mu.Lock() m.mu.Lock()
m.auths[auth.ID] = auth.Clone() m.auths[auth.ID] = authClone
m.mu.Unlock() m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig() m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth) _ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone()) m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil return auth.Clone(), nil
@@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
} }
} }
auth.EnsureIndex() auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone() authClone := auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock() m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig() m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth) _ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone()) m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil return auth.Clone(), nil
@@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
// Load resets manager state from the backing store. // Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error { func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock()
if m.store == nil { if m.store == nil {
m.mu.Unlock()
return nil return nil
} }
items, err := m.store.List(ctx) items, err := m.store.List(ctx)
if err != nil { if err != nil {
m.mu.Unlock()
return err return err
} }
m.auths = make(map[string]*Auth, len(items)) m.auths = make(map[string]*Auth, len(items))
@@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error {
cfg = &internalconfig.Config{} cfg = &internalconfig.Config{}
} }
m.rebuildAPIKeyModelAliasLocked(cfg) m.rebuildAPIKeyModelAliasLocked(cfg)
m.mu.Unlock()
m.syncScheduler()
return nil return nil
} }
@@ -634,10 +1005,12 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
} }
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = upstreamModel
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts) resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -655,12 +1028,20 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
if isRequestInvalidError(errExec) { if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec return cliproxyexecutor.Response{}, errExec
} }
lastErr = errExec authErr = errExec
continue continue
} }
m.MarkResult(execCtx, result) m.MarkResult(execCtx, result)
return resp, nil return resp, nil
} }
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = authErr
continue
}
}
} }
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) { func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
@@ -696,10 +1077,12 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
} }
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth) execReq.Model = upstreamModel
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -717,12 +1100,20 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
if isRequestInvalidError(errExec) { if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec return cliproxyexecutor.Response{}, errExec
} }
lastErr = errExec authErr = errExec
continue continue
} }
m.hook.OnResult(execCtx, result) m.hook.OnResult(execCtx, result)
return resp, nil return resp, nil
} }
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = authErr
continue
}
}
} }
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) { func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) {
@@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
} }
execReq := req streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil { if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil { if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx return nil, errCtx
} }
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
if isRequestInvalidError(errStream) { if isRequestInvalidError(errStream) {
return nil, errStream return nil, errStream
} }
lastErr = errStream lastErr = errStream
continue continue
} }
out := make(chan cliproxyexecutor.StreamChunk) return streamResult, nil
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
return &cliproxyexecutor.StreamResult{
Headers: streamResult.Headers,
Chunks: out,
}, nil
} }
} }
@@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
suspendReason := "" suspendReason := ""
clearModelQuota := false clearModelQuota := false
setModelQuota := false setModelQuota := false
var authSnapshot *Auth
m.mu.Lock() m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil { if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
@@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
} }
_ = m.persist(ctx, auth) _ = m.persist(ctx, auth)
authSnapshot = auth.Clone()
} }
m.mu.Unlock() m.mu.Unlock()
if m.scheduler != nil && authSnapshot != nil {
m.scheduler.upsertAuth(authSnapshot)
}
if clearModelQuota && result.Model != "" { if clearModelQuota && result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model) registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
@@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int {
} }
// isRequestInvalidError returns true if the error represents a client request // isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it checks for 400 Bad Request // error that should not be retried. Specifically, it treats 400 responses with
// with "invalid_request_error" in the message, indicating the request itself is // "invalid_request_error" and all 422 responses as request-shape failures,
// malformed and switching to a different auth will not help. // where switching auths or pooled upstream models will not help.
func isRequestInvalidError(err error) bool { func isRequestInvalidError(err error) bool {
if err == nil { if err == nil {
return false return false
} }
status := statusCodeFromError(err) status := statusCodeFromError(err)
if status != http.StatusBadRequest { switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusUnprocessableEntity:
return true
default:
return false return false
} }
return strings.Contains(err.Error(), "invalid_request_error")
} }
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) { func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
@@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
} }
} }
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { func (m *Manager) useSchedulerFastPath() bool {
if m == nil || m.scheduler == nil {
return false
}
return isBuiltInSelector(m.selector)
}
func shouldRetrySchedulerPick(err error) bool {
if err == nil {
return false
}
var cooldownErr *modelCooldownError
if errors.As(err, &cooldownErr) {
return true
}
var authErr *Error
if !errors.As(err, &authErr) || authErr == nil {
return false
}
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock() m.mu.RLock()
@@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
return authCopy, executor, nil return authCopy, executor, nil
} }
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
}
if errPick != nil {
return nil, nil, errPick
}
if selected == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers)) providerSet := make(map[string]struct{}, len(providers))
@@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
return authCopy, executor, providerKey, nil return authCopy, executor, providerKey, nil
} }
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
if !m.useSchedulerFastPath() {
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
eligibleProviders := make([]string, 0, len(providers))
seenProviders := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
if _, seen := seenProviders[providerKey]; seen {
continue
}
if _, okExecutor := m.Executor(providerKey); !okExecutor {
continue
}
seenProviders[providerKey] = struct{}{}
eligibleProviders = append(eligibleProviders, providerKey)
}
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
}
if errPick != nil {
return nil, nil, "", errPick
}
if selected == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
executor, okExecutor := m.Executor(providerKey)
if !okExecutor {
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error { func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil { if m.store == nil || auth == nil {
return nil return nil
@@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
current.NextRefreshAfter = now.Add(refreshFailureBackoff) current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()} current.LastError = &Error{Message: err.Error()}
m.auths[id] = current m.auths[id] = current
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
} }
m.mu.Unlock() m.mu.Unlock()
return return
@@ -0,0 +1,163 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerProviderTestExecutor struct {
provider string
}
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
prime func(*Manager, *Auth) error
}{
{
name: "register",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
return errRegister
},
},
{
name: "update",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
if errRegister != nil {
return errRegister
}
updated := auth.Clone()
updated.Metadata = map[string]any{"updated": true}
_, errUpdate := manager.Update(ctx, updated)
return errUpdate
},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
manager := NewManager(nil, &RoundRobinSelector{}, nil)
auth := &Auth{
ID: "refresh-entry-" + testCase.name,
Provider: "gemini",
}
if errPrime := testCase.prime(manager, auth); errPrime != nil {
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
}
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
var authErr *Error
if !errors.As(errPick, &authErr) || authErr == nil {
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
}
if authErr.Code != "auth_not_found" {
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
}
if got != nil {
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
}
manager.RefreshSchedulerEntry(auth.ID)
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() after refresh error = %v", errPick)
}
if got == nil || got.ID != auth.ID {
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
}
})
}
}
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
ctx := context.Background()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
oldAuth := &Auth{
ID: "cooldown-stale-old",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
t.Fatalf("register old auth: %v", errRegister)
}
manager.MarkResult(ctx, Result{
AuthID: oldAuth.ID,
Provider: "gemini",
Model: "scheduler-cooldown-rebuild-model",
Success: false,
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
})
newAuth := &Auth{
ID: "cooldown-stale-new",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
t.Fatalf("register new auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
t.Cleanup(func() {
reg.UnregisterClient(newAuth.ID)
})
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
var cooldownErr *modelCooldownError
if !errors.As(errPick, &cooldownErr) {
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
}
if got != nil {
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
}
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if executor == nil {
t.Fatal("pickNext() executor = nil")
}
if got == nil || got.ID != newAuth.ID {
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
}
}
+57 -13
View File
@@ -80,23 +80,24 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string
return upstreamModel return upstreamModel
} }
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string { func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
requestedModel = strings.TrimSpace(requestedModel) requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" { if requestedModel == "" {
return "" return thinking.SuffixResult{}, nil
} }
if len(models) == 0 {
return ""
}
requestResult := thinking.ParseSuffix(requestedModel) requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName base := requestResult.ModelName
if base == "" {
base = requestedModel
}
candidates := []string{base} candidates := []string{base}
if base != requestedModel { if base != requestedModel {
candidates = append(candidates, requestedModel) candidates = append(candidates, requestedModel)
} }
return requestResult, candidates
}
preserveSuffix := func(resolved string) string { func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
resolved = strings.TrimSpace(resolved) resolved = strings.TrimSpace(resolved)
if resolved == "" { if resolved == "" {
return "" return ""
@@ -110,23 +111,66 @@ func resolveModelAliasFromConfigModels(requestedModel string, models []modelAlia
return resolved return resolved
} }
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
if len(models) == 0 {
return nil
}
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
if len(candidates) == 0 {
return nil
}
out := make([]string, 0)
seen := make(map[string]struct{})
for i := range models { for i := range models {
name := strings.TrimSpace(models[i].GetName()) name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias()) alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates { for _, candidate := range candidates {
if candidate == "" { if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
continue continue
} }
if alias != "" && strings.EqualFold(alias, candidate) { resolved := candidate
if name != "" { if name != "" {
return preserveSuffix(name) resolved = name
} }
return preserveSuffix(candidate) resolved = preserveResolvedModelSuffix(resolved, requestResult)
key := strings.ToLower(strings.TrimSpace(resolved))
if key == "" {
break
} }
if name != "" && strings.EqualFold(name, candidate) { if _, exists := seen[key]; exists {
return preserveSuffix(name) break
}
seen[key] = struct{}{}
out = append(out, resolved)
break
} }
} }
if len(out) > 0 {
return out
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
for _, candidate := range candidates {
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
continue
}
return []string{preserveResolvedModelSuffix(name, requestResult)}
}
}
return nil
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
if len(resolved) > 0 {
return resolved[0]
} }
return "" return ""
} }
@@ -0,0 +1,419 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type openAICompatPoolExecutor struct {
id string
mu sync.Mutex
executeModels []string
countModels []string
streamModels []string
executeErrors map[string]error
countErrors map[string]error
streamFirstErrors map[string]error
streamPayloads map[string][]cliproxyexecutor.StreamChunk
}
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
err := e.executeErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.streamModels = append(e.streamModels, req.Model)
err := e.streamFirstErrors[req.Model]
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
e.mu.Unlock()
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
if err != nil {
ch <- cliproxyexecutor.StreamChunk{Err: err}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
if !hasCustomChunks {
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
} else {
for _, chunk := range chunks {
ch <- chunk
}
}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.countModels = append(e.countModels, req.Model)
err := e.countErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
_ = ctx
_ = auth
_ = req
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func (e *openAICompatPoolExecutor) CountModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.countModels))
copy(out, e.countModels)
return out
}
func (e *openAICompatPoolExecutor) StreamModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.streamModels))
copy(out, e.streamModels)
return out
}
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
t.Helper()
cfg := &internalconfig.Config{
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
Name: "pool",
Models: models,
}},
}
m := NewManager(nil, nil, nil)
m.SetConfig(cfg)
if executor == nil {
executor = &openAICompatPoolExecutor{id: "pool"}
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "pool-auth-" + t.Name(),
Provider: "pool",
Status: StatusActive,
Attributes: map[string]string{
"api_key": "test-key",
"compat_name": "pool",
"provider_key": "pool",
},
}
if _, err := m.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
return m
}
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
}
got := executor.CountModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("count calls = %v, want only first invalid model", got)
}
}
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
models := []modelAliasEntry{
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
}
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
if len(got) != len(want) {
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 3; i++ {
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute %d returned empty payload", i)
}
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute error = %v, want %v", err, invalidErr)
}
got := executor.ExecuteModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("execute calls = %v, want only first invalid model", got)
}
}
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute: %v", err)
}
if string(resp.Payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
"qwen3.5-plus": {},
},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
}
got := executor.StreamModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first invalid model", got)
}
}
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 2; i++ {
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute count %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute count %d returned empty payload", i)
}
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil {
t.Fatal("expected invalid request error")
}
if err != invalidErr {
t.Fatalf("error = %v, want %v", err, invalidErr)
}
if streamResult != nil {
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
}
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first upstream model", got)
}
}
+904
View File
@@ -0,0 +1,904 @@
package auth
import (
"context"
"sort"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
type schedulerStrategy int
const (
schedulerStrategyCustom schedulerStrategy = iota
schedulerStrategyRoundRobin
schedulerStrategyFillFirst
)
// scheduledState describes how an auth currently participates in a model shard.
type scheduledState int
const (
scheduledStateReady scheduledState = iota
scheduledStateCooldown
scheduledStateBlocked
scheduledStateDisabled
)
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
type authScheduler struct {
mu sync.Mutex
strategy schedulerStrategy
providers map[string]*providerScheduler
authProviders map[string]string
mixedCursors map[string]int
}
// providerScheduler stores auth metadata and model shards for a single provider.
type providerScheduler struct {
providerKey string
auths map[string]*scheduledAuthMeta
modelShards map[string]*modelScheduler
}
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
type scheduledAuthMeta struct {
auth *Auth
providerKey string
priority int
virtualParent string
websocketEnabled bool
supportedModelSet map[string]struct{}
}
// modelScheduler tracks ready and blocked auths for one provider/model combination.
type modelScheduler struct {
modelKey string
entries map[string]*scheduledAuth
priorityOrder []int
readyByPriority map[int]*readyBucket
blocked cooldownQueue
}
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
type scheduledAuth struct {
meta *scheduledAuthMeta
auth *Auth
state scheduledState
nextRetryAt time.Time
}
// readyBucket keeps the ready views for one priority level.
type readyBucket struct {
all readyView
ws readyView
}
// readyView holds the selection order for flat or grouped round-robin traversal.
type readyView struct {
flat []*scheduledAuth
cursor int
parentOrder []string
parentCursor int
children map[string]*childBucket
}
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
type childBucket struct {
items []*scheduledAuth
cursor int
}
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
type cooldownQueue []*scheduledAuth
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
func newAuthScheduler(selector Selector) *authScheduler {
return &authScheduler{
strategy: selectorStrategy(selector),
providers: make(map[string]*providerScheduler),
authProviders: make(map[string]string),
mixedCursors: make(map[string]int),
}
}
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
func selectorStrategy(selector Selector) schedulerStrategy {
switch selector.(type) {
case *FillFirstSelector:
return schedulerStrategyFillFirst
case nil, *RoundRobinSelector:
return schedulerStrategyRoundRobin
default:
return schedulerStrategyCustom
}
}
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
func (s *authScheduler) setSelector(selector Selector) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.strategy = selectorStrategy(selector)
clear(s.mixedCursors)
}
// rebuild recreates the complete scheduler state from an auth snapshot.
func (s *authScheduler) rebuild(auths []*Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.providers = make(map[string]*providerScheduler)
s.authProviders = make(map[string]string)
s.mixedCursors = make(map[string]int)
now := time.Now()
for _, auth := range auths {
s.upsertAuthLocked(auth, now)
}
}
// upsertAuth incrementally synchronizes one auth into the scheduler.
func (s *authScheduler) upsertAuth(auth *Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.upsertAuthLocked(auth, time.Now())
}
// removeAuth deletes one auth from every scheduler shard that references it.
func (s *authScheduler) removeAuth(authID string) {
if s == nil {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.removeAuthLocked(authID)
}
// pickSingle returns the next auth for a single provider/model request using scheduler state.
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
if s == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerKey := strings.ToLower(strings.TrimSpace(provider))
modelKey := canonicalModelKey(model)
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
s.mu.Lock()
defer s.mu.Unlock()
providerState := s.providers[providerKey]
if providerState == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
if shard == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) > 0 {
if _, ok := tried[entry.auth.ID]; ok {
return false
}
}
return true
}
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
return picked, nil
}
return nil, shard.unavailableErrorLocked(provider, model, predicate)
}
// pickMixed returns the next auth and provider for a mixed-provider request.
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
if s == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
normalized := normalizeProviderKeys(providers)
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
s.mu.Lock()
defer s.mu.Unlock()
if pinnedAuthID != "" {
providerKey := s.authProviders[pinnedAuthID]
if providerKey == "" || !containsProvider(normalized, providerKey) {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerState := s.providers[providerKey]
if providerState == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) == 0 {
return true
}
_, ok := tried[pinnedAuthID]
return !ok
}
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
return picked, providerKey, nil
}
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
}
predicate := triedPredicate(tried)
candidateShards := make([]*modelScheduler, len(normalized))
bestPriority := 0
hasCandidate := false
now := time.Now()
for providerIndex, providerKey := range normalized {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(modelKey, now)
candidateShards[providerIndex] = shard
if shard == nil {
continue
}
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
if !okPriority {
continue
}
if !hasCandidate || priorityReady > bestPriority {
bestPriority = priorityReady
hasCandidate = true
}
}
if !hasCandidate {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
if s.strategy == schedulerStrategyFillFirst {
for providerIndex, providerKey := range normalized {
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
if picked != nil {
return picked, providerKey, nil
}
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
}
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
now := time.Now()
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, providerKey := range providers {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
if shard == nil {
continue
}
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
total += localTotal
cooldownCount += localCooldownCount
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
earliest = localEarliest
}
}
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, "", resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// triedPredicate builds a filter that excludes auths already attempted for the current request.
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
if len(tried) == 0 {
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
}
return func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
_, ok := tried[entry.auth.ID]
return !ok
}
}
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
func normalizeProviderKeys(providers []string) []string {
seen := make(map[string]struct{}, len(providers))
out := make([]string, 0, len(providers))
for _, provider := range providers {
providerKey := strings.ToLower(strings.TrimSpace(provider))
if providerKey == "" {
continue
}
if _, ok := seen[providerKey]; ok {
continue
}
seen[providerKey] = struct{}{}
out = append(out, providerKey)
}
return out
}
// containsProvider reports whether provider is present in the normalized provider list.
func containsProvider(providers []string, provider string) bool {
for _, candidate := range providers {
if candidate == provider {
return true
}
}
return false
}
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
if auth == nil {
return
}
authID := strings.TrimSpace(auth.ID)
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
if authID == "" || providerKey == "" || auth.Disabled {
s.removeAuthLocked(authID)
return
}
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
if previousState := s.providers[previousProvider]; previousState != nil {
previousState.removeAuthLocked(authID)
}
}
meta := buildScheduledAuthMeta(auth)
s.authProviders[authID] = providerKey
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
}
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
func (s *authScheduler) removeAuthLocked(authID string) {
if authID == "" {
return
}
if providerKey := s.authProviders[authID]; providerKey != "" {
if providerState := s.providers[providerKey]; providerState != nil {
providerState.removeAuthLocked(authID)
}
delete(s.authProviders, authID)
}
}
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
if s.providers == nil {
s.providers = make(map[string]*providerScheduler)
}
providerState := s.providers[providerKey]
if providerState == nil {
providerState = &providerScheduler{
providerKey: providerKey,
auths: make(map[string]*scheduledAuthMeta),
modelShards: make(map[string]*modelScheduler),
}
s.providers[providerKey] = providerState
}
return providerState
}
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
virtualParent := ""
if auth.Attributes != nil {
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
}
return &scheduledAuthMeta{
auth: auth,
providerKey: providerKey,
priority: authPriority(auth),
virtualParent: virtualParent,
websocketEnabled: authWebsocketsEnabled(auth),
supportedModelSet: supportedModelSetForAuth(auth.ID),
}
}
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
func supportedModelSetForAuth(authID string) map[string]struct{} {
authID = strings.TrimSpace(authID)
if authID == "" {
return nil
}
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
if len(models) == 0 {
return nil
}
set := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
set[modelKey] = struct{}{}
}
return set
}
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
if p == nil || meta == nil || meta.auth == nil {
return
}
p.auths[meta.auth.ID] = meta
for modelKey, shard := range p.modelShards {
if shard == nil {
continue
}
if !meta.supportsModel(modelKey) {
shard.removeEntryLocked(meta.auth.ID)
continue
}
shard.upsertEntryLocked(meta, now)
}
}
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
func (p *providerScheduler) removeAuthLocked(authID string) {
if p == nil || authID == "" {
return
}
delete(p.auths, authID)
for _, shard := range p.modelShards {
if shard != nil {
shard.removeEntryLocked(authID)
}
}
}
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
if p == nil {
return nil
}
modelKey = canonicalModelKey(modelKey)
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
shard.promoteExpiredLocked(now)
return shard
}
shard := &modelScheduler{
modelKey: modelKey,
entries: make(map[string]*scheduledAuth),
readyByPriority: make(map[int]*readyBucket),
}
for _, meta := range p.auths {
if meta == nil || !meta.supportsModel(modelKey) {
continue
}
shard.upsertEntryLocked(meta, now)
}
p.modelShards[modelKey] = shard
return shard
}
// supportsModel reports whether the auth metadata currently supports modelKey.
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
modelKey = canonicalModelKey(modelKey)
if modelKey == "" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}
_, ok := m.supportedModelSet[modelKey]
return ok
}
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
if m == nil || meta == nil || meta.auth == nil {
return
}
entry, ok := m.entries[meta.auth.ID]
if !ok || entry == nil {
entry = &scheduledAuth{}
m.entries[meta.auth.ID] = entry
}
previousState := entry.state
previousNextRetryAt := entry.nextRetryAt
previousPriority := 0
previousParent := ""
previousWebsocketEnabled := false
if entry.meta != nil {
previousPriority = entry.meta.priority
previousParent = entry.meta.virtualParent
previousWebsocketEnabled = entry.meta.websocketEnabled
}
entry.meta = meta
entry.auth = meta.auth
entry.nextRetryAt = time.Time{}
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
return
}
m.rebuildIndexesLocked()
}
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
func (m *modelScheduler) removeEntryLocked(authID string) {
if m == nil || authID == "" {
return
}
if _, ok := m.entries[authID]; !ok {
return
}
delete(m.entries, authID)
m.rebuildIndexesLocked()
}
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
if m == nil || len(m.blocked) == 0 {
return
}
changed := false
for _, entry := range m.blocked {
if entry == nil || entry.auth == nil {
continue
}
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
continue
}
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
entry.nextRetryAt = time.Time{}
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
entry.nextRetryAt = time.Time{}
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
changed = true
}
if changed {
m.rebuildIndexesLocked()
}
}
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
m.promoteExpiredLocked(time.Now())
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
if !okPriority {
return nil
}
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
}
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
if m == nil {
return 0, false
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
return priority, true
}
}
return 0, false
}
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
var picked *scheduledAuth
if strategy == schedulerStrategyFillFirst {
picked = view.pickFirst(predicate)
} else {
picked = view.pickRoundRobin(predicate)
}
if picked == nil || picked.auth == nil {
return nil
}
return picked.auth
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, providerForError, resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
if m == nil {
return 0, 0, time.Time{}
}
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, entry := range m.entries {
if predicate != nil && !predicate(entry) {
continue
}
total++
if entry == nil || entry.auth == nil {
continue
}
if entry.state != scheduledStateCooldown {
continue
}
cooldownCount++
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
earliest = entry.nextRetryAt
}
}
return total, cooldownCount, earliest
}
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
func (m *modelScheduler) rebuildIndexesLocked() {
m.readyByPriority = make(map[int]*readyBucket)
m.priorityOrder = m.priorityOrder[:0]
m.blocked = m.blocked[:0]
priorityBuckets := make(map[int][]*scheduledAuth)
for _, entry := range m.entries {
if entry == nil || entry.auth == nil {
continue
}
switch entry.state {
case scheduledStateReady:
priority := entry.meta.priority
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
case scheduledStateCooldown, scheduledStateBlocked:
m.blocked = append(m.blocked, entry)
}
}
for priority, entries := range priorityBuckets {
sort.Slice(entries, func(i, j int) bool {
return entries[i].auth.ID < entries[j].auth.ID
})
m.readyByPriority[priority] = buildReadyBucket(entries)
m.priorityOrder = append(m.priorityOrder, priority)
}
sort.Slice(m.priorityOrder, func(i, j int) bool {
return m.priorityOrder[i] > m.priorityOrder[j]
})
sort.Slice(m.blocked, func(i, j int) bool {
left := m.blocked[i]
right := m.blocked[j]
if left == nil || right == nil {
return left != nil
}
if left.nextRetryAt.Equal(right.nextRetryAt) {
return left.auth.ID < right.auth.ID
}
if left.nextRetryAt.IsZero() {
return false
}
if right.nextRetryAt.IsZero() {
return true
}
return left.nextRetryAt.Before(right.nextRetryAt)
})
}
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
bucket := &readyBucket{}
bucket.all = buildReadyView(entries)
wsEntries := make([]*scheduledAuth, 0, len(entries))
for _, entry := range entries {
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
wsEntries = append(wsEntries, entry)
}
}
bucket.ws = buildReadyView(wsEntries)
return bucket
}
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
func buildReadyView(entries []*scheduledAuth) readyView {
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
if len(entries) == 0 {
return view
}
groups := make(map[string][]*scheduledAuth)
for _, entry := range entries {
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
return view
}
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
}
if len(groups) <= 1 {
return view
}
view.children = make(map[string]*childBucket, len(groups))
view.parentOrder = make([]string, 0, len(groups))
for parent := range groups {
view.parentOrder = append(view.parentOrder, parent)
}
sort.Strings(view.parentOrder)
for _, parent := range view.parentOrder {
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
}
return view
}
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
for _, entry := range v.flat {
if predicate == nil || predicate(entry) {
return entry
}
}
return nil
}
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
if len(v.parentOrder) > 1 && len(v.children) > 0 {
return v.pickGroupedRoundRobin(predicate)
}
if len(v.flat) == 0 {
return nil
}
start := 0
if len(v.flat) > 0 {
start = v.cursor % len(v.flat)
}
for offset := 0; offset < len(v.flat); offset++ {
index := (start + offset) % len(v.flat)
entry := v.flat[index]
if predicate != nil && !predicate(entry) {
continue
}
v.cursor = index + 1
return entry
}
return nil
}
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
start := 0
if len(v.parentOrder) > 0 {
start = v.parentCursor % len(v.parentOrder)
}
for offset := 0; offset < len(v.parentOrder); offset++ {
parentIndex := (start + offset) % len(v.parentOrder)
parent := v.parentOrder[parentIndex]
child := v.children[parent]
if child == nil || len(child.items) == 0 {
continue
}
itemStart := child.cursor % len(child.items)
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
itemIndex := (itemStart + itemOffset) % len(child.items)
entry := child.items[itemIndex]
if predicate != nil && !predicate(entry) {
continue
}
child.cursor = itemIndex + 1
v.parentCursor = parentIndex + 1
return entry
}
}
return nil
}
@@ -0,0 +1,216 @@
package auth
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerBenchmarkExecutor struct {
id string
}
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
b.Helper()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
providers := []string{"gemini"}
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
if mixed {
providers = []string{"gemini", "claude"}
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
}
reg := registry.GetGlobalRegistry()
model := "bench-model"
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
if withPriority {
priority := "0"
if index%2 == 0 {
priority = "10"
}
auth.Attributes = map[string]string{"priority": priority}
}
_, errRegister := manager.Register(context.Background(), auth)
if errRegister != nil {
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
}
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
}
manager.syncScheduler()
b.Cleanup(func() {
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
}
})
return manager, providers, model
}
func BenchmarkManagerPickNext500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNext1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextMixed500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil {
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
}
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
}
}
+503
View File
@@ -0,0 +1,503 @@
package auth
import (
"context"
"net/http"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerTestExecutor struct{}
func (schedulerTestExecutor) Identifier() string { return "test" }
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
type trackingSelector struct {
calls int
lastAuthID []string
}
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
s.calls++
s.lastAuthID = s.lastAuthID[:0]
for _, auth := range auths {
s.lastAuthID = append(s.lastAuthID, auth.ID)
}
if len(auths) == 0 {
return nil, nil
}
return auths[len(auths)-1], nil
}
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
scheduler := newAuthScheduler(selector)
scheduler.rebuild(auths)
return scheduler
}
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
t.Helper()
reg := registry.GetGlobalRegistry()
for _, authID := range authIDs {
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
}
t.Cleanup(func() {
for _, authID := range authIDs {
reg.UnregisterClient(authID)
}
})
}
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
)
want := []string{"high-a", "high-b", "high-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&FillFirstSelector{},
&Auth{ID: "b", Provider: "gemini"},
&Auth{ID: "a", Provider: "gemini"},
&Auth{ID: "c", Provider: "gemini"},
)
for index := 0; index < 3; index++ {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != "a" {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
}
}
}
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
t.Parallel()
model := "gemini-2.5-pro"
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{
ID: "cooldown-expired",
Provider: "gemini",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
NextRetryAfter: time.Now().Add(-1 * time.Second),
},
},
},
)
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickSingle() auth = nil")
}
if got.ID != "cooldown-expired" {
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
}
}
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
t.Parallel()
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
)
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
for index := range wantIDs {
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantIDs[index] {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
}
}
}
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex"},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "gemini-a", Provider: "gemini"},
&Auth{ID: "gemini-b", Provider: "gemini"},
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
t.Parallel()
model := "gpt-default"
registerSchedulerModels(t, "provider-low", model, "low")
registerSchedulerModels(t, "provider-high-a", model, "high-a")
registerSchedulerModels(t, "provider-high-b", model, "high-b")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
)
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
t.Parallel()
selector := &trackingSelector{}
manager := NewManager(nil, selector, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNext() auth = nil")
}
if selector.calls != 1 {
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
}
if len(selector.lastAuthID) != 2 {
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
}
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
}
}
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if manager.scheduler == nil {
t.Fatalf("manager.scheduler = nil")
}
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
}
manager.SetSelector(&FillFirstSelector{})
if manager.scheduler.strategy != schedulerStrategyFillFirst {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
}
}
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
}
if got == nil || got.ID != "auth-a" {
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
}
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
t.Fatalf("Update(auth-a) error = %v", errUpdate)
}
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
}
}
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() auth = nil")
}
if provider != "claude" {
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
}
if got.ID != "claude-a" {
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
}
}
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
reg.UnregisterClient("auth-a")
reg.UnregisterClient("auth-b")
})
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: false,
Error: &Error{HTTPStatus: 429, Message: "quota"},
})
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: true,
})
seen := make(map[string]struct{}, 2)
for index := 0; index < 2; index++ {
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
}
seen[got.ID] = struct{}{}
}
if len(seen) != 2 {
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
}
}
+5 -31
View File
@@ -1,16 +1,13 @@
package cliproxy package cliproxy
import ( import (
"context"
"net"
"net/http" "net/http"
"net/url"
"strings" "strings"
"sync" "sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
) )
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on // defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
@@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http.
if rt != nil { if rt != nil {
return rt return rt
} }
// Parse the proxy URL to determine the scheme. transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
proxyURL, errParse := url.Parse(proxyStr) if errBuild != nil {
if errParse != nil { log.Errorf("%v", errBuild)
log.Errorf("parse proxy URL failed: %v", errParse)
return nil return nil
} }
var transport *http.Transport if transport == nil {
// Handle different proxy schemes.
if proxyURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer.
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil return nil
} }
p.mu.Lock() p.mu.Lock()
+22
View File
@@ -0,0 +1,22 @@
package cliproxy
import (
"net/http"
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRoundTripperForDirectBypassesProxy(t *testing.T) {
t.Parallel()
provider := newDefaultRoundTripperProvider()
rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"})
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", rt)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
+22 -1
View File
@@ -312,6 +312,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// This operation may block on network calls, but the auth configuration // This operation may block on network calls, but the auth configuration
// is already effective at this point. // is already effective at this point.
s.registerModelsForAuth(auth) s.registerModelsForAuth(auth)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
// have an empty supportedModelSet (because Register/Update upserts into the
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
s.coreManager.RefreshSchedulerEntry(auth.ID)
} }
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
@@ -823,7 +829,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
} }
models = applyExcludedModels(models, excluded) models = applyExcludedModels(models, excluded)
case "codex": case "codex":
models = registry.GetOpenAIModels() codexPlanType := ""
if a.Attributes != nil {
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
}
switch strings.ToLower(codexPlanType) {
case "pro":
models = registry.GetCodexProModels()
case "plus":
models = registry.GetCodexPlusModels()
case "team":
models = registry.GetCodexTeamModels()
case "free":
models = registry.GetCodexFreeModels()
default:
models = registry.GetCodexProModels()
}
if entry := s.resolveConfigCodexKey(a); entry != nil { if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 { if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry) models = buildCodexConfigModels(entry)
+139
View File
@@ -0,0 +1,139 @@
package proxyutil
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
// Mode describes how a proxy setting should be interpreted.
type Mode int
const (
// ModeInherit means no explicit proxy behavior was configured.
ModeInherit Mode = iota
// ModeDirect means outbound requests must bypass proxies explicitly.
ModeDirect
// ModeProxy means a concrete proxy URL was configured.
ModeProxy
// ModeInvalid means the proxy setting is present but malformed or unsupported.
ModeInvalid
)
// Setting is the normalized interpretation of a proxy configuration value.
type Setting struct {
Raw string
Mode Mode
URL *url.URL
}
// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes.
func Parse(raw string) (Setting, error) {
trimmed := strings.TrimSpace(raw)
setting := Setting{Raw: trimmed}
if trimmed == "" {
setting.Mode = ModeInherit
return setting, nil
}
if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") {
setting.Mode = ModeDirect
return setting, nil
}
parsedURL, errParse := url.Parse(trimmed)
if errParse != nil {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("proxy URL missing scheme/host")
}
switch parsedURL.Scheme {
case "socks5", "http", "https":
setting.Mode = ModeProxy
setting.URL = parsedURL
return setting, nil
default:
setting.Mode = ModeInvalid
return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
}
}
// NewDirectTransport returns a transport that bypasses environment proxies.
func NewDirectTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
clone := transport.Clone()
clone.Proxy = nil
return clone
}
return &http.Transport{Proxy: nil}
}
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return NewDirectTransport(), setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
return &http.Transport{
Proxy: nil,
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}, setting.Mode, nil
}
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
// BuildDialer constructs a proxy dialer for settings that operate at the connection layer.
func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return proxy.Direct, setting.Mode, nil
case ModeProxy:
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
if errDialer != nil {
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
}
return dialer, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
+89
View File
@@ -0,0 +1,89 @@
package proxyutil
import (
"net/http"
"testing"
)
func TestParse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want Mode
wantErr bool
}{
{name: "inherit", input: "", want: ModeInherit},
{name: "direct", input: "direct", want: ModeDirect},
{name: "none", input: "none", want: ModeDirect},
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
setting, errParse := Parse(tt.input)
if tt.wantErr && errParse == nil {
t.Fatal("expected error, got nil")
}
if !tt.wantErr && errParse != nil {
t.Fatalf("unexpected error: %v", errParse)
}
if setting.Mode != tt.want {
t.Fatalf("mode = %d, want %d", setting.Mode, tt.want)
}
})
}
}
func TestBuildHTTPTransportDirectBypassesProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("direct")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeDirect {
t.Fatalf("mode = %d, want %d", mode, ModeDirect)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("http://proxy.example.com:8080")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := transport.Proxy(req)
if errProxy != nil {
t.Fatalf("transport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
}
}