Merge branch 'dev' into codex/custom-useragent-request

This commit is contained in:
Luis Pater
2026-03-11 22:55:50 +08:00
committed by GitHub
58 changed files with 7068 additions and 1724 deletions
+4
View File
@@ -15,6 +15,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
@@ -46,6 +48,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
+2
View File
@@ -12,6 +12,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- name: Set up Go
uses: actions/setup-go@v5
with:
+2
View File
@@ -16,6 +16,8 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Refresh models catalog
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
- run: git fetch --force --tags
- uses: actions/setup-go@v4
with:
+4
View File
@@ -150,6 +150,10 @@ A Windows tray application implemented using PowerShell scripts, without relying
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
+4
View File
@@ -149,6 +149,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
+3
View File
@@ -24,6 +24,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
@@ -494,6 +495,7 @@ func main() {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
@@ -566,6 +568,7 @@ func main() {
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
cmd.StartService(cfg, configFilePath, password)
}
}
+17
View File
@@ -63,6 +63,7 @@ error-logs-max-files: 10
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
@@ -110,6 +111,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
@@ -128,6 +130,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
@@ -146,6 +149,7 @@ nonstream-keepalive-interval: 0
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
@@ -191,10 +195,22 @@ nonstream-keepalive-interval: 0
# api-key-entries:
# - api-key: "sk-or-v1-...b780"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,
# # and if the chosen upstream fails before producing output, the request will
# # continue with the next upstream model in the same alias pool.
# - name: "qwen3.5-plus"
# alias: "claude-opus-4.66"
# - name: "glm-5"
# alias: "claude-opus-4.66"
# - name: "kimi-k2.5"
# alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
# vertex-api-key:
@@ -202,6 +218,7 @@ nonstream-keepalive-interval: 0
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
+5 -41
View File
@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
@@ -14,8 +13,8 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
@@ -660,45 +659,10 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
}
func buildProxyTransport(proxyStr string) *http.Transport {
proxyStr = strings.TrimSpace(proxyStr)
if proxyStr == "" {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.WithError(errBuild).Debug("build proxy transport failed")
return nil
}
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.WithError(errParse).Debug("parse proxy URL failed")
return nil
}
if proxyURL.Scheme == "" || proxyURL.Host == "" {
log.Debug("proxy URL missing scheme/host")
return nil
}
if proxyURL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
return nil
}
return &http.Transport{
Proxy: nil,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
}
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
return nil
return transport
}
@@ -1,173 +1,58 @@
package management
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
_ = ctx
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, a := range s.items {
out = append(out, a.Clone())
}
return out, nil
}
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
_ = ctx
if auth == nil {
return "", nil
}
s.mu.Lock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth.Clone()
s.mu.Unlock()
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
_ = ctx
s.mu.Lock()
delete(s.items, id)
s.mu.Unlock()
return nil
}
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
if r.Method != http.MethodPost {
t.Fatalf("expected POST, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
t.Fatalf("unexpected content-type: %s", ct)
}
bodyBytes, _ := io.ReadAll(r.Body)
_ = r.Body.Close()
values, err := url.ParseQuery(string(bodyBytes))
if err != nil {
t.Fatalf("parse form: %v", err)
}
if values.Get("grant_type") != "refresh_token" {
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
}
if values.Get("refresh_token") != "rt" {
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
}
if values.Get("client_id") != antigravityOAuthClientID {
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
}
if values.Get("client_secret") != antigravityOAuthClientSecret {
t.Fatalf("unexpected client_secret")
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-token",
"refresh_token": "rt2",
"expires_in": int64(3600),
"token_type": "Bearer",
})
}))
t.Cleanup(srv.Close)
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
store := &memoryAuthStore{}
manager := coreauth.NewManager(store, nil, nil)
auth := &coreauth.Auth{
ID: "antigravity-test.json",
FileName: "antigravity-test.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "old-token",
"refresh_token": "rt",
"expires_in": int64(3600),
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
h := &Handler{authManager: manager}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if token != "new-token" {
t.Fatalf("expected refreshed token, got %q", token)
}
if callCount != 1 {
t.Fatalf("expected 1 refresh call, got %d", callCount)
}
updated, ok := manager.GetByID(auth.ID)
if !ok || updated == nil {
t.Fatalf("expected auth in manager after update")
}
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
t.Fatalf("expected manager metadata updated, got %q", got)
if httpTransport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
var callCount int
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.WriteHeader(http.StatusInternalServerError)
}))
t.Cleanup(srv.Close)
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
t.Parallel()
originalURL := antigravityOAuthTokenURL
antigravityOAuthTokenURL = srv.URL
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
auth := &coreauth.Auth{
ID: "antigravity-valid.json",
FileName: "antigravity-valid.json",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"access_token": "ok-token",
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
h := &Handler{}
token, err := h.resolveTokenForAuth(context.Background(), auth)
if err != nil {
t.Fatalf("resolveTokenForAuth: %v", err)
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if token != "ok-token" {
t.Fatalf("expected existing token, got %q", token)
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
if callCount != 0 {
t.Fatalf("expected no refresh calls, got %d", callCount)
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
}
}
@@ -1306,12 +1306,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
return
}
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
return
}
ts.ProjectID = strings.Join(projects, ",")
@@ -1320,7 +1320,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, "Google One auto-discovery failed")
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
@@ -1331,19 +1331,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
return
}
@@ -1356,13 +1356,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project")
SetOAuthSessionError(state, "Cloud AI API not enabled")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
}
@@ -0,0 +1,49 @@
package management
import (
"context"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
return out, nil
}
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.items, id)
return nil
}
func (s *memoryAuthStore) SetBaseDir(string) {}
+7 -12
View File
@@ -4,12 +4,12 @@ package claude
import (
"net/http"
"net/url"
"strings"
"sync"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if cfg != nil && cfg.ProxyURL != "" {
proxyURL, err := url.Parse(cfg.ProxyURL)
if err != nil {
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
} else {
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
if err != nil {
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
} else {
dialer = pDialer
}
if cfg != nil {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
}
+9 -31
View File
@@ -10,9 +10,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
@@ -20,9 +18,9 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
@@ -80,36 +78,16 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
} else if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
var err error
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: ClientID,
+139 -13
View File
@@ -1,5 +1,5 @@
// Package registry provides model definitions and lookup helpers for various AI providers.
// Static model metadata is stored in model_definitions_static_data.go.
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
package registry
import (
@@ -7,6 +7,131 @@ import (
"strings"
)
// AntigravityModelConfig captures static antigravity model overrides, including
// Thinking budget limits and provider max completion tokens.
type AntigravityModelConfig struct {
Thinking *ThinkingSupport `json:"thinking,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
// staticModelsJSON mirrors the top-level structure of models.json.
type staticModelsJSON struct {
Claude []*ModelInfo `json:"claude"`
Gemini []*ModelInfo `json:"gemini"`
Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"`
CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"`
CodexPro []*ModelInfo `json:"codex-pro"`
Qwen []*ModelInfo `json:"qwen"`
IFlow []*ModelInfo `json:"iflow"`
Kimi []*ModelInfo `json:"kimi"`
Antigravity map[string]*AntigravityModelConfig `json:"antigravity"`
}
// GetClaudeModels returns the standard Claude model definitions.
func GetClaudeModels() []*ModelInfo {
return cloneModelInfos(getModels().Claude)
}
// GetGeminiModels returns the standard Gemini model definitions.
func GetGeminiModels() []*ModelInfo {
return cloneModelInfos(getModels().Gemini)
}
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
func GetGeminiVertexModels() []*ModelInfo {
return cloneModelInfos(getModels().Vertex)
}
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
func GetGeminiCLIModels() []*ModelInfo {
return cloneModelInfos(getModels().GeminiCLI)
}
// GetAIStudioModels returns model definitions for AI Studio.
func GetAIStudioModels() []*ModelInfo {
return cloneModelInfos(getModels().AIStudio)
}
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
func GetCodexFreeModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexFree)
}
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
func GetCodexTeamModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexTeam)
}
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
func GetCodexPlusModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPlus)
}
// GetCodexProModels returns model definitions for the Codex pro plan tier.
func GetCodexProModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPro)
}
// GetQwenModels returns the standard Qwen model definitions.
func GetQwenModels() []*ModelInfo {
return cloneModelInfos(getModels().Qwen)
}
// GetIFlowModels returns the standard iFlow model definitions.
func GetIFlowModels() []*ModelInfo {
return cloneModelInfos(getModels().IFlow)
}
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
func GetKimiModels() []*ModelInfo {
return cloneModelInfos(getModels().Kimi)
}
// GetAntigravityModelConfig returns static configuration for antigravity models.
// Keys use upstream model names returned by the Antigravity models endpoint.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
data := getModels()
if len(data.Antigravity) == 0 {
return nil
}
out := make(map[string]*AntigravityModelConfig, len(data.Antigravity))
for k, v := range data.Antigravity {
out[k] = cloneAntigravityModelConfig(v)
}
return out
}
func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig {
if cfg == nil {
return nil
}
copyConfig := *cfg
if cfg.Thinking != nil {
copyThinking := *cfg.Thinking
if len(cfg.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...)
}
copyConfig.Thinking = &copyThinking
}
return &copyConfig
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*ModelInfo, len(models))
for i, m := range models {
out[i] = cloneModelInfo(m)
}
return out
}
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
// It returns nil when the channel is unknown.
//
@@ -35,7 +160,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
case "aistudio":
return GetAIStudioModels()
case "codex":
return GetOpenAIModels()
return GetCodexProModels()
case "qwen":
return GetQwenModels()
case "iflow":
@@ -77,27 +202,28 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
return nil
}
data := getModels()
allModels := [][]*ModelInfo{
GetClaudeModels(),
GetGeminiModels(),
GetGeminiVertexModels(),
GetGeminiCLIModels(),
GetAIStudioModels(),
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
data.Claude,
data.Gemini,
data.Vertex,
data.GeminiCLI,
data.AIStudio,
data.CodexPro,
data.Qwen,
data.IFlow,
data.Kimi,
}
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return m
return cloneModelInfo(m)
}
}
}
// Check Antigravity static config
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil {
return &ModelInfo{
ID: modelID,
Thinking: cfg.Thinking,
File diff suppressed because it is too large Load Diff
+139 -29
View File
@@ -62,6 +62,11 @@ type ModelInfo struct {
UserDefined bool `json:"-"`
}
type availableModelsCacheEntry struct {
models []map[string]any
expiresAt time.Time
}
// ThinkingSupport describes a model family's supported internal reasoning budget range.
// Values are interpreted in provider-native token units.
type ThinkingSupport struct {
@@ -116,6 +121,8 @@ type ModelRegistry struct {
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
availableModelsCache map[string]availableModelsCacheEntry
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
@@ -128,15 +135,28 @@ var registryOnce sync.Once
func GetGlobalRegistry() *ModelRegistry {
registryOnce.Do(func() {
globalRegistry = &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
availableModelsCache: make(map[string]availableModelsCacheEntry),
mutex: &sync.RWMutex{},
}
})
return globalRegistry
}
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
if r.availableModelsCache == nil {
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
}
}
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
if len(r.availableModelsCache) == 0 {
return
}
clear(r.availableModelsCache)
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
@@ -151,9 +171,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
}
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return info
return cloneModelInfo(info)
}
return LookupStaticModelInfo(modelID)
return cloneModelInfo(LookupStaticModelInfo(modelID))
}
// SetHook sets an optional hook for observing model registration changes.
@@ -211,6 +231,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
provider := strings.ToLower(clientProvider)
uniqueModelIDs := make([]string, 0, len(models))
@@ -236,6 +257,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientModels, clientID)
delete(r.clientModelInfos, clientID)
delete(r.clientProviders, clientID)
r.invalidateAvailableModelsCacheLocked()
misc.LogCredentialSeparator()
return
}
@@ -263,6 +285,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
} else {
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
@@ -406,6 +429,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
@@ -509,6 +533,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
if len(model.SupportedOutputModalities) > 0 {
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
}
if model.Thinking != nil {
copyThinking := *model.Thinking
if len(model.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
}
copyModel.Thinking = &copyThinking
}
return &copyModel
}
@@ -538,6 +569,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.unregisterClientInternal(clientID)
r.invalidateAvailableModelsCacheLocked()
}
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
@@ -604,9 +636,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
registration.QuotaExceededClients[clientID] = new(time.Now())
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}
@@ -618,9 +653,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID)
r.invalidateAvailableModelsCacheLocked()
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
}
}
@@ -636,6 +673,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil {
@@ -649,6 +687,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
}
registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else {
@@ -666,6 +705,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil {
@@ -676,6 +716,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
}
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}
@@ -711,22 +752,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
// Returns:
// - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
r.mutex.RLock()
defer r.mutex.RUnlock()
now := time.Now()
models := make([]map[string]any, 0)
r.mutex.RLock()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
models := cloneModelMaps(cache.models)
r.mutex.RUnlock()
return models
}
r.mutex.RUnlock()
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
return cloneModelMaps(cache.models)
}
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
models: cloneModelMaps(models),
expiresAt: expiresAt,
}
return models
}
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
quotaExpiredDuration := 5 * time.Minute
var expiresAt time.Time
for _, registration := range r.models {
// Check if model has any non-quota-exceeded clients
availableClients := registration.Count
now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(quotaExpiredDuration)
if now.Before(recoveryAt) {
expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
expiresAt = recoveryAt
}
}
}
@@ -747,7 +818,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
effectiveClients = 0
}
// Include models that have available clients, or those solely cooling down.
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
model := r.convertModelToMap(registration.Info, handlerType)
if model != nil {
@@ -756,7 +826,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
}
}
return models
return models, expiresAt
}
func cloneModelMaps(models []map[string]any) []map[string]any {
cloned := make([]map[string]any, 0, len(models))
for _, model := range models {
if model == nil {
cloned = append(cloned, nil)
continue
}
copyModel := make(map[string]any, len(model))
for key, value := range model {
copyModel[key] = cloneModelMapValue(value)
}
cloned = append(cloned, copyModel)
}
return cloned
}
func cloneModelMapValue(value any) any {
switch typed := value.(type) {
case map[string]any:
copyMap := make(map[string]any, len(typed))
for key, entry := range typed {
copyMap[key] = cloneModelMapValue(entry)
}
return copyMap
case []any:
copySlice := make([]any, len(typed))
for i, entry := range typed {
copySlice[i] = cloneModelMapValue(entry)
}
return copySlice
case []string:
return append([]string(nil), typed...)
default:
return value
}
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
@@ -872,11 +979,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, entry.info)
result = append(result, cloneModelInfo(entry.info))
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, registration.Info)
result = append(result, cloneModelInfo(registration.Info))
}
}
}
@@ -985,13 +1092,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return info
return cloneModelInfo(info)
}
}
}
}
// Fallback to global info (last registered)
return reg.Info
return cloneModelInfo(reg.Info)
}
return nil
}
@@ -1031,7 +1138,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["max_completion_tokens"] = model.MaxCompletionTokens
}
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
}
return result
@@ -1075,13 +1182,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
result["outputTokenLimit"] = model.OutputTokenLimit
}
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedInputModalities) > 0 {
result["supportedInputModalities"] = model.SupportedInputModalities
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
}
if len(model.SupportedOutputModalities) > 0 {
result["supportedOutputModalities"] = model.SupportedOutputModalities
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
}
return result
@@ -1111,15 +1218,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
invalidated := false
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
}
}
}
if invalidated {
r.invalidateAvailableModelsCacheLocked()
}
}
// GetFirstAvailableModel returns the first available model for the given handler type.
@@ -1133,8 +1245,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
// - string: The model ID of the first available model, or empty string if none available
// - error: An error if no models are available
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Get all available models for this handler type
models := r.GetAvailableModels(handlerType)
@@ -1194,13 +1304,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
// Prefer client's own model info to preserve original type/owned_by
if clientInfos != nil {
if info, ok := clientInfos[modelID]; ok && info != nil {
result = append(result, info)
result = append(result, cloneModelInfo(info))
continue
}
}
// Fallback to global registry (for backwards compatibility)
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
result = append(result, reg.Info)
result = append(result, cloneModelInfo(reg.Info))
}
}
return result
@@ -0,0 +1,54 @@
package registry
import "testing"
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected 1 model, got %d", len(first))
}
first[0]["id"] = "mutated"
first[0]["display_name"] = "Mutated"
second := r.GetAvailableModels("openai")
if got := second[0]["id"]; got != "m1" {
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
}
if got := second[0]["display_name"]; got != "Model One" {
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
}
}
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected 1 model, got %d", len(models))
}
if got := models[0]["display_name"]; got != "Model One" {
t.Fatalf("expected initial display_name Model One, got %v", got)
}
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
models = r.GetAvailableModels("openai")
if got := models[0]["display_name"]; got != "Model One Updated" {
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
}
r.SuspendClientModel("client-1", "m1", "manual")
models = r.GetAvailableModels("openai")
if len(models) != 0 {
t.Fatalf("expected no available models after suspension, got %d", len(models))
}
r.ResumeClientModel("client-1", "m1")
models = r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to reappear after resume, got %d", len(models))
}
}
@@ -0,0 +1,149 @@
package registry
import (
"testing"
"time"
)
func TestGetModelInfoReturnsClone(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
}})
first := r.GetModelInfo("m1", "gemini")
if first == nil {
t.Fatal("expected model info")
}
first.DisplayName = "mutated"
first.Thinking.Levels[0] = "mutated"
second := r.GetModelInfo("m1", "gemini")
if second.DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
}
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
}
}
func TestGetModelsForClientReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetModelsForClient("client-1")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetModelsForClient("client-1")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetAvailableModelsByProvider("gemini")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetAvailableModelsByProvider("gemini")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
r.SetModelQuotaExceeded("client-1", "m1")
if models := r.GetAvailableModels("openai"); len(models) != 1 {
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
}
r.mutex.Lock()
quotaTime := time.Now().Add(-6 * time.Minute)
r.models["m1"].QuotaExceededClients["client-1"] = &quotaTime
r.mutex.Unlock()
r.CleanupExpiredQuotas()
if count := r.GetModelCount("m1"); count != 1 {
t.Fatalf("expected model count 1 after cleanup, got %d", count)
}
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
}
if got := models[0]["id"]; got != "m1" {
t.Fatalf("expected model id m1, got %v", got)
}
}
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
SupportedParameters: []string{"temperature", "top_p"},
}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected one model, got %d", len(first))
}
params, ok := first[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 {
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
}
params[0] = "mutated"
second := r.GetAvailableModels("openai")
params, ok = second[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 || params[0] != "temperature" {
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
}
}
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
first := LookupModelInfo("glm-4.6")
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
t.Fatalf("expected static model with thinking levels, got %+v", first)
}
first.Thinking.Levels[0] = "mutated"
second := LookupModelInfo("glm-4.6")
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
t.Fatalf("expected static lookup clone, got %+v", second)
}
}
+198
View File
@@ -0,0 +1,198 @@
package registry
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
modelsFetchTimeout = 30 * time.Second
)
var modelsURLs = []string{
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
"https://models.router-for.me/models.json",
}
//go:embed models/models.json
var embeddedModelsJSON []byte
type modelStore struct {
mu sync.RWMutex
data *staticModelsJSON
}
var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once
func init() {
// Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
}
}
// StartModelsUpdater runs a one-time models refresh on startup.
// It blocks until the startup fetch attempt finishes so service initialization
// can wait for the refreshed catalog before registering auth-backed models.
// Safe to call multiple times; only one refresh will run.
func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() {
runModelsUpdater(ctx)
})
}
func runModelsUpdater(ctx context.Context) {
// Try network fetch once on startup, then stop.
// Periodic refresh is disabled - models are only refreshed at startup.
tryRefreshModels(ctx)
}
func tryRefreshModels(ctx context.Context) {
client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
if err != nil {
cancel()
log.Debugf("models fetch request creation failed for %s: %v", url, err)
continue
}
resp, err := client.Do(req)
if err != nil {
cancel()
log.Debugf("models fetch failed from %s: %v", url, err)
continue
}
if resp.StatusCode != 200 {
resp.Body.Close()
cancel()
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
cancel()
if err != nil {
log.Debugf("models fetch read error from %s: %v", url, err)
continue
}
if err := loadModelsFromBytes(data, url); err != nil {
log.Warnf("models parse failed from %s: %v", url, err)
continue
}
log.Infof("models updated from %s", url)
return
}
log.Warn("models refresh failed from all URLs, using current data")
}
func loadModelsFromBytes(data []byte, source string) error {
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
return fmt.Errorf("%s: decode models catalog: %w", source, err)
}
if err := validateModelsCatalog(&parsed); err != nil {
return fmt.Errorf("%s: validate models catalog: %w", source, err)
}
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = &parsed
modelsCatalogStore.mu.Unlock()
return nil
}
func getModels() *staticModelsJSON {
modelsCatalogStore.mu.RLock()
defer modelsCatalogStore.mu.RUnlock()
return modelsCatalogStore.data
}
func validateModelsCatalog(data *staticModelsJSON) error {
if data == nil {
return fmt.Errorf("catalog is nil")
}
requiredSections := []struct {
name string
models []*ModelInfo
}{
{name: "claude", models: data.Claude},
{name: "gemini", models: data.Gemini},
{name: "vertex", models: data.Vertex},
{name: "gemini-cli", models: data.GeminiCLI},
{name: "aistudio", models: data.AIStudio},
{name: "codex-free", models: data.CodexFree},
{name: "codex-team", models: data.CodexTeam},
{name: "codex-plus", models: data.CodexPlus},
{name: "codex-pro", models: data.CodexPro},
{name: "qwen", models: data.Qwen},
{name: "iflow", models: data.IFlow},
{name: "kimi", models: data.Kimi},
}
for _, section := range requiredSections {
if err := validateModelSection(section.name, section.models); err != nil {
return err
}
}
if err := validateAntigravitySection(data.Antigravity); err != nil {
return err
}
return nil
}
func validateModelSection(section string, models []*ModelInfo) error {
if len(models) == 0 {
return fmt.Errorf("%s section is empty", section)
}
seen := make(map[string]struct{}, len(models))
for i, model := range models {
if model == nil {
return fmt.Errorf("%s[%d] is null", section, i)
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
return fmt.Errorf("%s[%d] has empty id", section, i)
}
if _, exists := seen[modelID]; exists {
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
}
seen[modelID] = struct{}{}
}
return nil
}
func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error {
if len(configs) == 0 {
return fmt.Errorf("antigravity section is empty")
}
for modelID, cfg := range configs {
trimmedID := strings.TrimSpace(modelID)
if trimmedID == "" {
return fmt.Errorf("antigravity contains empty model id")
}
if cfg == nil {
return fmt.Errorf("antigravity[%q] is null", trimmedID)
}
}
return nil
}
File diff suppressed because it is too large Load Diff
+23 -7
View File
@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
}
return true
})
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String())
result += "," + partJSON
}
result += "]"
@@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int {
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
@@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
}
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
@@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
@@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
}
for _, item := range content {
if obj, ok := asObject(item); ok {
normalizeTTLForBlock(obj, &seen5m)
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
}
if !modified {
return payload
}
return marshalPayloadObject(payload, root)
}
@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
}
}
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
// Payload where no TTL normalization is needed (all blocks use 1h with no
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
// that json.Marshal would escape to \u003c etc., altering byte identity.
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := normalizeCacheControlTTL(payload)
if !bytes.Equal(out, payload) {
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
@@ -829,8 +842,8 @@ func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
system := gjson.GetBytes(out, "system")
if !system.IsArray() {
t.Fatalf("system should be an array, got %s", system.Type)
}
blocks := system.Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
}
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
}
if blocks[2].Get("text").String() != "You are a helpful assistant." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
}
}
// Test case 2: Strict mode drops the string system prompt
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, true)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 3: Empty string system prompt does not produce a spurious block
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 4: Array system prompt is unaffected by the string handling
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != "Be concise." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
}
// Test case 5: Special characters in string system prompt survive conversion
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}
@@ -23,6 +23,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return dialer
}
parsedURL, errParse := url.Parse(proxyURL)
setting, errParse := proxyutil.Parse(proxyURL)
if errParse != nil {
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
log.Errorf("codex websockets executor: %v", errParse)
return dialer
}
switch parsedURL.Scheme {
switch setting.Mode {
case proxyutil.ModeDirect:
dialer.Proxy = nil
return dialer
case proxyutil.ModeProxy:
default:
return dialer
}
switch setting.URL.Scheme {
case "socks5":
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
return dialer
@@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
return socksDialer.Dial(network, addr)
}
case "http", "https":
dialer.Proxy = http.ProxyURL(parsedURL)
dialer.Proxy = http.ProxyURL(setting.URL)
default:
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
}
return dialer
@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
@@ -187,3 +188,16 @@ func contextWithGinHeaders(headers map[string]string) context.Context {
}
return context.WithValue(context.Background(), "gin", ginCtx)
}
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Parallel()
dialer := newProxyAwareWebsocketDialer(
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
)
if dialer.Proxy != nil {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}
@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
// Imagen models don't support streaming, skip SSE params
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
+4 -41
View File
@@ -2,16 +2,14 @@ package executor
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
@@ -72,45 +70,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
// Returns:
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
func buildProxyTransport(proxyURL string) *http.Transport {
if proxyURL == "" {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
parsedURL, errParse := url.Parse(proxyURL)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
return nil
}
var transport *http.Transport
// Handle different proxy schemes
if parsedURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication
var proxyAuth *proxy.Auth
if parsedURL.User != nil {
username := parsedURL.User.Username()
password, _ := parsedURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
return nil
}
return transport
}
@@ -0,0 +1,30 @@
package executor
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
0,
)
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
+7 -1
View File
@@ -257,7 +257,10 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
} else {
config = extractThinkingConfig(body, toFormat)
config = extractThinkingConfig(body, fromFormat)
if !hasThinkingConfig(config) && fromFormat != toFormat {
config = extractThinkingConfig(body, toFormat)
}
}
if !hasThinkingConfig(config) {
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
if config.Mode != ModeLevel {
return config
}
if toFormat == "claude" {
return config
}
if !isBudgetCapableProvider(toFormat) {
return config
}
@@ -0,0 +1,55 @@
package thinking_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
"github.com/tidwall/gjson"
)
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-user-defined-claude-" + t.Name()
modelID := "custom-claude-4-6"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
t.Cleanup(func() {
reg.UnregisterClient(clientID)
})
tests := []struct {
name string
model string
body []byte
}{
{
name: "claude adaptive effort body",
model: modelID,
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
},
{
name: "suffix level",
model: modelID + "(high)",
body: []byte(`{}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
if err != nil {
t.Fatalf("ApplyThinking() error = %v", err)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
}
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
})
}
}
@@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
if effort == "max" {
effort = "high"
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
@@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
tests := []struct {
name string
effort string
expected string
}{
{"low", "low", "low"},
{"medium", "medium", "medium"},
{"high", "high", "high"},
{"max", "max", "high"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"},
"output_config": {"effort": "` + tt.effort + `"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking")
}
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
})
}
}
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"thinking": {"type": "adaptive"}
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
outputStr := string(output)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if !thinkingConfig.Exists() {
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
}
if thinkingConfig.Get("thinkingLevel").String() != "high" {
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
}
@@ -15,6 +15,7 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
// Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system")
if systemsResult.IsArray() {
systemResults := systemsResult.Array()
if systemsResult.Exists() {
message := `{"type":"message","role":"developer","content":[]}`
contentIndex := 0
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
systemTypeResult := systemResult.Get("type")
if systemTypeResult.String() == "text" {
text := systemResult.Get("text").String()
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
continue
appendSystemText := func(text string) {
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
return
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
if systemsResult.Type == gjson.String {
appendSystemText(systemsResult.String())
} else if systemsResult.IsArray() {
systemResults := systemsResult.Array()
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
if systemResult.Get("type").String() == "text" {
appendSystemText(systemResult.Get("text").String())
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
}
if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message)
}
@@ -0,0 +1,89 @@
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasDeveloper bool
wantTexts []string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Block 1", "Block 2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
inputs := resultJSON.Get("input").Array()
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
if hasDeveloper != tt.wantHasDeveloper {
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
}
if !tt.wantHasDeveloper {
return
}
content := inputs[0].Get("content").Array()
if len(content) != len(tt.wantTexts) {
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
}
for i, wantText := range tt.wantTexts {
if gotType := content[i].Get("type").String(); gotType != "input_text" {
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
}
if gotText := content[i].Get("text").String(); gotText != wantText {
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
}
}
})
}
}
@@ -12,6 +12,7 @@ import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
@@ -14,6 +14,7 @@ import (
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> system_instruction as a user message style
// system -> systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if instructions := root.Get("instructions"); instructions.Exists() {
systemInstr := `{"parts":[{"text":""}]}`
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
// Convert input messages to Gemini contents format
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
if strings.EqualFold(itemRole, "system") {
if contentArray := item.Get("content"); contentArray.Exists() {
systemInstr := ""
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
systemInstr = systemInstructionResult.Raw
} else {
systemInstr = `{"parts":[]}`
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
}
if systemInstr != `{"parts":[]}` {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
}
continue
@@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_start for tool_use
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
}
@@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
@@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
argsStr := util.FixJSON(tc.Get("function.arguments").String())
@@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
+24
View File
@@ -0,0 +1,24 @@
package util
import (
"fmt"
"regexp"
"sync/atomic"
"time"
)
var (
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
claudeToolUseIDCounter uint64
)
// SanitizeClaudeToolID ensures the given id conforms to Claude's
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
// replaced with '_'; an empty result gets a generated fallback.
func SanitizeClaudeToolID(id string) string {
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
if s == "" {
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
}
return s
}
+8 -33
View File
@@ -4,50 +4,25 @@
package util
import (
"context"
"net"
"net/http"
"net/url"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
// to route requests through the configured proxy server.
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
var transport *http.Transport
// Attempt to parse the proxy URL from the configuration.
proxyURL, errParse := url.Parse(cfg.ProxyURL)
if errParse == nil {
// Handle different proxy schemes.
if proxyURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication.
var proxyAuth *proxy.Auth
if proxyURL.User != nil {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return httpClient
}
// Set up a custom transport using the SOCKS5 dialer.
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
if cfg == nil || httpClient == nil {
return httpClient
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
}
// If a new transport was created, apply it to the HTTP client.
if transport != nil {
httpClient.Transport = transport
}
+11
View File
@@ -10,6 +10,7 @@ import (
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
@@ -149,6 +150,16 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
// For codex auth files, extract plan_type from the JWT id_token.
if provider == "codex" {
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
a.Attributes["plan_type"] = pt
}
}
}
}
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
@@ -34,6 +34,8 @@ const (
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -825,18 +827,71 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
if !appendWebsocketLogString(builder, "\n") {
return
}
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
if !appendWebsocketLogString(builder, "websocket.") {
return
}
if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
}
func websocketPayloadEventType(payload []byte) string {
@@ -266,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) {
}
}
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
+3
View File
@@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"plan_type": planType,
},
}, nil
}
+571 -108
View File
@@ -134,6 +134,7 @@ type Manager struct {
hook Hook
mu sync.RWMutex
auths map[string]*Auth
scheduler *authScheduler
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
@@ -149,6 +150,9 @@ type Manager struct {
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelAlias atomic.Value
// modelPoolOffsets tracks per-auth alias pool rotation state.
modelPoolOffsets map[string]int
// runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value
@@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
hook: hook,
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
manager.scheduler = newAuthScheduler(selector)
return manager
}
func isBuiltInSelector(selector Selector) bool {
switch selector.(type) {
case *RoundRobinSelector, *FillFirstSelector:
return true
default:
return false
}
}
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
if m == nil || m.scheduler == nil {
return
}
m.scheduler.rebuild(auths)
}
func (m *Manager) syncScheduler() {
if m == nil || m.scheduler == nil {
return
}
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
// because the initial scheduler.upsertAuth during Register/Update runs before
// registerModelsForAuth and therefore snapshots an empty model set.
func (m *Manager) RefreshSchedulerEntry(authID string) {
if m == nil || m.scheduler == nil || authID == "" {
return
}
m.mu.RLock()
auth, ok := m.auths[authID]
if !ok || auth == nil {
m.mu.RUnlock()
return
}
snapshot := auth.Clone()
m.mu.RUnlock()
m.scheduler.upsertAuth(snapshot)
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
@@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) {
m.mu.Lock()
m.selector = selector
m.mu.Unlock()
if m.scheduler != nil {
m.scheduler.setSelector(selector)
m.syncScheduler()
}
}
// SetStore swaps the underlying persistence store.
@@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin
if resolved == "" {
return ""
}
// Preserve thinking suffix from the client's requested model unless config already has one.
requestResult := thinking.ParseSuffix(requestedModel)
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
return preserveRequestedModelSuffix(requestedModel, resolved)
}
func isAPIKeyAuth(auth *Auth) bool {
if auth == nil {
return false
}
kind, _ := auth.AccountInfo()
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
}
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
if !isAPIKeyAuth(auth) {
return false
}
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return true
}
if auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
}
func openAICompatProviderKey(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
return strings.ToLower(providerKey)
}
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
return strings.ToLower(compatName)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
if base == "" {
base = strings.TrimSpace(requestedModel)
}
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
}
func (m *Manager) nextModelPoolOffset(key string, size int) int {
if m == nil || size <= 1 {
return 0
}
key = strings.TrimSpace(key)
if key == "" {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
if m.modelPoolOffsets == nil {
m.modelPoolOffsets = make(map[string]int)
}
offset := m.modelPoolOffsets[key]
if offset >= 2_147_483_640 {
offset = 0
}
m.modelPoolOffsets[key] = offset + 1
if size <= 0 {
return 0
}
return offset % size
}
func rotateStrings(values []string, offset int) []string {
if len(values) <= 1 {
return values
}
if offset <= 0 {
out := make([]string, len(values))
copy(out, values)
return out
}
offset = offset % len(values)
out := make([]string, 0, len(values))
out = append(out, values[offset:]...)
out = append(out, values[:offset]...)
return out
}
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
return nil
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return nil
}
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
}
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
return m.prepareExecutionModels(auth, routeModel)
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
requestedModel := rewriteModelForAuth(routeModel, auth)
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
if len(pool) == 1 {
return pool
}
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
return rotateStrings(pool, offset)
}
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
if strings.TrimSpace(resolved) == "" {
resolved = requestedModel
}
return []string{resolved}
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
}
go func() {
for range ch {
}
}()
}
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
if ch == nil {
return nil, true, nil
}
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
for {
var (
chunk cliproxyexecutor.StreamChunk
ok bool
)
if ctx != nil {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case chunk, ok = <-ch:
}
} else {
chunk, ok = <-ch
}
if !ok {
return buffered, true, nil
}
if chunk.Err != nil {
return nil, false, chunk.Err
}
buffered = append(buffered, chunk)
if len(chunk.Payload) > 0 {
return buffered, false, nil
}
}
}
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
var failed bool
forward := true
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
return false
}
if ctx == nil {
out <- chunk
return true
}
select {
case <-ctx.Done():
forward = false
return false
case out <- chunk:
return true
}
}
for _, chunk := range buffered {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
for chunk := range remaining {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
if !failed {
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
}
}()
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
execModels := m.prepareExecutionModels(auth, routeModel)
var lastErr error
for idx, execModel := range execModels {
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
lastErr = bootstrapErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
if closed && len(buffered) == 0 {
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
m.MarkResult(ctx, result)
if idx < len(execModels)-1 {
lastErr = emptyErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
remaining := streamResult.Chunks
if closed {
closedCh := make(chan cliproxyexecutor.StreamChunk)
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
}
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
@@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
authClone := auth.Clone()
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
@@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
}
}
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
authClone := auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
@@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
// Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.store == nil {
m.mu.Unlock()
return nil
}
items, err := m.store.List(ctx)
if err != nil {
m.mu.Unlock()
return err
}
m.auths = make(map[string]*Auth, len(items))
@@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error {
cfg = &internalconfig.Config{}
}
m.rebuildAPIKeyModelAliasLocked(cfg)
m.mu.Unlock()
m.syncScheduler()
return nil
}
@@ -634,32 +1005,42 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = errExec
lastErr = authErr
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
@@ -696,32 +1077,42 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = errExec
lastErr = authErr
continue
}
m.hook.OnResult(execCtx, result)
return resp, nil
}
}
@@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
return &cliproxyexecutor.StreamResult{
Headers: streamResult.Headers,
Chunks: out,
}, nil
return streamResult, nil
}
}
@@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
suspendReason := ""
clearModelQuota := false
setModelQuota := false
var authSnapshot *Auth
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
@@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
}
_ = m.persist(ctx, auth)
authSnapshot = auth.Clone()
}
m.mu.Unlock()
if m.scheduler != nil && authSnapshot != nil {
m.scheduler.upsertAuth(authSnapshot)
}
if clearModelQuota && result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
@@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int {
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it checks for 400 Bad Request
// with "invalid_request_error" in the message, indicating the request itself is
// malformed and switching to a different auth will not help.
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
}
status := statusCodeFromError(err)
if status != http.StatusBadRequest {
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusUnprocessableEntity:
return true
default:
return false
}
return strings.Contains(err.Error(), "invalid_request_error")
}
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
@@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
}
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
func (m *Manager) useSchedulerFastPath() bool {
if m == nil || m.scheduler == nil {
return false
}
return isBuiltInSelector(m.selector)
}
func shouldRetrySchedulerPick(err error) bool {
if err == nil {
return false
}
var cooldownErr *modelCooldownError
if errors.As(err, &cooldownErr) {
return true
}
var authErr *Error
if !errors.As(err, &authErr) || authErr == nil {
return false
}
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
@@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
return authCopy, executor, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
}
if errPick != nil {
return nil, nil, errPick
}
if selected == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
@@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
return authCopy, executor, providerKey, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
if !m.useSchedulerFastPath() {
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
eligibleProviders := make([]string, 0, len(providers))
seenProviders := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
if _, seen := seenProviders[providerKey]; seen {
continue
}
if _, okExecutor := m.Executor(providerKey); !okExecutor {
continue
}
seenProviders[providerKey] = struct{}{}
eligibleProviders = append(eligibleProviders, providerKey)
}
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
}
if errPick != nil {
return nil, nil, "", errPick
}
if selected == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
executor, okExecutor := m.Executor(providerKey)
if !okExecutor {
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
@@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
return
@@ -0,0 +1,163 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerProviderTestExecutor struct {
provider string
}
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
prime func(*Manager, *Auth) error
}{
{
name: "register",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
return errRegister
},
},
{
name: "update",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
if errRegister != nil {
return errRegister
}
updated := auth.Clone()
updated.Metadata = map[string]any{"updated": true}
_, errUpdate := manager.Update(ctx, updated)
return errUpdate
},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
manager := NewManager(nil, &RoundRobinSelector{}, nil)
auth := &Auth{
ID: "refresh-entry-" + testCase.name,
Provider: "gemini",
}
if errPrime := testCase.prime(manager, auth); errPrime != nil {
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
}
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
var authErr *Error
if !errors.As(errPick, &authErr) || authErr == nil {
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
}
if authErr.Code != "auth_not_found" {
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
}
if got != nil {
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
}
manager.RefreshSchedulerEntry(auth.ID)
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() after refresh error = %v", errPick)
}
if got == nil || got.ID != auth.ID {
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
}
})
}
}
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
ctx := context.Background()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
oldAuth := &Auth{
ID: "cooldown-stale-old",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
t.Fatalf("register old auth: %v", errRegister)
}
manager.MarkResult(ctx, Result{
AuthID: oldAuth.ID,
Provider: "gemini",
Model: "scheduler-cooldown-rebuild-model",
Success: false,
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
})
newAuth := &Auth{
ID: "cooldown-stale-new",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
t.Fatalf("register new auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
t.Cleanup(func() {
reg.UnregisterClient(newAuth.ID)
})
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
var cooldownErr *modelCooldownError
if !errors.As(errPick, &cooldownErr) {
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
}
if got != nil {
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
}
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if executor == nil {
t.Fatal("pickNext() executor = nil")
}
if got == nil || got.ID != newAuth.ID {
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
}
}
+69 -25
View File
@@ -80,54 +80,98 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string
return upstreamModel
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
return thinking.SuffixResult{}, nil
}
if len(models) == 0 {
return ""
}
requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName
if base == "" {
base = requestedModel
}
candidates := []string{base}
if base != requestedModel {
candidates = append(candidates, requestedModel)
}
return requestResult, candidates
}
preserveSuffix := func(resolved string) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
if len(models) == 0 {
return nil
}
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
if len(candidates) == 0 {
return nil
}
out := make([]string, 0)
seen := make(map[string]struct{})
for i := range models {
name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates {
if candidate == "" {
if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
continue
}
if alias != "" && strings.EqualFold(alias, candidate) {
if name != "" {
return preserveSuffix(name)
}
return preserveSuffix(candidate)
resolved := candidate
if name != "" {
resolved = name
}
if name != "" && strings.EqualFold(name, candidate) {
return preserveSuffix(name)
resolved = preserveResolvedModelSuffix(resolved, requestResult)
key := strings.ToLower(strings.TrimSpace(resolved))
if key == "" {
break
}
if _, exists := seen[key]; exists {
break
}
seen[key] = struct{}{}
out = append(out, resolved)
break
}
}
if len(out) > 0 {
return out
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
for _, candidate := range candidates {
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
continue
}
return []string{preserveResolvedModelSuffix(name, requestResult)}
}
}
return nil
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
if len(resolved) > 0 {
return resolved[0]
}
return ""
}
@@ -0,0 +1,419 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type openAICompatPoolExecutor struct {
id string
mu sync.Mutex
executeModels []string
countModels []string
streamModels []string
executeErrors map[string]error
countErrors map[string]error
streamFirstErrors map[string]error
streamPayloads map[string][]cliproxyexecutor.StreamChunk
}
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
err := e.executeErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.streamModels = append(e.streamModels, req.Model)
err := e.streamFirstErrors[req.Model]
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
e.mu.Unlock()
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
if err != nil {
ch <- cliproxyexecutor.StreamChunk{Err: err}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
if !hasCustomChunks {
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
} else {
for _, chunk := range chunks {
ch <- chunk
}
}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.countModels = append(e.countModels, req.Model)
err := e.countErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
_ = ctx
_ = auth
_ = req
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func (e *openAICompatPoolExecutor) CountModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.countModels))
copy(out, e.countModels)
return out
}
func (e *openAICompatPoolExecutor) StreamModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.streamModels))
copy(out, e.streamModels)
return out
}
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
t.Helper()
cfg := &internalconfig.Config{
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
Name: "pool",
Models: models,
}},
}
m := NewManager(nil, nil, nil)
m.SetConfig(cfg)
if executor == nil {
executor = &openAICompatPoolExecutor{id: "pool"}
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "pool-auth-" + t.Name(),
Provider: "pool",
Status: StatusActive,
Attributes: map[string]string{
"api_key": "test-key",
"compat_name": "pool",
"provider_key": "pool",
},
}
if _, err := m.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
return m
}
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
}
got := executor.CountModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("count calls = %v, want only first invalid model", got)
}
}
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
models := []modelAliasEntry{
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
}
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
if len(got) != len(want) {
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 3; i++ {
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute %d returned empty payload", i)
}
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute error = %v, want %v", err, invalidErr)
}
got := executor.ExecuteModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("execute calls = %v, want only first invalid model", got)
}
}
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute: %v", err)
}
if string(resp.Payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
"qwen3.5-plus": {},
},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
}
got := executor.StreamModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first invalid model", got)
}
}
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 2; i++ {
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute count %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute count %d returned empty payload", i)
}
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil {
t.Fatal("expected invalid request error")
}
if err != invalidErr {
t.Fatalf("error = %v, want %v", err, invalidErr)
}
if streamResult != nil {
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
}
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first upstream model", got)
}
}
+904
View File
@@ -0,0 +1,904 @@
package auth
import (
"context"
"sort"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
type schedulerStrategy int
const (
schedulerStrategyCustom schedulerStrategy = iota
schedulerStrategyRoundRobin
schedulerStrategyFillFirst
)
// scheduledState describes how an auth currently participates in a model shard.
type scheduledState int
const (
scheduledStateReady scheduledState = iota
scheduledStateCooldown
scheduledStateBlocked
scheduledStateDisabled
)
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
type authScheduler struct {
mu sync.Mutex
strategy schedulerStrategy
providers map[string]*providerScheduler
authProviders map[string]string
mixedCursors map[string]int
}
// providerScheduler stores auth metadata and model shards for a single provider.
type providerScheduler struct {
providerKey string
auths map[string]*scheduledAuthMeta
modelShards map[string]*modelScheduler
}
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
type scheduledAuthMeta struct {
auth *Auth
providerKey string
priority int
virtualParent string
websocketEnabled bool
supportedModelSet map[string]struct{}
}
// modelScheduler tracks ready and blocked auths for one provider/model combination.
type modelScheduler struct {
modelKey string
entries map[string]*scheduledAuth
priorityOrder []int
readyByPriority map[int]*readyBucket
blocked cooldownQueue
}
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
type scheduledAuth struct {
meta *scheduledAuthMeta
auth *Auth
state scheduledState
nextRetryAt time.Time
}
// readyBucket keeps the ready views for one priority level.
type readyBucket struct {
all readyView
ws readyView
}
// readyView holds the selection order for flat or grouped round-robin traversal.
type readyView struct {
flat []*scheduledAuth
cursor int
parentOrder []string
parentCursor int
children map[string]*childBucket
}
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
type childBucket struct {
items []*scheduledAuth
cursor int
}
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
type cooldownQueue []*scheduledAuth
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
func newAuthScheduler(selector Selector) *authScheduler {
return &authScheduler{
strategy: selectorStrategy(selector),
providers: make(map[string]*providerScheduler),
authProviders: make(map[string]string),
mixedCursors: make(map[string]int),
}
}
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
func selectorStrategy(selector Selector) schedulerStrategy {
switch selector.(type) {
case *FillFirstSelector:
return schedulerStrategyFillFirst
case nil, *RoundRobinSelector:
return schedulerStrategyRoundRobin
default:
return schedulerStrategyCustom
}
}
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
func (s *authScheduler) setSelector(selector Selector) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.strategy = selectorStrategy(selector)
clear(s.mixedCursors)
}
// rebuild recreates the complete scheduler state from an auth snapshot.
func (s *authScheduler) rebuild(auths []*Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.providers = make(map[string]*providerScheduler)
s.authProviders = make(map[string]string)
s.mixedCursors = make(map[string]int)
now := time.Now()
for _, auth := range auths {
s.upsertAuthLocked(auth, now)
}
}
// upsertAuth incrementally synchronizes one auth into the scheduler.
func (s *authScheduler) upsertAuth(auth *Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.upsertAuthLocked(auth, time.Now())
}
// removeAuth deletes one auth from every scheduler shard that references it.
func (s *authScheduler) removeAuth(authID string) {
if s == nil {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.removeAuthLocked(authID)
}
// pickSingle returns the next auth for a single provider/model request using scheduler state.
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
if s == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerKey := strings.ToLower(strings.TrimSpace(provider))
modelKey := canonicalModelKey(model)
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
s.mu.Lock()
defer s.mu.Unlock()
providerState := s.providers[providerKey]
if providerState == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
if shard == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) > 0 {
if _, ok := tried[entry.auth.ID]; ok {
return false
}
}
return true
}
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
return picked, nil
}
return nil, shard.unavailableErrorLocked(provider, model, predicate)
}
// pickMixed returns the next auth and provider for a mixed-provider request.
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
if s == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
normalized := normalizeProviderKeys(providers)
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
s.mu.Lock()
defer s.mu.Unlock()
if pinnedAuthID != "" {
providerKey := s.authProviders[pinnedAuthID]
if providerKey == "" || !containsProvider(normalized, providerKey) {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerState := s.providers[providerKey]
if providerState == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) == 0 {
return true
}
_, ok := tried[pinnedAuthID]
return !ok
}
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
return picked, providerKey, nil
}
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
}
predicate := triedPredicate(tried)
candidateShards := make([]*modelScheduler, len(normalized))
bestPriority := 0
hasCandidate := false
now := time.Now()
for providerIndex, providerKey := range normalized {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(modelKey, now)
candidateShards[providerIndex] = shard
if shard == nil {
continue
}
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
if !okPriority {
continue
}
if !hasCandidate || priorityReady > bestPriority {
bestPriority = priorityReady
hasCandidate = true
}
}
if !hasCandidate {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
if s.strategy == schedulerStrategyFillFirst {
for providerIndex, providerKey := range normalized {
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
if picked != nil {
return picked, providerKey, nil
}
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
}
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
now := time.Now()
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, providerKey := range providers {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
if shard == nil {
continue
}
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
total += localTotal
cooldownCount += localCooldownCount
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
earliest = localEarliest
}
}
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, "", resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// triedPredicate builds a filter that excludes auths already attempted for the current request.
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
if len(tried) == 0 {
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
}
return func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
_, ok := tried[entry.auth.ID]
return !ok
}
}
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
func normalizeProviderKeys(providers []string) []string {
seen := make(map[string]struct{}, len(providers))
out := make([]string, 0, len(providers))
for _, provider := range providers {
providerKey := strings.ToLower(strings.TrimSpace(provider))
if providerKey == "" {
continue
}
if _, ok := seen[providerKey]; ok {
continue
}
seen[providerKey] = struct{}{}
out = append(out, providerKey)
}
return out
}
// containsProvider reports whether provider is present in the normalized provider list.
func containsProvider(providers []string, provider string) bool {
for _, candidate := range providers {
if candidate == provider {
return true
}
}
return false
}
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
if auth == nil {
return
}
authID := strings.TrimSpace(auth.ID)
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
if authID == "" || providerKey == "" || auth.Disabled {
s.removeAuthLocked(authID)
return
}
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
if previousState := s.providers[previousProvider]; previousState != nil {
previousState.removeAuthLocked(authID)
}
}
meta := buildScheduledAuthMeta(auth)
s.authProviders[authID] = providerKey
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
}
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
func (s *authScheduler) removeAuthLocked(authID string) {
if authID == "" {
return
}
if providerKey := s.authProviders[authID]; providerKey != "" {
if providerState := s.providers[providerKey]; providerState != nil {
providerState.removeAuthLocked(authID)
}
delete(s.authProviders, authID)
}
}
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
if s.providers == nil {
s.providers = make(map[string]*providerScheduler)
}
providerState := s.providers[providerKey]
if providerState == nil {
providerState = &providerScheduler{
providerKey: providerKey,
auths: make(map[string]*scheduledAuthMeta),
modelShards: make(map[string]*modelScheduler),
}
s.providers[providerKey] = providerState
}
return providerState
}
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
virtualParent := ""
if auth.Attributes != nil {
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
}
return &scheduledAuthMeta{
auth: auth,
providerKey: providerKey,
priority: authPriority(auth),
virtualParent: virtualParent,
websocketEnabled: authWebsocketsEnabled(auth),
supportedModelSet: supportedModelSetForAuth(auth.ID),
}
}
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
func supportedModelSetForAuth(authID string) map[string]struct{} {
authID = strings.TrimSpace(authID)
if authID == "" {
return nil
}
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
if len(models) == 0 {
return nil
}
set := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
set[modelKey] = struct{}{}
}
return set
}
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
if p == nil || meta == nil || meta.auth == nil {
return
}
p.auths[meta.auth.ID] = meta
for modelKey, shard := range p.modelShards {
if shard == nil {
continue
}
if !meta.supportsModel(modelKey) {
shard.removeEntryLocked(meta.auth.ID)
continue
}
shard.upsertEntryLocked(meta, now)
}
}
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
func (p *providerScheduler) removeAuthLocked(authID string) {
if p == nil || authID == "" {
return
}
delete(p.auths, authID)
for _, shard := range p.modelShards {
if shard != nil {
shard.removeEntryLocked(authID)
}
}
}
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
if p == nil {
return nil
}
modelKey = canonicalModelKey(modelKey)
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
shard.promoteExpiredLocked(now)
return shard
}
shard := &modelScheduler{
modelKey: modelKey,
entries: make(map[string]*scheduledAuth),
readyByPriority: make(map[int]*readyBucket),
}
for _, meta := range p.auths {
if meta == nil || !meta.supportsModel(modelKey) {
continue
}
shard.upsertEntryLocked(meta, now)
}
p.modelShards[modelKey] = shard
return shard
}
// supportsModel reports whether the auth metadata currently supports modelKey.
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
modelKey = canonicalModelKey(modelKey)
if modelKey == "" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}
_, ok := m.supportedModelSet[modelKey]
return ok
}
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
if m == nil || meta == nil || meta.auth == nil {
return
}
entry, ok := m.entries[meta.auth.ID]
if !ok || entry == nil {
entry = &scheduledAuth{}
m.entries[meta.auth.ID] = entry
}
previousState := entry.state
previousNextRetryAt := entry.nextRetryAt
previousPriority := 0
previousParent := ""
previousWebsocketEnabled := false
if entry.meta != nil {
previousPriority = entry.meta.priority
previousParent = entry.meta.virtualParent
previousWebsocketEnabled = entry.meta.websocketEnabled
}
entry.meta = meta
entry.auth = meta.auth
entry.nextRetryAt = time.Time{}
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
return
}
m.rebuildIndexesLocked()
}
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
func (m *modelScheduler) removeEntryLocked(authID string) {
if m == nil || authID == "" {
return
}
if _, ok := m.entries[authID]; !ok {
return
}
delete(m.entries, authID)
m.rebuildIndexesLocked()
}
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
if m == nil || len(m.blocked) == 0 {
return
}
changed := false
for _, entry := range m.blocked {
if entry == nil || entry.auth == nil {
continue
}
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
continue
}
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
entry.nextRetryAt = time.Time{}
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
entry.nextRetryAt = time.Time{}
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
changed = true
}
if changed {
m.rebuildIndexesLocked()
}
}
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
m.promoteExpiredLocked(time.Now())
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
if !okPriority {
return nil
}
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
}
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
if m == nil {
return 0, false
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
return priority, true
}
}
return 0, false
}
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
var picked *scheduledAuth
if strategy == schedulerStrategyFillFirst {
picked = view.pickFirst(predicate)
} else {
picked = view.pickRoundRobin(predicate)
}
if picked == nil || picked.auth == nil {
return nil
}
return picked.auth
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, providerForError, resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
if m == nil {
return 0, 0, time.Time{}
}
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, entry := range m.entries {
if predicate != nil && !predicate(entry) {
continue
}
total++
if entry == nil || entry.auth == nil {
continue
}
if entry.state != scheduledStateCooldown {
continue
}
cooldownCount++
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
earliest = entry.nextRetryAt
}
}
return total, cooldownCount, earliest
}
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
func (m *modelScheduler) rebuildIndexesLocked() {
m.readyByPriority = make(map[int]*readyBucket)
m.priorityOrder = m.priorityOrder[:0]
m.blocked = m.blocked[:0]
priorityBuckets := make(map[int][]*scheduledAuth)
for _, entry := range m.entries {
if entry == nil || entry.auth == nil {
continue
}
switch entry.state {
case scheduledStateReady:
priority := entry.meta.priority
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
case scheduledStateCooldown, scheduledStateBlocked:
m.blocked = append(m.blocked, entry)
}
}
for priority, entries := range priorityBuckets {
sort.Slice(entries, func(i, j int) bool {
return entries[i].auth.ID < entries[j].auth.ID
})
m.readyByPriority[priority] = buildReadyBucket(entries)
m.priorityOrder = append(m.priorityOrder, priority)
}
sort.Slice(m.priorityOrder, func(i, j int) bool {
return m.priorityOrder[i] > m.priorityOrder[j]
})
sort.Slice(m.blocked, func(i, j int) bool {
left := m.blocked[i]
right := m.blocked[j]
if left == nil || right == nil {
return left != nil
}
if left.nextRetryAt.Equal(right.nextRetryAt) {
return left.auth.ID < right.auth.ID
}
if left.nextRetryAt.IsZero() {
return false
}
if right.nextRetryAt.IsZero() {
return true
}
return left.nextRetryAt.Before(right.nextRetryAt)
})
}
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
bucket := &readyBucket{}
bucket.all = buildReadyView(entries)
wsEntries := make([]*scheduledAuth, 0, len(entries))
for _, entry := range entries {
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
wsEntries = append(wsEntries, entry)
}
}
bucket.ws = buildReadyView(wsEntries)
return bucket
}
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
func buildReadyView(entries []*scheduledAuth) readyView {
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
if len(entries) == 0 {
return view
}
groups := make(map[string][]*scheduledAuth)
for _, entry := range entries {
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
return view
}
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
}
if len(groups) <= 1 {
return view
}
view.children = make(map[string]*childBucket, len(groups))
view.parentOrder = make([]string, 0, len(groups))
for parent := range groups {
view.parentOrder = append(view.parentOrder, parent)
}
sort.Strings(view.parentOrder)
for _, parent := range view.parentOrder {
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
}
return view
}
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
for _, entry := range v.flat {
if predicate == nil || predicate(entry) {
return entry
}
}
return nil
}
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
if len(v.parentOrder) > 1 && len(v.children) > 0 {
return v.pickGroupedRoundRobin(predicate)
}
if len(v.flat) == 0 {
return nil
}
start := 0
if len(v.flat) > 0 {
start = v.cursor % len(v.flat)
}
for offset := 0; offset < len(v.flat); offset++ {
index := (start + offset) % len(v.flat)
entry := v.flat[index]
if predicate != nil && !predicate(entry) {
continue
}
v.cursor = index + 1
return entry
}
return nil
}
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
start := 0
if len(v.parentOrder) > 0 {
start = v.parentCursor % len(v.parentOrder)
}
for offset := 0; offset < len(v.parentOrder); offset++ {
parentIndex := (start + offset) % len(v.parentOrder)
parent := v.parentOrder[parentIndex]
child := v.children[parent]
if child == nil || len(child.items) == 0 {
continue
}
itemStart := child.cursor % len(child.items)
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
itemIndex := (itemStart + itemOffset) % len(child.items)
entry := child.items[itemIndex]
if predicate != nil && !predicate(entry) {
continue
}
child.cursor = itemIndex + 1
v.parentCursor = parentIndex + 1
return entry
}
}
return nil
}
@@ -0,0 +1,216 @@
package auth
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerBenchmarkExecutor struct {
id string
}
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
b.Helper()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
providers := []string{"gemini"}
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
if mixed {
providers = []string{"gemini", "claude"}
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
}
reg := registry.GetGlobalRegistry()
model := "bench-model"
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
if withPriority {
priority := "0"
if index%2 == 0 {
priority = "10"
}
auth.Attributes = map[string]string{"priority": priority}
}
_, errRegister := manager.Register(context.Background(), auth)
if errRegister != nil {
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
}
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
}
manager.syncScheduler()
b.Cleanup(func() {
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
}
})
return manager, providers, model
}
func BenchmarkManagerPickNext500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNext1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextMixed500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil {
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
}
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
}
}
+503
View File
@@ -0,0 +1,503 @@
package auth
import (
"context"
"net/http"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerTestExecutor struct{}
func (schedulerTestExecutor) Identifier() string { return "test" }
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
type trackingSelector struct {
calls int
lastAuthID []string
}
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
s.calls++
s.lastAuthID = s.lastAuthID[:0]
for _, auth := range auths {
s.lastAuthID = append(s.lastAuthID, auth.ID)
}
if len(auths) == 0 {
return nil, nil
}
return auths[len(auths)-1], nil
}
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
scheduler := newAuthScheduler(selector)
scheduler.rebuild(auths)
return scheduler
}
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
t.Helper()
reg := registry.GetGlobalRegistry()
for _, authID := range authIDs {
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
}
t.Cleanup(func() {
for _, authID := range authIDs {
reg.UnregisterClient(authID)
}
})
}
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
)
want := []string{"high-a", "high-b", "high-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&FillFirstSelector{},
&Auth{ID: "b", Provider: "gemini"},
&Auth{ID: "a", Provider: "gemini"},
&Auth{ID: "c", Provider: "gemini"},
)
for index := 0; index < 3; index++ {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != "a" {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
}
}
}
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
t.Parallel()
model := "gemini-2.5-pro"
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{
ID: "cooldown-expired",
Provider: "gemini",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
NextRetryAfter: time.Now().Add(-1 * time.Second),
},
},
},
)
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickSingle() auth = nil")
}
if got.ID != "cooldown-expired" {
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
}
}
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
t.Parallel()
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
)
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
for index := range wantIDs {
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantIDs[index] {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
}
}
}
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex"},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "gemini-a", Provider: "gemini"},
&Auth{ID: "gemini-b", Provider: "gemini"},
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
t.Parallel()
model := "gpt-default"
registerSchedulerModels(t, "provider-low", model, "low")
registerSchedulerModels(t, "provider-high-a", model, "high-a")
registerSchedulerModels(t, "provider-high-b", model, "high-b")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
)
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
t.Parallel()
selector := &trackingSelector{}
manager := NewManager(nil, selector, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNext() auth = nil")
}
if selector.calls != 1 {
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
}
if len(selector.lastAuthID) != 2 {
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
}
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
}
}
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if manager.scheduler == nil {
t.Fatalf("manager.scheduler = nil")
}
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
}
manager.SetSelector(&FillFirstSelector{})
if manager.scheduler.strategy != schedulerStrategyFillFirst {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
}
}
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
}
if got == nil || got.ID != "auth-a" {
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
}
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
t.Fatalf("Update(auth-a) error = %v", errUpdate)
}
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
}
}
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() auth = nil")
}
if provider != "claude" {
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
}
if got.ID != "claude-a" {
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
}
}
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
reg.UnregisterClient("auth-a")
reg.UnregisterClient("auth-b")
})
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: false,
Error: &Error{HTTPStatus: 429, Message: "quota"},
})
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: true,
})
seen := make(map[string]struct{}, 2)
for index := 0; index < 2; index++ {
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
}
seen[got.ID] = struct{}{}
}
if len(seen) != 2 {
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
}
}
+5 -31
View File
@@ -1,16 +1,13 @@
package cliproxy
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
@@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http.
if rt != nil {
return rt
}
// Parse the proxy URL to determine the scheme.
proxyURL, errParse := url.Parse(proxyStr)
if errParse != nil {
log.Errorf("parse proxy URL failed: %v", errParse)
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
var transport *http.Transport
// Handle different proxy schemes.
if proxyURL.Scheme == "socks5" {
// Configure SOCKS5 proxy with optional authentication.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil
}
// Set up a custom transport using the SOCKS5 dialer.
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Configure HTTP or HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
} else {
log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
if transport == nil {
return nil
}
p.mu.Lock()
+22
View File
@@ -0,0 +1,22 @@
package cliproxy
import (
"net/http"
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRoundTripperForDirectBypassesProxy(t *testing.T) {
t.Parallel()
provider := newDefaultRoundTripperProvider()
rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"})
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", rt)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
+22 -1
View File
@@ -312,6 +312,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
// have an empty supportedModelSet (because Register/Update upserts into the
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
s.coreManager.RefreshSchedulerEntry(auth.ID)
}
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
@@ -823,7 +829,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
models = applyExcludedModels(models, excluded)
case "codex":
models = registry.GetOpenAIModels()
codexPlanType := ""
if a.Attributes != nil {
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
}
switch strings.ToLower(codexPlanType) {
case "pro":
models = registry.GetCodexProModels()
case "plus":
models = registry.GetCodexPlusModels()
case "team":
models = registry.GetCodexTeamModels()
case "free":
models = registry.GetCodexFreeModels()
default:
models = registry.GetCodexProModels()
}
if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry)
+139
View File
@@ -0,0 +1,139 @@
package proxyutil
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
// Mode describes how a proxy setting should be interpreted.
type Mode int
const (
// ModeInherit means no explicit proxy behavior was configured.
ModeInherit Mode = iota
// ModeDirect means outbound requests must bypass proxies explicitly.
ModeDirect
// ModeProxy means a concrete proxy URL was configured.
ModeProxy
// ModeInvalid means the proxy setting is present but malformed or unsupported.
ModeInvalid
)
// Setting is the normalized interpretation of a proxy configuration value.
type Setting struct {
Raw string
Mode Mode
URL *url.URL
}
// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes.
func Parse(raw string) (Setting, error) {
trimmed := strings.TrimSpace(raw)
setting := Setting{Raw: trimmed}
if trimmed == "" {
setting.Mode = ModeInherit
return setting, nil
}
if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") {
setting.Mode = ModeDirect
return setting, nil
}
parsedURL, errParse := url.Parse(trimmed)
if errParse != nil {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("proxy URL missing scheme/host")
}
switch parsedURL.Scheme {
case "socks5", "http", "https":
setting.Mode = ModeProxy
setting.URL = parsedURL
return setting, nil
default:
setting.Mode = ModeInvalid
return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
}
}
// NewDirectTransport returns a transport that bypasses environment proxies.
func NewDirectTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
clone := transport.Clone()
clone.Proxy = nil
return clone
}
return &http.Transport{Proxy: nil}
}
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return NewDirectTransport(), setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
return &http.Transport{
Proxy: nil,
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}, setting.Mode, nil
}
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
// BuildDialer constructs a proxy dialer for settings that operate at the connection layer.
func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return proxy.Direct, setting.Mode, nil
case ModeProxy:
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
if errDialer != nil {
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
}
return dialer, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
+89
View File
@@ -0,0 +1,89 @@
package proxyutil
import (
"net/http"
"testing"
)
func TestParse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want Mode
wantErr bool
}{
{name: "inherit", input: "", want: ModeInherit},
{name: "direct", input: "direct", want: ModeDirect},
{name: "none", input: "none", want: ModeDirect},
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
setting, errParse := Parse(tt.input)
if tt.wantErr && errParse == nil {
t.Fatal("expected error, got nil")
}
if !tt.wantErr && errParse != nil {
t.Fatalf("unexpected error: %v", errParse)
}
if setting.Mode != tt.want {
t.Fatalf("mode = %d, want %d", setting.Mode, tt.want)
}
})
}
}
func TestBuildHTTPTransportDirectBypassesProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("direct")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeDirect {
t.Fatalf("mode = %d, want %d", mode, ModeDirect)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("http://proxy.example.com:8080")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := transport.Proxy(req)
if errProxy != nil {
t.Fatalf("transport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
}
}