Merge branch 'dev' into codex/custom-useragent-request
This commit is contained in:
@@ -15,6 +15,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
@@ -46,6 +48,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v3
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to DockerHub
|
- name: Login to DockerHub
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
- name: Refresh models catalog
|
||||||
|
run: curl -fsSL https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json -o internal/registry/models/models.json
|
||||||
- run: git fetch --force --tags
|
- run: git fetch --force --tags
|
||||||
- uses: actions/setup-go@v4
|
- uses: actions/setup-go@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -150,6 +150,10 @@ A Windows tray application implemented using PowerShell scripts, without relying
|
|||||||
|
|
||||||
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
|
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
|
||||||
|
|
||||||
|
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||||
|
|
||||||
|
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||||
|
|
||||||
|
|||||||
@@ -149,6 +149,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
|
|||||||
|
|
||||||
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
|
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
|
||||||
|
|
||||||
|
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||||
|
|
||||||
|
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
|
||||||
@@ -494,6 +495,7 @@ func main() {
|
|||||||
if standalone {
|
if standalone {
|
||||||
// Standalone mode: start an embedded local server and connect TUI client to it.
|
// Standalone mode: start an embedded local server and connect TUI client to it.
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
hook := tui.NewLogHook(2000)
|
hook := tui.NewLogHook(2000)
|
||||||
hook.SetFormatter(&logging.LogFormatter{})
|
hook.SetFormatter(&logging.LogFormatter{})
|
||||||
log.AddHook(hook)
|
log.AddHook(hook)
|
||||||
@@ -566,6 +568,7 @@ func main() {
|
|||||||
} else {
|
} else {
|
||||||
// Start the main proxy service
|
// Start the main proxy service
|
||||||
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
managementasset.StartAutoUpdater(context.Background(), configFilePath)
|
||||||
|
registry.StartModelsUpdater(context.Background())
|
||||||
cmd.StartService(cfg, configFilePath, password)
|
cmd.StartService(cfg, configFilePath, password)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ error-logs-max-files: 10
|
|||||||
usage-statistics-enabled: false
|
usage-statistics-enabled: false
|
||||||
|
|
||||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||||
|
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
|
||||||
proxy-url: ""
|
proxy-url: ""
|
||||||
|
|
||||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||||
@@ -110,6 +111,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080"
|
# proxy-url: "socks5://proxy.example.com:1080"
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gemini-2.5-flash" # upstream model name
|
# - name: "gemini-2.5-flash" # upstream model name
|
||||||
# alias: "gemini-flash" # client alias mapped to the upstream model
|
# alias: "gemini-flash" # client alias mapped to the upstream model
|
||||||
@@ -128,6 +130,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "gpt-5-codex" # upstream model name
|
# - name: "gpt-5-codex" # upstream model name
|
||||||
# alias: "codex-latest" # client alias mapped to the upstream model
|
# alias: "codex-latest" # client alias mapped to the upstream model
|
||||||
@@ -146,6 +149,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# models:
|
# models:
|
||||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||||
@@ -191,10 +195,22 @@ nonstream-keepalive-interval: 0
|
|||||||
# api-key-entries:
|
# api-key-entries:
|
||||||
# - api-key: "sk-or-v1-...b780"
|
# - api-key: "sk-or-v1-...b780"
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
# - api-key: "sk-or-v1-...b781" # without proxy-url
|
||||||
# models: # The models supported by the provider.
|
# models: # The models supported by the provider.
|
||||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||||
# alias: "kimi-k2" # The alias used in the API.
|
# alias: "kimi-k2" # The alias used in the API.
|
||||||
|
# # You may repeat the same alias to build an internal model pool.
|
||||||
|
# # The client still sees only one alias in the model list.
|
||||||
|
# # Requests to that alias will round-robin across the upstream names below,
|
||||||
|
# # and if the chosen upstream fails before producing output, the request will
|
||||||
|
# # continue with the next upstream model in the same alias pool.
|
||||||
|
# - name: "qwen3.5-plus"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
# - name: "glm-5"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
# - name: "kimi-k2.5"
|
||||||
|
# alias: "claude-opus-4.66"
|
||||||
|
|
||||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||||
# vertex-api-key:
|
# vertex-api-key:
|
||||||
@@ -202,6 +218,7 @@ nonstream-keepalive-interval: 0
|
|||||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||||
|
# # proxy-url: "direct" # optional: explicit direct connect for this credential
|
||||||
# headers:
|
# headers:
|
||||||
# X-Custom-Header: "custom-value"
|
# X-Custom-Header: "custom-value"
|
||||||
# models: # optional: map aliases to upstream model names
|
# models: # optional: map aliases to upstream model names
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -14,8 +13,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
@@ -660,45 +659,10 @@ func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func buildProxyTransport(proxyStr string) *http.Transport {
|
func buildProxyTransport(proxyStr string) *http.Transport {
|
||||||
proxyStr = strings.TrimSpace(proxyStr)
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||||
if proxyStr == "" {
|
if errBuild != nil {
|
||||||
|
log.WithError(errBuild).Debug("build proxy transport failed")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
return transport
|
||||||
proxyURL, errParse := url.Parse(proxyStr)
|
|
||||||
if errParse != nil {
|
|
||||||
log.WithError(errParse).Debug("parse proxy URL failed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if proxyURL.Scheme == "" || proxyURL.Host == "" {
|
|
||||||
log.Debug("proxy URL missing scheme/host")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxyURL.Scheme == "socks5" {
|
|
||||||
var proxyAuth *proxy.Auth
|
|
||||||
if proxyURL.User != nil {
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.WithError(errSOCKS5).Debug("create SOCKS5 dialer failed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return &http.Transport{
|
|
||||||
Proxy: nil,
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
return &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,173 +1,58 @@
|
|||||||
package management
|
package management
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
type memoryAuthStore struct {
|
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
mu sync.Mutex
|
t.Parallel()
|
||||||
items map[string]*coreauth.Auth
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) List(ctx context.Context) ([]*coreauth.Auth, error) {
|
h := &Handler{
|
||||||
_ = ctx
|
cfg: &config.Config{
|
||||||
s.mu.Lock()
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
defer s.mu.Unlock()
|
|
||||||
out := make([]*coreauth.Auth, 0, len(s.items))
|
|
||||||
for _, a := range s.items {
|
|
||||||
out = append(out, a.Clone())
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Save(ctx context.Context, auth *coreauth.Auth) (string, error) {
|
|
||||||
_ = ctx
|
|
||||||
if auth == nil {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
if s.items == nil {
|
|
||||||
s.items = make(map[string]*coreauth.Auth)
|
|
||||||
}
|
|
||||||
s.items[auth.ID] = auth.Clone()
|
|
||||||
s.mu.Unlock()
|
|
||||||
return auth.ID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *memoryAuthStore) Delete(ctx context.Context, id string) error {
|
|
||||||
_ = ctx
|
|
||||||
s.mu.Lock()
|
|
||||||
delete(s.items, id)
|
|
||||||
s.mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_RefreshesExpiredToken(t *testing.T) {
|
|
||||||
var callCount int
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
t.Fatalf("expected POST, got %s", r.Method)
|
|
||||||
}
|
|
||||||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
|
||||||
t.Fatalf("unexpected content-type: %s", ct)
|
|
||||||
}
|
|
||||||
bodyBytes, _ := io.ReadAll(r.Body)
|
|
||||||
_ = r.Body.Close()
|
|
||||||
values, err := url.ParseQuery(string(bodyBytes))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("parse form: %v", err)
|
|
||||||
}
|
|
||||||
if values.Get("grant_type") != "refresh_token" {
|
|
||||||
t.Fatalf("unexpected grant_type: %s", values.Get("grant_type"))
|
|
||||||
}
|
|
||||||
if values.Get("refresh_token") != "rt" {
|
|
||||||
t.Fatalf("unexpected refresh_token: %s", values.Get("refresh_token"))
|
|
||||||
}
|
|
||||||
if values.Get("client_id") != antigravityOAuthClientID {
|
|
||||||
t.Fatalf("unexpected client_id: %s", values.Get("client_id"))
|
|
||||||
}
|
|
||||||
if values.Get("client_secret") != antigravityOAuthClientSecret {
|
|
||||||
t.Fatalf("unexpected client_secret")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
|
||||||
"access_token": "new-token",
|
|
||||||
"refresh_token": "rt2",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"token_type": "Bearer",
|
|
||||||
})
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
|
||||||
antigravityOAuthTokenURL = srv.URL
|
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
|
||||||
|
|
||||||
store := &memoryAuthStore{}
|
|
||||||
manager := coreauth.NewManager(store, nil, nil)
|
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-test.json",
|
|
||||||
FileName: "antigravity-test.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "old-token",
|
|
||||||
"refresh_token": "rt",
|
|
||||||
"expires_in": int64(3600),
|
|
||||||
"timestamp": time.Now().Add(-2 * time.Hour).UnixMilli(),
|
|
||||||
"expired": time.Now().Add(-1 * time.Hour).Format(time.RFC3339),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
|
||||||
t.Fatalf("register auth: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
h := &Handler{authManager: manager}
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
if err != nil {
|
if !ok {
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
}
|
}
|
||||||
if token != "new-token" {
|
if httpTransport.Proxy != nil {
|
||||||
t.Fatalf("expected refreshed token, got %q", token)
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
}
|
|
||||||
if callCount != 1 {
|
|
||||||
t.Fatalf("expected 1 refresh call, got %d", callCount)
|
|
||||||
}
|
|
||||||
|
|
||||||
updated, ok := manager.GetByID(auth.ID)
|
|
||||||
if !ok || updated == nil {
|
|
||||||
t.Fatalf("expected auth in manager after update")
|
|
||||||
}
|
|
||||||
if got := tokenValueFromMetadata(updated.Metadata); got != "new-token" {
|
|
||||||
t.Fatalf("expected manager metadata updated, got %q", got)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveTokenForAuth_Antigravity_SkipsRefreshWhenTokenValid(t *testing.T) {
|
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
|
||||||
var callCount int
|
t.Parallel()
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
callCount++
|
|
||||||
w.WriteHeader(http.StatusInternalServerError)
|
|
||||||
}))
|
|
||||||
t.Cleanup(srv.Close)
|
|
||||||
|
|
||||||
originalURL := antigravityOAuthTokenURL
|
h := &Handler{
|
||||||
antigravityOAuthTokenURL = srv.URL
|
cfg: &config.Config{
|
||||||
t.Cleanup(func() { antigravityOAuthTokenURL = originalURL })
|
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
|
||||||
|
|
||||||
auth := &coreauth.Auth{
|
|
||||||
ID: "antigravity-valid.json",
|
|
||||||
FileName: "antigravity-valid.json",
|
|
||||||
Provider: "antigravity",
|
|
||||||
Metadata: map[string]any{
|
|
||||||
"type": "antigravity",
|
|
||||||
"access_token": "ok-token",
|
|
||||||
"expired": time.Now().Add(30 * time.Minute).Format(time.RFC3339),
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
h := &Handler{}
|
|
||||||
token, err := h.resolveTokenForAuth(context.Background(), auth)
|
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
|
||||||
if err != nil {
|
httpTransport, ok := transport.(*http.Transport)
|
||||||
t.Fatalf("resolveTokenForAuth: %v", err)
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", transport)
|
||||||
}
|
}
|
||||||
if token != "ok-token" {
|
|
||||||
t.Fatalf("expected existing token, got %q", token)
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
}
|
}
|
||||||
if callCount != 0 {
|
|
||||||
t.Fatalf("expected no refresh calls, got %d", callCount)
|
proxyURL, errProxy := httpTransport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1306,12 +1306,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||||
if errAll != nil {
|
if errAll != nil {
|
||||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.ProjectID = strings.Join(projects, ",")
|
ts.ProjectID = strings.Join(projects, ",")
|
||||||
@@ -1320,7 +1320,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
ts.Auto = false
|
ts.Auto = false
|
||||||
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
|
||||||
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
log.Errorf("Google One auto-discovery failed: %v", errSetup)
|
||||||
SetOAuthSessionError(state, "Google One auto-discovery failed")
|
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||||
@@ -1331,19 +1331,19 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||||
if errCheck != nil {
|
if errCheck != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.Checked = isChecked
|
ts.Checked = isChecked
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
log.Error("Cloud AI API is not enabled for the auto-discovered project")
|
||||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1356,13 +1356,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
|||||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||||
if errCheck != nil {
|
if errCheck != nil {
|
||||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
ts.Checked = isChecked
|
ts.Checked = isChecked
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
log.Error("Cloud AI API is not enabled for the selected project")
|
log.Error("Cloud AI API is not enabled for the selected project")
|
||||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,49 @@
|
|||||||
|
package management
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type memoryAuthStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
items map[string]*coreauth.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
out := make([]*coreauth.Auth, 0, len(s.items))
|
||||||
|
for _, item := range s.items {
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.items == nil {
|
||||||
|
s.items = make(map[string]*coreauth.Auth)
|
||||||
|
}
|
||||||
|
s.items[auth.ID] = auth
|
||||||
|
return auth.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
delete(s.items, id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memoryAuthStore) SetBaseDir(string) {}
|
||||||
@@ -4,12 +4,12 @@ package claude
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
tls "github.com/refraction-networking/utls"
|
tls "github.com/refraction-networking/utls"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
@@ -31,17 +31,12 @@ type utlsRoundTripper struct {
|
|||||||
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
|
||||||
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||||
var dialer proxy.Dialer = proxy.Direct
|
var dialer proxy.Dialer = proxy.Direct
|
||||||
if cfg != nil && cfg.ProxyURL != "" {
|
if cfg != nil {
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||||
if err != nil {
|
if errBuild != nil {
|
||||||
log.Errorf("failed to parse proxy URL %q: %v", cfg.ProxyURL, err)
|
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||||
} else {
|
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||||
pDialer, err := proxy.FromURL(proxyURL, proxy.Direct)
|
dialer = proxyDialer
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create proxy dialer for %q: %v", cfg.ProxyURL, err)
|
|
||||||
} else {
|
|
||||||
dialer = pDialer
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
@@ -20,9 +18,9 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
@@ -80,35 +78,15 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
}
|
}
|
||||||
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
if errBuild != nil {
|
||||||
if err == nil {
|
log.Errorf("%v", errBuild)
|
||||||
var transport *http.Transport
|
} else if transport != nil {
|
||||||
if proxyURL.Scheme == "socks5" {
|
|
||||||
// Handle SOCKS5 proxy.
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
auth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
|
||||||
}
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Handle HTTP/HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
|
|
||||||
if transport != nil {
|
|
||||||
proxyClient := &http.Client{Transport: transport}
|
proxyClient := &http.Client{Transport: transport}
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
var err error
|
||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Package registry provides model definitions and lookup helpers for various AI providers.
|
// Package registry provides model definitions and lookup helpers for various AI providers.
|
||||||
// Static model metadata is stored in model_definitions_static_data.go.
|
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -7,6 +7,131 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AntigravityModelConfig captures static antigravity model overrides, including
|
||||||
|
// Thinking budget limits and provider max completion tokens.
|
||||||
|
type AntigravityModelConfig struct {
|
||||||
|
Thinking *ThinkingSupport `json:"thinking,omitempty"`
|
||||||
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// staticModelsJSON mirrors the top-level structure of models.json.
|
||||||
|
type staticModelsJSON struct {
|
||||||
|
Claude []*ModelInfo `json:"claude"`
|
||||||
|
Gemini []*ModelInfo `json:"gemini"`
|
||||||
|
Vertex []*ModelInfo `json:"vertex"`
|
||||||
|
GeminiCLI []*ModelInfo `json:"gemini-cli"`
|
||||||
|
AIStudio []*ModelInfo `json:"aistudio"`
|
||||||
|
CodexFree []*ModelInfo `json:"codex-free"`
|
||||||
|
CodexTeam []*ModelInfo `json:"codex-team"`
|
||||||
|
CodexPlus []*ModelInfo `json:"codex-plus"`
|
||||||
|
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||||
|
Qwen []*ModelInfo `json:"qwen"`
|
||||||
|
IFlow []*ModelInfo `json:"iflow"`
|
||||||
|
Kimi []*ModelInfo `json:"kimi"`
|
||||||
|
Antigravity map[string]*AntigravityModelConfig `json:"antigravity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions.
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Claude)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions.
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
|
||||||
|
func GetGeminiVertexModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Vertex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
|
||||||
|
func GetGeminiCLIModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().GeminiCLI)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAIStudioModels returns model definitions for AI Studio.
|
||||||
|
func GetAIStudioModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().AIStudio)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
|
||||||
|
func GetCodexFreeModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexFree)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
|
||||||
|
func GetCodexTeamModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexTeam)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
|
||||||
|
func GetCodexPlusModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPlus)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCodexProModels returns model definitions for the Codex pro plan tier.
|
||||||
|
func GetCodexProModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().CodexPro)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQwenModels returns the standard Qwen model definitions.
|
||||||
|
func GetQwenModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Qwen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIFlowModels returns the standard iFlow model definitions.
|
||||||
|
func GetIFlowModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().IFlow)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
|
||||||
|
func GetKimiModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().Kimi)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAntigravityModelConfig returns static configuration for antigravity models.
|
||||||
|
// Keys use upstream model names returned by the Antigravity models endpoint.
|
||||||
|
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||||
|
data := getModels()
|
||||||
|
if len(data.Antigravity) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]*AntigravityModelConfig, len(data.Antigravity))
|
||||||
|
for k, v := range data.Antigravity {
|
||||||
|
out[k] = cloneAntigravityModelConfig(v)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneAntigravityModelConfig(cfg *AntigravityModelConfig) *AntigravityModelConfig {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copyConfig := *cfg
|
||||||
|
if cfg.Thinking != nil {
|
||||||
|
copyThinking := *cfg.Thinking
|
||||||
|
if len(cfg.Thinking.Levels) > 0 {
|
||||||
|
copyThinking.Levels = append([]string(nil), cfg.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
copyConfig.Thinking = ©Thinking
|
||||||
|
}
|
||||||
|
return ©Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
|
||||||
|
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*ModelInfo, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
out[i] = cloneModelInfo(m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
|
||||||
// It returns nil when the channel is unknown.
|
// It returns nil when the channel is unknown.
|
||||||
//
|
//
|
||||||
@@ -35,7 +160,7 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
case "aistudio":
|
case "aistudio":
|
||||||
return GetAIStudioModels()
|
return GetAIStudioModels()
|
||||||
case "codex":
|
case "codex":
|
||||||
return GetOpenAIModels()
|
return GetCodexProModels()
|
||||||
case "qwen":
|
case "qwen":
|
||||||
return GetQwenModels()
|
return GetQwenModels()
|
||||||
case "iflow":
|
case "iflow":
|
||||||
@@ -77,27 +202,28 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data := getModels()
|
||||||
allModels := [][]*ModelInfo{
|
allModels := [][]*ModelInfo{
|
||||||
GetClaudeModels(),
|
data.Claude,
|
||||||
GetGeminiModels(),
|
data.Gemini,
|
||||||
GetGeminiVertexModels(),
|
data.Vertex,
|
||||||
GetGeminiCLIModels(),
|
data.GeminiCLI,
|
||||||
GetAIStudioModels(),
|
data.AIStudio,
|
||||||
GetOpenAIModels(),
|
data.CodexPro,
|
||||||
GetQwenModels(),
|
data.Qwen,
|
||||||
GetIFlowModels(),
|
data.IFlow,
|
||||||
GetKimiModels(),
|
data.Kimi,
|
||||||
}
|
}
|
||||||
for _, models := range allModels {
|
for _, models := range allModels {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
if m != nil && m.ID == modelID {
|
if m != nil && m.ID == modelID {
|
||||||
return m
|
return cloneModelInfo(m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check Antigravity static config
|
// Check Antigravity static config
|
||||||
if cfg := GetAntigravityModelConfig()[modelID]; cfg != nil {
|
if cfg := cloneAntigravityModelConfig(data.Antigravity[modelID]); cfg != nil {
|
||||||
return &ModelInfo{
|
return &ModelInfo{
|
||||||
ID: modelID,
|
ID: modelID,
|
||||||
Thinking: cfg.Thinking,
|
Thinking: cfg.Thinking,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -62,6 +62,11 @@ type ModelInfo struct {
|
|||||||
UserDefined bool `json:"-"`
|
UserDefined bool `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type availableModelsCacheEntry struct {
|
||||||
|
models []map[string]any
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||||
// Values are interpreted in provider-native token units.
|
// Values are interpreted in provider-native token units.
|
||||||
type ThinkingSupport struct {
|
type ThinkingSupport struct {
|
||||||
@@ -116,6 +121,8 @@ type ModelRegistry struct {
|
|||||||
clientProviders map[string]string
|
clientProviders map[string]string
|
||||||
// mutex ensures thread-safe access to the registry
|
// mutex ensures thread-safe access to the registry
|
||||||
mutex *sync.RWMutex
|
mutex *sync.RWMutex
|
||||||
|
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
|
||||||
|
availableModelsCache map[string]availableModelsCacheEntry
|
||||||
// hook is an optional callback sink for model registration changes
|
// hook is an optional callback sink for model registration changes
|
||||||
hook ModelRegistryHook
|
hook ModelRegistryHook
|
||||||
}
|
}
|
||||||
@@ -132,11 +139,24 @@ func GetGlobalRegistry() *ModelRegistry {
|
|||||||
clientModels: make(map[string][]string),
|
clientModels: make(map[string][]string),
|
||||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||||
clientProviders: make(map[string]string),
|
clientProviders: make(map[string]string),
|
||||||
|
availableModelsCache: make(map[string]availableModelsCacheEntry),
|
||||||
mutex: &sync.RWMutex{},
|
mutex: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return globalRegistry
|
return globalRegistry
|
||||||
}
|
}
|
||||||
|
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
|
||||||
|
if r.availableModelsCache == nil {
|
||||||
|
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
|
||||||
|
if len(r.availableModelsCache) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clear(r.availableModelsCache)
|
||||||
|
}
|
||||||
|
|
||||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||||
@@ -151,9 +171,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||||
return info
|
return cloneModelInfo(info)
|
||||||
}
|
}
|
||||||
return LookupStaticModelInfo(modelID)
|
return cloneModelInfo(LookupStaticModelInfo(modelID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetHook sets an optional hook for observing model registration changes.
|
// SetHook sets an optional hook for observing model registration changes.
|
||||||
@@ -211,6 +231,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
|||||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
provider := strings.ToLower(clientProvider)
|
provider := strings.ToLower(clientProvider)
|
||||||
uniqueModelIDs := make([]string, 0, len(models))
|
uniqueModelIDs := make([]string, 0, len(models))
|
||||||
@@ -236,6 +257,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientModels, clientID)
|
delete(r.clientModels, clientID)
|
||||||
delete(r.clientModelInfos, clientID)
|
delete(r.clientModelInfos, clientID)
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -263,6 +285,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
} else {
|
} else {
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
r.triggerModelsRegistered(provider, clientID, models)
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
@@ -406,6 +429,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
r.triggerModelsRegistered(provider, clientID, models)
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||||
@@ -509,6 +533,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
|||||||
if len(model.SupportedOutputModalities) > 0 {
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
|
if model.Thinking != nil {
|
||||||
|
copyThinking := *model.Thinking
|
||||||
|
if len(model.Thinking.Levels) > 0 {
|
||||||
|
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
copyModel.Thinking = ©Thinking
|
||||||
|
}
|
||||||
return ©Model
|
return ©Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,6 +569,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
|
|||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
r.unregisterClientInternal(clientID)
|
r.unregisterClientInternal(clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||||
@@ -604,9 +636,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
now := time.Now()
|
||||||
|
registration.QuotaExceededClients[clientID] = &now
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -618,9 +653,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
|||||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -636,6 +673,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
|||||||
}
|
}
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
registration, exists := r.models[modelID]
|
registration, exists := r.models[modelID]
|
||||||
if !exists || registration == nil {
|
if !exists || registration == nil {
|
||||||
@@ -649,6 +687,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
|||||||
}
|
}
|
||||||
registration.SuspendedClients[clientID] = reason
|
registration.SuspendedClients[clientID] = reason
|
||||||
registration.LastUpdated = time.Now()
|
registration.LastUpdated = time.Now()
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
if reason != "" {
|
if reason != "" {
|
||||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||||
} else {
|
} else {
|
||||||
@@ -666,6 +705,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
|||||||
}
|
}
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
registration, exists := r.models[modelID]
|
registration, exists := r.models[modelID]
|
||||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||||
@@ -676,6 +716,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
|||||||
}
|
}
|
||||||
delete(registration.SuspendedClients, clientID)
|
delete(registration.SuspendedClients, clientID)
|
||||||
registration.LastUpdated = time.Now()
|
registration.LastUpdated = time.Now()
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -711,22 +752,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []map[string]any: List of available models in the requested format
|
// - []map[string]any: List of available models in the requested format
|
||||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
|
|
||||||
models := make([]map[string]any, 0)
|
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
|
|
||||||
for _, registration := range r.models {
|
|
||||||
// Check if model has any non-quota-exceeded clients
|
|
||||||
availableClients := registration.Count
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Count clients that have exceeded quota but haven't recovered yet
|
r.mutex.RLock()
|
||||||
|
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||||
|
models := cloneModelMaps(cache.models)
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
|
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||||
|
return cloneModelMaps(cache.models)
|
||||||
|
}
|
||||||
|
|
||||||
|
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
|
||||||
|
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
|
||||||
|
models: cloneModelMaps(models),
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||||
|
models := make([]map[string]any, 0, len(r.models))
|
||||||
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
var expiresAt time.Time
|
||||||
|
|
||||||
|
for _, registration := range r.models {
|
||||||
|
availableClients := registration.Count
|
||||||
|
|
||||||
expiredClients := 0
|
expiredClients := 0
|
||||||
for _, quotaTime := range registration.QuotaExceededClients {
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
recoveryAt := quotaTime.Add(quotaExpiredDuration)
|
||||||
|
if now.Before(recoveryAt) {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
|
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||||
|
expiresAt = recoveryAt
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -747,7 +818,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
effectiveClients = 0
|
effectiveClients = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include models that have available clients, or those solely cooling down.
|
|
||||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||||
model := r.convertModelToMap(registration.Info, handlerType)
|
model := r.convertModelToMap(registration.Info, handlerType)
|
||||||
if model != nil {
|
if model != nil {
|
||||||
@@ -756,7 +826,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return models
|
return models, expiresAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelMaps(models []map[string]any) []map[string]any {
|
||||||
|
cloned := make([]map[string]any, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
cloned = append(cloned, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copyModel := make(map[string]any, len(model))
|
||||||
|
for key, value := range model {
|
||||||
|
copyModel[key] = cloneModelMapValue(value)
|
||||||
|
}
|
||||||
|
cloned = append(cloned, copyModel)
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelMapValue(value any) any {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
copyMap := make(map[string]any, len(typed))
|
||||||
|
for key, entry := range typed {
|
||||||
|
copyMap[key] = cloneModelMapValue(entry)
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
case []any:
|
||||||
|
copySlice := make([]any, len(typed))
|
||||||
|
for i, entry := range typed {
|
||||||
|
copySlice[i] = cloneModelMapValue(entry)
|
||||||
|
}
|
||||||
|
return copySlice
|
||||||
|
case []string:
|
||||||
|
return append([]string(nil), typed...)
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||||
@@ -872,11 +979,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
|
|
||||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||||
if entry.info != nil {
|
if entry.info != nil {
|
||||||
result = append(result, entry.info)
|
result = append(result, cloneModelInfo(entry.info))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ok && registration != nil && registration.Info != nil {
|
if ok && registration != nil && registration.Info != nil {
|
||||||
result = append(result, registration.Info)
|
result = append(result, cloneModelInfo(registration.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -985,13 +1092,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
|||||||
if reg.Providers != nil {
|
if reg.Providers != nil {
|
||||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||||
return info
|
return cloneModelInfo(info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback to global info (last registered)
|
// Fallback to global info (last registered)
|
||||||
return reg.Info
|
return cloneModelInfo(reg.Info)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1031,7 +1138,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||||
}
|
}
|
||||||
if len(model.SupportedParameters) > 0 {
|
if len(model.SupportedParameters) > 0 {
|
||||||
result["supported_parameters"] = model.SupportedParameters
|
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1075,13 +1182,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||||
}
|
}
|
||||||
if len(model.SupportedGenerationMethods) > 0 {
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedInputModalities) > 0 {
|
if len(model.SupportedInputModalities) > 0 {
|
||||||
result["supportedInputModalities"] = model.SupportedInputModalities
|
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedOutputModalities) > 0 {
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1111,15 +1218,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
invalidated := false
|
||||||
|
|
||||||
for modelID, registration := range r.models {
|
for modelID, registration := range r.models {
|
||||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
invalidated = true
|
||||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if invalidated {
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||||
@@ -1133,8 +1245,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
// - string: The model ID of the first available model, or empty string if none available
|
// - string: The model ID of the first available model, or empty string if none available
|
||||||
// - error: An error if no models are available
|
// - error: An error if no models are available
|
||||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
|
|
||||||
// Get all available models for this handler type
|
// Get all available models for this handler type
|
||||||
models := r.GetAvailableModels(handlerType)
|
models := r.GetAvailableModels(handlerType)
|
||||||
@@ -1194,13 +1304,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
|||||||
// Prefer client's own model info to preserve original type/owned_by
|
// Prefer client's own model info to preserve original type/owned_by
|
||||||
if clientInfos != nil {
|
if clientInfos != nil {
|
||||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||||
result = append(result, info)
|
result = append(result, cloneModelInfo(info))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback to global registry (for backwards compatibility)
|
// Fallback to global registry (for backwards compatibility)
|
||||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||||
result = append(result, reg.Info)
|
result = append(result, cloneModelInfo(reg.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModels("openai")
|
||||||
|
if len(first) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(first))
|
||||||
|
}
|
||||||
|
first[0]["id"] = "mutated"
|
||||||
|
first[0]["display_name"] = "Mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModels("openai")
|
||||||
|
if got := second[0]["id"]; got != "m1" {
|
||||||
|
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
|
||||||
|
}
|
||||||
|
if got := second[0]["display_name"]; got != "Model One" {
|
||||||
|
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||||
|
|
||||||
|
models := r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(models))
|
||||||
|
}
|
||||||
|
if got := models[0]["display_name"]; got != "Model One" {
|
||||||
|
t.Fatalf("expected initial display_name Model One, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if got := models[0]["display_name"]; got != "Model One Updated" {
|
||||||
|
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.SuspendClientModel("client-1", "m1", "manual")
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 0 {
|
||||||
|
t.Fatalf("expected no available models after suspension, got %d", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ResumeClientModel("client-1", "m1")
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected model to reappear after resume, got %d", len(models))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetModelInfoReturnsClone(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetModelInfo("m1", "gemini")
|
||||||
|
if first == nil {
|
||||||
|
t.Fatal("expected model info")
|
||||||
|
}
|
||||||
|
first.DisplayName = "mutated"
|
||||||
|
first.Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetModelInfo("m1", "gemini")
|
||||||
|
if second.DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
|
||||||
|
}
|
||||||
|
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelsForClientReturnsClones(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetModelsForClient("client-1")
|
||||||
|
if len(first) != 1 || first[0] == nil {
|
||||||
|
t.Fatalf("expected one model, got %+v", first)
|
||||||
|
}
|
||||||
|
first[0].DisplayName = "mutated"
|
||||||
|
first[0].Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetModelsForClient("client-1")
|
||||||
|
if len(second) != 1 || second[0] == nil {
|
||||||
|
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||||
|
}
|
||||||
|
if second[0].DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||||
|
}
|
||||||
|
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModelsByProvider("gemini")
|
||||||
|
if len(first) != 1 || first[0] == nil {
|
||||||
|
t.Fatalf("expected one model, got %+v", first)
|
||||||
|
}
|
||||||
|
first[0].DisplayName = "mutated"
|
||||||
|
first[0].Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModelsByProvider("gemini")
|
||||||
|
if len(second) != 1 || second[0] == nil {
|
||||||
|
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||||
|
}
|
||||||
|
if second[0].DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||||
|
}
|
||||||
|
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
|
||||||
|
r.SetModelQuotaExceeded("client-1", "m1")
|
||||||
|
if models := r.GetAvailableModels("openai"); len(models) != 1 {
|
||||||
|
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
quotaTime := time.Now().Add(-6 * time.Minute)
|
||||||
|
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
|
||||||
|
r.mutex.Unlock()
|
||||||
|
|
||||||
|
r.CleanupExpiredQuotas()
|
||||||
|
|
||||||
|
if count := r.GetModelCount("m1"); count != 1 {
|
||||||
|
t.Fatalf("expected model count 1 after cleanup, got %d", count)
|
||||||
|
}
|
||||||
|
models := r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
|
||||||
|
}
|
||||||
|
if got := models[0]["id"]; got != "m1" {
|
||||||
|
t.Fatalf("expected model id m1, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "openai", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
SupportedParameters: []string{"temperature", "top_p"},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModels("openai")
|
||||||
|
if len(first) != 1 {
|
||||||
|
t.Fatalf("expected one model, got %d", len(first))
|
||||||
|
}
|
||||||
|
params, ok := first[0]["supported_parameters"].([]string)
|
||||||
|
if !ok || len(params) != 2 {
|
||||||
|
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
|
||||||
|
}
|
||||||
|
params[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModels("openai")
|
||||||
|
params, ok = second[0]["supported_parameters"].([]string)
|
||||||
|
if !ok || len(params) != 2 || params[0] != "temperature" {
|
||||||
|
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
|
||||||
|
first := LookupModelInfo("glm-4.6")
|
||||||
|
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
|
||||||
|
t.Fatalf("expected static model with thinking levels, got %+v", first)
|
||||||
|
}
|
||||||
|
first.Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := LookupModelInfo("glm-4.6")
|
||||||
|
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
|
||||||
|
t.Fatalf("expected static lookup clone, got %+v", second)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,198 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
_ "embed"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
modelsFetchTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var modelsURLs = []string{
|
||||||
|
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
|
||||||
|
"https://models.router-for.me/models.json",
|
||||||
|
}
|
||||||
|
|
||||||
|
//go:embed models/models.json
|
||||||
|
var embeddedModelsJSON []byte
|
||||||
|
|
||||||
|
type modelStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
data *staticModelsJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
var modelsCatalogStore = &modelStore{}
|
||||||
|
|
||||||
|
var updaterOnce sync.Once
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Load embedded data as fallback on startup.
|
||||||
|
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||||
|
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartModelsUpdater runs a one-time models refresh on startup.
|
||||||
|
// It blocks until the startup fetch attempt finishes so service initialization
|
||||||
|
// can wait for the refreshed catalog before registering auth-backed models.
|
||||||
|
// Safe to call multiple times; only one refresh will run.
|
||||||
|
func StartModelsUpdater(ctx context.Context) {
|
||||||
|
updaterOnce.Do(func() {
|
||||||
|
runModelsUpdater(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func runModelsUpdater(ctx context.Context) {
|
||||||
|
// Try network fetch once on startup, then stop.
|
||||||
|
// Periodic refresh is disabled - models are only refreshed at startup.
|
||||||
|
tryRefreshModels(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func tryRefreshModels(ctx context.Context) {
|
||||||
|
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||||
|
for _, url := range modelsURLs {
|
||||||
|
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||||
|
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch request creation failed for %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
resp.Body.Close()
|
||||||
|
cancel()
|
||||||
|
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("models fetch read error from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := loadModelsFromBytes(data, url); err != nil {
|
||||||
|
log.Warnf("models parse failed from %s: %v", url, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("models updated from %s", url)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warn("models refresh failed from all URLs, using current data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadModelsFromBytes(data []byte, source string) error {
|
||||||
|
var parsed staticModelsJSON
|
||||||
|
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||||
|
return fmt.Errorf("%s: decode models catalog: %w", source, err)
|
||||||
|
}
|
||||||
|
if err := validateModelsCatalog(&parsed); err != nil {
|
||||||
|
return fmt.Errorf("%s: validate models catalog: %w", source, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsCatalogStore.mu.Lock()
|
||||||
|
modelsCatalogStore.data = &parsed
|
||||||
|
modelsCatalogStore.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getModels() *staticModelsJSON {
|
||||||
|
modelsCatalogStore.mu.RLock()
|
||||||
|
defer modelsCatalogStore.mu.RUnlock()
|
||||||
|
return modelsCatalogStore.data
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateModelsCatalog(data *staticModelsJSON) error {
|
||||||
|
if data == nil {
|
||||||
|
return fmt.Errorf("catalog is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
requiredSections := []struct {
|
||||||
|
name string
|
||||||
|
models []*ModelInfo
|
||||||
|
}{
|
||||||
|
{name: "claude", models: data.Claude},
|
||||||
|
{name: "gemini", models: data.Gemini},
|
||||||
|
{name: "vertex", models: data.Vertex},
|
||||||
|
{name: "gemini-cli", models: data.GeminiCLI},
|
||||||
|
{name: "aistudio", models: data.AIStudio},
|
||||||
|
{name: "codex-free", models: data.CodexFree},
|
||||||
|
{name: "codex-team", models: data.CodexTeam},
|
||||||
|
{name: "codex-plus", models: data.CodexPlus},
|
||||||
|
{name: "codex-pro", models: data.CodexPro},
|
||||||
|
{name: "qwen", models: data.Qwen},
|
||||||
|
{name: "iflow", models: data.IFlow},
|
||||||
|
{name: "kimi", models: data.Kimi},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, section := range requiredSections {
|
||||||
|
if err := validateModelSection(section.name, section.models); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := validateAntigravitySection(data.Antigravity); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateModelSection(section string, models []*ModelInfo) error {
|
||||||
|
if len(models) == 0 {
|
||||||
|
return fmt.Errorf("%s section is empty", section)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(models))
|
||||||
|
for i, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
return fmt.Errorf("%s[%d] is null", section, i)
|
||||||
|
}
|
||||||
|
modelID := strings.TrimSpace(model.ID)
|
||||||
|
if modelID == "" {
|
||||||
|
return fmt.Errorf("%s[%d] has empty id", section, i)
|
||||||
|
}
|
||||||
|
if _, exists := seen[modelID]; exists {
|
||||||
|
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
|
||||||
|
}
|
||||||
|
seen[modelID] = struct{}{}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateAntigravitySection(configs map[string]*AntigravityModelConfig) error {
|
||||||
|
if len(configs) == 0 {
|
||||||
|
return fmt.Errorf("antigravity section is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
for modelID, cfg := range configs {
|
||||||
|
trimmedID := strings.TrimSpace(modelID)
|
||||||
|
if trimmedID == "" {
|
||||||
|
return fmt.Errorf("antigravity contains empty model id")
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return fmt.Errorf("antigravity[%q] is null", trimmedID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
} else if system.Type == gjson.String && system.String() != "" {
|
||||||
|
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||||
|
partJSON, _ = sjson.Set(partJSON, "text", system.String())
|
||||||
|
result += "," + partJSON
|
||||||
}
|
}
|
||||||
result += "]"
|
result += "]"
|
||||||
|
|
||||||
@@ -1485,25 +1489,27 @@ func countCacheControlsMap(root map[string]any) int {
|
|||||||
return count
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) {
|
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
|
||||||
ccRaw, exists := obj["cache_control"]
|
ccRaw, exists := obj["cache_control"]
|
||||||
if !exists {
|
if !exists {
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
cc, ok := asObject(ccRaw)
|
cc, ok := asObject(ccRaw)
|
||||||
if !ok {
|
if !ok {
|
||||||
*seen5m = true
|
*seen5m = true
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
ttlRaw, ttlExists := cc["ttl"]
|
ttlRaw, ttlExists := cc["ttl"]
|
||||||
ttl, ttlIsString := ttlRaw.(string)
|
ttl, ttlIsString := ttlRaw.(string)
|
||||||
if !ttlExists || !ttlIsString || ttl != "1h" {
|
if !ttlExists || !ttlIsString || ttl != "1h" {
|
||||||
*seen5m = true
|
*seen5m = true
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
if *seen5m {
|
if *seen5m {
|
||||||
delete(cc, "ttl")
|
delete(cc, "ttl")
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func findLastCacheControlIndex(arr []any) int {
|
func findLastCacheControlIndex(arr []any) int {
|
||||||
@@ -1599,11 +1605,14 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
seen5m := false
|
seen5m := false
|
||||||
|
modified := false
|
||||||
|
|
||||||
if tools, ok := asArray(root["tools"]); ok {
|
if tools, ok := asArray(root["tools"]); ok {
|
||||||
for _, tool := range tools {
|
for _, tool := range tools {
|
||||||
if obj, ok := asObject(tool); ok {
|
if obj, ok := asObject(tool); ok {
|
||||||
normalizeTTLForBlock(obj, &seen5m)
|
if normalizeTTLForBlock(obj, &seen5m) {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1611,7 +1620,9 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
|||||||
if system, ok := asArray(root["system"]); ok {
|
if system, ok := asArray(root["system"]); ok {
|
||||||
for _, item := range system {
|
for _, item := range system {
|
||||||
if obj, ok := asObject(item); ok {
|
if obj, ok := asObject(item); ok {
|
||||||
normalizeTTLForBlock(obj, &seen5m)
|
if normalizeTTLForBlock(obj, &seen5m) {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1628,12 +1639,17 @@ func normalizeCacheControlTTL(payload []byte) []byte {
|
|||||||
}
|
}
|
||||||
for _, item := range content {
|
for _, item := range content {
|
||||||
if obj, ok := asObject(item); ok {
|
if obj, ok := asObject(item); ok {
|
||||||
normalizeTTLForBlock(obj, &seen5m)
|
if normalizeTTLForBlock(obj, &seen5m) {
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !modified {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
return marshalPayloadObject(payload, root)
|
return marshalPayloadObject(payload, root)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -369,6 +369,19 @@ func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
|
||||||
|
// Payload where no TTL normalization is needed (all blocks use 1h with no
|
||||||
|
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
|
||||||
|
// that json.Marshal would escape to \u003c etc., altering byte identity.
|
||||||
|
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"<system-reminder>foo & bar</system-reminder>","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
|
||||||
|
|
||||||
|
out := normalizeCacheControlTTL(payload)
|
||||||
|
|
||||||
|
if !bytes.Equal(out, payload) {
|
||||||
|
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
|
||||||
payload := []byte(`{
|
payload := []byte(`{
|
||||||
"tools": [
|
"tools": [
|
||||||
@@ -967,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
|||||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test case 1: String system prompt is preserved and converted to a content block
|
||||||
|
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
system := gjson.GetBytes(out, "system")
|
||||||
|
if !system.IsArray() {
|
||||||
|
t.Fatalf("system should be an array, got %s", system.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks := system.Array()
|
||||||
|
if len(blocks) != 3 {
|
||||||
|
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
|
||||||
|
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
|
||||||
|
}
|
||||||
|
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||||
|
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
|
||||||
|
}
|
||||||
|
if blocks[2].Get("text").String() != "You are a helpful assistant." {
|
||||||
|
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||||
|
}
|
||||||
|
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
|
||||||
|
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 2: Strict mode drops the string system prompt
|
||||||
|
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, true)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 2 {
|
||||||
|
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 3: Empty string system prompt does not produce a spurious block
|
||||||
|
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 2 {
|
||||||
|
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 4: Array system prompt is unaffected by the string handling
|
||||||
|
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 3 {
|
||||||
|
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
if blocks[2].Get("text").String() != "Be concise." {
|
||||||
|
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 5: Special characters in string system prompt survive conversion
|
||||||
|
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
||||||
|
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
|
||||||
|
|
||||||
|
out := checkSystemInstructionsWithMode(payload, false)
|
||||||
|
|
||||||
|
blocks := gjson.GetBytes(out, "system").Array()
|
||||||
|
if len(blocks) != 3 {
|
||||||
|
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
|
||||||
|
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -705,21 +706,30 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
|||||||
return dialer
|
return dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, errParse := url.Parse(proxyURL)
|
setting, errParse := proxyutil.Parse(proxyURL)
|
||||||
if errParse != nil {
|
if errParse != nil {
|
||||||
log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse)
|
log.Errorf("codex websockets executor: %v", errParse)
|
||||||
return dialer
|
return dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
switch parsedURL.Scheme {
|
switch setting.Mode {
|
||||||
|
case proxyutil.ModeDirect:
|
||||||
|
dialer.Proxy = nil
|
||||||
|
return dialer
|
||||||
|
case proxyutil.ModeProxy:
|
||||||
|
default:
|
||||||
|
return dialer
|
||||||
|
}
|
||||||
|
|
||||||
|
switch setting.URL.Scheme {
|
||||||
case "socks5":
|
case "socks5":
|
||||||
var proxyAuth *proxy.Auth
|
var proxyAuth *proxy.Auth
|
||||||
if parsedURL.User != nil {
|
if setting.URL.User != nil {
|
||||||
username := parsedURL.User.Username()
|
username := setting.URL.User.Username()
|
||||||
password, _ := parsedURL.User.Password()
|
password, _ := setting.URL.User.Password()
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||||
}
|
}
|
||||||
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
||||||
if errSOCKS5 != nil {
|
if errSOCKS5 != nil {
|
||||||
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||||
return dialer
|
return dialer
|
||||||
@@ -729,9 +739,9 @@ func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *
|
|||||||
return socksDialer.Dial(network, addr)
|
return socksDialer.Dial(network, addr)
|
||||||
}
|
}
|
||||||
case "http", "https":
|
case "http", "https":
|
||||||
dialer.Proxy = http.ProxyURL(parsedURL)
|
dialer.Proxy = http.ProxyURL(setting.URL)
|
||||||
default:
|
default:
|
||||||
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme)
|
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dialer
|
return dialer
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,3 +188,16 @@ func contextWithGinHeaders(headers map[string]string) context.Context {
|
|||||||
}
|
}
|
||||||
return context.WithValue(context.Background(), "gin", ginCtx)
|
return context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dialer := newProxyAwareWebsocketDialer(
|
||||||
|
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||||
|
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if dialer.Proxy != nil {
|
||||||
|
t.Fatal("expected websocket proxy function to be nil for direct mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -460,7 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
|
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://aiplatform.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
if opts.Alt != "" && action != "countTokens" {
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
@@ -683,7 +683,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
action := getVertexAction(baseModel, true)
|
action := getVertexAction(baseModel, true)
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://aiplatform.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
// Imagen models don't support streaming, skip SSE params
|
// Imagen models don't support streaming, skip SSE params
|
||||||
@@ -883,7 +883,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
|||||||
|
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://aiplatform.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
|
||||||
|
|
||||||
|
|||||||
@@ -2,16 +2,14 @@ package executor
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
|
||||||
@@ -72,45 +70,10 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
|
||||||
func buildProxyTransport(proxyURL string) *http.Transport {
|
func buildProxyTransport(proxyURL string) *http.Transport {
|
||||||
if proxyURL == "" {
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
|
||||||
|
if errBuild != nil {
|
||||||
|
log.Errorf("%v", errBuild)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedURL, errParse := url.Parse(proxyURL)
|
|
||||||
if errParse != nil {
|
|
||||||
log.Errorf("parse proxy URL failed: %v", errParse)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var transport *http.Transport
|
|
||||||
|
|
||||||
// Handle different proxy schemes
|
|
||||||
if parsedURL.Scheme == "socks5" {
|
|
||||||
// Configure SOCKS5 proxy with optional authentication
|
|
||||||
var proxyAuth *proxy.Auth
|
|
||||||
if parsedURL.User != nil {
|
|
||||||
username := parsedURL.User.Username()
|
|
||||||
password, _ := parsedURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Set up a custom transport using the SOCKS5 dialer
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if parsedURL.Scheme == "http" || parsedURL.Scheme == "https" {
|
|
||||||
// Configure HTTP or HTTPS proxy
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(parsedURL)}
|
|
||||||
} else {
|
|
||||||
log.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return transport
|
return transport
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
client := newProxyAwareHTTPClient(
|
||||||
|
context.Background(),
|
||||||
|
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
|
||||||
|
&cliproxyauth.Auth{ProxyURL: "direct"},
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
transport, ok := client.Transport.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -257,8 +257,11 @@ func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromForma
|
|||||||
if suffixResult.HasSuffix {
|
if suffixResult.HasSuffix {
|
||||||
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
|
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
|
||||||
} else {
|
} else {
|
||||||
|
config = extractThinkingConfig(body, fromFormat)
|
||||||
|
if !hasThinkingConfig(config) && fromFormat != toFormat {
|
||||||
config = extractThinkingConfig(body, toFormat)
|
config = extractThinkingConfig(body, toFormat)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !hasThinkingConfig(config) {
|
if !hasThinkingConfig(config) {
|
||||||
log.WithFields(log.Fields{
|
log.WithFields(log.Fields{
|
||||||
@@ -293,6 +296,9 @@ func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat stri
|
|||||||
if config.Mode != ModeLevel {
|
if config.Mode != ModeLevel {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
if toFormat == "claude" {
|
||||||
|
return config
|
||||||
|
}
|
||||||
if !isBudgetCapableProvider(toFormat) {
|
if !isBudgetCapableProvider(toFormat) {
|
||||||
return config
|
return config
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package thinking_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
clientID := "test-user-defined-claude-" + t.Name()
|
||||||
|
modelID := "custom-claude-4-6"
|
||||||
|
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(clientID)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model string
|
||||||
|
body []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "claude adaptive effort body",
|
||||||
|
model: modelID,
|
||||||
|
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "suffix level",
|
||||||
|
model: modelID + "(high)",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ApplyThinking() error = %v", err)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
|
||||||
|
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
|
||||||
|
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
|
||||||
|
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -477,9 +477,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
|||||||
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
effort = strings.ToLower(strings.TrimSpace(v.String()))
|
||||||
}
|
}
|
||||||
if effort != "" {
|
if effort != "" {
|
||||||
if effort == "max" {
|
|
||||||
effort = "high"
|
|
||||||
}
|
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
|
||||||
} else {
|
} else {
|
||||||
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
|
||||||
|
|||||||
@@ -1235,64 +1235,3 @@ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *t
|
|||||||
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_EffortLevels(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
effort string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"low", "low", "low"},
|
|
||||||
{"medium", "medium", "medium"},
|
|
||||||
{"high", "high", "high"},
|
|
||||||
{"max", "max", "high"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
inputJSON := []byte(`{
|
|
||||||
"model": "claude-opus-4-6-thinking",
|
|
||||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
|
||||||
"thinking": {"type": "adaptive"},
|
|
||||||
"output_config": {"effort": "` + tt.effort + `"}
|
|
||||||
}`)
|
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
|
||||||
outputStr := string(output)
|
|
||||||
|
|
||||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
|
||||||
if !thinkingConfig.Exists() {
|
|
||||||
t.Fatal("thinkingConfig should exist for adaptive thinking")
|
|
||||||
}
|
|
||||||
if thinkingConfig.Get("thinkingLevel").String() != tt.expected {
|
|
||||||
t.Errorf("Expected thinkingLevel %q, got %q", tt.expected, thinkingConfig.Get("thinkingLevel").String())
|
|
||||||
}
|
|
||||||
if !thinkingConfig.Get("includeThoughts").Bool() {
|
|
||||||
t.Error("includeThoughts should be true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertClaudeRequestToAntigravity_AdaptiveThinking_NoEffort(t *testing.T) {
|
|
||||||
inputJSON := []byte(`{
|
|
||||||
"model": "claude-opus-4-6-thinking",
|
|
||||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
|
||||||
"thinking": {"type": "adaptive"}
|
|
||||||
}`)
|
|
||||||
|
|
||||||
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, false)
|
|
||||||
outputStr := string(output)
|
|
||||||
|
|
||||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
|
||||||
if !thinkingConfig.Exists() {
|
|
||||||
t.Fatal("thinkingConfig should exist for adaptive thinking without effort")
|
|
||||||
}
|
|
||||||
if thinkingConfig.Get("thinkingLevel").String() != "high" {
|
|
||||||
t.Errorf("Expected default thinkingLevel \"high\", got %q", thinkingConfig.Get("thinkingLevel").String())
|
|
||||||
}
|
|
||||||
if !thinkingConfig.Get("includeThoughts").Bool() {
|
|
||||||
t.Error("includeThoughts should be true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -256,7 +257,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
|
|||||||
@@ -43,23 +43,32 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
|
|
||||||
// Process system messages and convert them to input content format.
|
// Process system messages and convert them to input content format.
|
||||||
systemsResult := rootResult.Get("system")
|
systemsResult := rootResult.Get("system")
|
||||||
if systemsResult.IsArray() {
|
if systemsResult.Exists() {
|
||||||
systemResults := systemsResult.Array()
|
|
||||||
message := `{"type":"message","role":"developer","content":[]}`
|
message := `{"type":"message","role":"developer","content":[]}`
|
||||||
contentIndex := 0
|
contentIndex := 0
|
||||||
for i := 0; i < len(systemResults); i++ {
|
|
||||||
systemResult := systemResults[i]
|
appendSystemText := func(text string) {
|
||||||
systemTypeResult := systemResult.Get("type")
|
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||||
if systemTypeResult.String() == "text" {
|
return
|
||||||
text := systemResult.Get("text").String()
|
|
||||||
if strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
|
||||||
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
|
||||||
contentIndex++
|
contentIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if systemsResult.Type == gjson.String {
|
||||||
|
appendSystemText(systemsResult.String())
|
||||||
|
} else if systemsResult.IsArray() {
|
||||||
|
systemResults := systemsResult.Array()
|
||||||
|
for i := 0; i < len(systemResults); i++ {
|
||||||
|
systemResult := systemResults[i]
|
||||||
|
if systemResult.Get("type").String() == "text" {
|
||||||
|
appendSystemText(systemResult.Get("text").String())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if contentIndex > 0 {
|
if contentIndex > 0 {
|
||||||
template, _ = sjson.SetRaw(template, "input.-1", message)
|
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputJSON string
|
||||||
|
wantHasDeveloper bool
|
||||||
|
wantTexts []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasDeveloper: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Empty string system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": "",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasDeveloper: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "String system field",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": "Be helpful",
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasDeveloper: true,
|
||||||
|
wantTexts: []string{"Be helpful"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Array system field with filtered billing header",
|
||||||
|
inputJSON: `{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"system": [
|
||||||
|
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
|
||||||
|
{"type": "text", "text": "Block 1"},
|
||||||
|
{"type": "text", "text": "Block 2"}
|
||||||
|
],
|
||||||
|
"messages": [{"role": "user", "content": "hello"}]
|
||||||
|
}`,
|
||||||
|
wantHasDeveloper: true,
|
||||||
|
wantTexts: []string{"Block 1", "Block 2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
|
||||||
|
resultJSON := gjson.ParseBytes(result)
|
||||||
|
inputs := resultJSON.Get("input").Array()
|
||||||
|
|
||||||
|
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
|
||||||
|
if hasDeveloper != tt.wantHasDeveloper {
|
||||||
|
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.wantHasDeveloper {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
content := inputs[0].Get("content").Array()
|
||||||
|
if len(content) != len(tt.wantTexts) {
|
||||||
|
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, wantText := range tt.wantTexts {
|
||||||
|
if gotType := content[i].Get("type").String(); gotType != "input_text" {
|
||||||
|
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
|
||||||
|
}
|
||||||
|
if gotText := content[i].Get("text").String(); gotText != wantText {
|
||||||
|
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -141,7 +142,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
|||||||
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
|
||||||
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||||
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
|
||||||
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||||
{
|
{
|
||||||
// Restore original tool name if shortened
|
// Restore original tool name if shortened
|
||||||
name := itemResult.Get("name").String()
|
name := itemResult.Get("name").String()
|
||||||
@@ -310,7 +311,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
|||||||
}
|
}
|
||||||
|
|
||||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||||
toolBlock, _ = sjson.Set(toolBlock, "id", item.Get("call_id").String())
|
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||||
inputRaw := "{}"
|
inputRaw := "{}"
|
||||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
@@ -209,7 +210,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
|
|||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1)))
|
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
|||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
// Create the tool use block with unique ID and function details
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1)))
|
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
|
||||||
data, _ = sjson.Set(data, "content_block.name", clientToolName)
|
data, _ = sjson.Set(data, "content_block.name", clientToolName)
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
@@ -343,7 +343,7 @@ func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
|
||||||
toolIDCounter++
|
toolIDCounter++
|
||||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter))
|
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
|
||||||
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
|
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
|
||||||
inputRaw := "{}"
|
inputRaw := "{}"
|
||||||
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
|
||||||
|
|||||||
@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
content := m.Get("content")
|
content := m.Get("content")
|
||||||
|
|
||||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||||
// system -> system_instruction as a user message style
|
// system -> systemInstruction as a user message style
|
||||||
if content.Type == gjson.String {
|
if content.Type == gjson.String {
|
||||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
|
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
|
||||||
systemPartIndex++
|
systemPartIndex++
|
||||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||||
systemPartIndex++
|
systemPartIndex++
|
||||||
} else if content.IsArray() {
|
} else if content.IsArray() {
|
||||||
contents := content.Array()
|
contents := content.Array()
|
||||||
if len(contents) > 0 {
|
if len(contents) > 0 {
|
||||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||||
for j := 0; j < len(contents); j++ {
|
for j := 0; j < len(contents); j++ {
|
||||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||||
systemPartIndex++
|
systemPartIndex++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
if instructions := root.Get("instructions"); instructions.Exists() {
|
if instructions := root.Get("instructions"); instructions.Exists() {
|
||||||
systemInstr := `{"parts":[{"text":""}]}`
|
systemInstr := `{"parts":[{"text":""}]}`
|
||||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert input messages to Gemini contents format
|
// Convert input messages to Gemini contents format
|
||||||
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
if strings.EqualFold(itemRole, "system") {
|
if strings.EqualFold(itemRole, "system") {
|
||||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||||
systemInstr := ""
|
systemInstr := ""
|
||||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
|
||||||
systemInstr = systemInstructionResult.Raw
|
systemInstr = systemInstructionResult.Raw
|
||||||
} else {
|
} else {
|
||||||
systemInstr = `{"parts":[]}`
|
systemInstr = `{"parts":[]}`
|
||||||
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
|||||||
}
|
}
|
||||||
|
|
||||||
if systemInstr != `{"parts":[]}` {
|
if systemInstr != `{"parts":[]}` {
|
||||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -243,7 +243,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
|||||||
// Send content_block_start for tool_use
|
// Send content_block_start for tool_use
|
||||||
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
|
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
|
||||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", accumulator.ID)
|
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
|
||||||
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
|
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
|
||||||
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
|
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
|
||||||
}
|
}
|
||||||
@@ -414,7 +414,7 @@ func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
|||||||
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
|
||||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
|
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
|
||||||
|
|
||||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||||
@@ -612,7 +612,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
toolCalls.ForEach(func(_, tc gjson.Result) bool {
|
toolCalls.ForEach(func(_, tc gjson.Result) bool {
|
||||||
hasToolCall = true
|
hasToolCall = true
|
||||||
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||||
toolUse, _ = sjson.Set(toolUse, "id", tc.Get("id").String())
|
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
|
||||||
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
|
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
|
||||||
|
|
||||||
argsStr := util.FixJSON(tc.Get("function.arguments").String())
|
argsStr := util.FixJSON(tc.Get("function.arguments").String())
|
||||||
@@ -669,7 +669,7 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina
|
|||||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||||
hasToolCall = true
|
hasToolCall = true
|
||||||
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", toolCall.Get("id").String())
|
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
|
||||||
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
|
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
|
||||||
|
|
||||||
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
|
||||||
|
claudeToolUseIDCounter uint64
|
||||||
|
)
|
||||||
|
|
||||||
|
// SanitizeClaudeToolID ensures the given id conforms to Claude's
|
||||||
|
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
|
||||||
|
// replaced with '_'; an empty result gets a generated fallback.
|
||||||
|
func SanitizeClaudeToolID(id string) string {
|
||||||
|
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
|
||||||
|
if s == "" {
|
||||||
|
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
+6
-31
@@ -4,50 +4,25 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
||||||
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
||||||
// to route requests through the configured proxy server.
|
// to route requests through the configured proxy server.
|
||||||
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
|
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
|
||||||
var transport *http.Transport
|
if cfg == nil || httpClient == nil {
|
||||||
// Attempt to parse the proxy URL from the configuration.
|
|
||||||
proxyURL, errParse := url.Parse(cfg.ProxyURL)
|
|
||||||
if errParse == nil {
|
|
||||||
// Handle different proxy schemes.
|
|
||||||
if proxyURL.Scheme == "socks5" {
|
|
||||||
// Configure SOCKS5 proxy with optional authentication.
|
|
||||||
var proxyAuth *proxy.Auth
|
|
||||||
if proxyURL.User != nil {
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth = &proxy.Auth{User: username, Password: password}
|
|
||||||
}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return httpClient
|
return httpClient
|
||||||
}
|
}
|
||||||
// Set up a custom transport using the SOCKS5 dialer.
|
|
||||||
transport = &http.Transport{
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
if errBuild != nil {
|
||||||
return dialer.Dial(network, addr)
|
log.Errorf("%v", errBuild)
|
||||||
},
|
|
||||||
}
|
}
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Configure HTTP or HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If a new transport was created, apply it to the HTTP client.
|
|
||||||
if transport != nil {
|
if transport != nil {
|
||||||
httpClient.Transport = transport
|
httpClient.Transport = transport
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
)
|
)
|
||||||
@@ -149,6 +150,16 @@ func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
|
||||||
|
// For codex auth files, extract plan_type from the JWT id_token.
|
||||||
|
if provider == "codex" {
|
||||||
|
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
|
||||||
|
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
|
||||||
|
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
|
||||||
|
a.Attributes["plan_type"] = pt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if provider == "gemini-cli" {
|
if provider == "gemini-cli" {
|
||||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||||
for _, v := range virtuals {
|
for _, v := range virtuals {
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ const (
|
|||||||
wsTurnStateHeader = "x-codex-turn-state"
|
wsTurnStateHeader = "x-codex-turn-state"
|
||||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||||
wsPayloadLogMaxSize = 2048
|
wsPayloadLogMaxSize = 2048
|
||||||
|
wsBodyLogMaxSize = 64 * 1024
|
||||||
|
wsBodyLogTruncated = "\n[websocket log truncated]\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||||
@@ -825,18 +827,71 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
|
|||||||
if builder == nil {
|
if builder == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if builder.Len() >= wsBodyLogMaxSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
trimmedPayload := bytes.TrimSpace(payload)
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
if len(trimmedPayload) == 0 {
|
if len(trimmedPayload) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if builder.Len() > 0 {
|
if builder.Len() > 0 {
|
||||||
builder.WriteString("\n")
|
if !appendWebsocketLogString(builder, "\n") {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
builder.WriteString("websocket.")
|
}
|
||||||
builder.WriteString(eventType)
|
if !appendWebsocketLogString(builder, "websocket.") {
|
||||||
builder.WriteString("\n")
|
return
|
||||||
builder.Write(trimmedPayload)
|
}
|
||||||
builder.WriteString("\n")
|
if !appendWebsocketLogString(builder, eventType) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !appendWebsocketLogString(builder, "\n") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
|
||||||
|
appendWebsocketLogString(builder, wsBodyLogTruncated)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appendWebsocketLogString(builder, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
|
||||||
|
if builder == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remaining := wsBodyLogMaxSize - builder.Len()
|
||||||
|
if remaining <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(value) <= remaining {
|
||||||
|
builder.WriteString(value)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
builder.WriteString(value[:remaining])
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
|
||||||
|
if builder == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
remaining := wsBodyLogMaxSize - builder.Len()
|
||||||
|
if remaining <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(value) <= remaining {
|
||||||
|
builder.Write(value)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
limit := remaining - reserveForSuffix
|
||||||
|
if limit < 0 {
|
||||||
|
limit = 0
|
||||||
|
}
|
||||||
|
if limit > len(value) {
|
||||||
|
limit = len(value)
|
||||||
|
}
|
||||||
|
builder.Write(value[:limit])
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func websocketPayloadEventType(payload []byte) string {
|
func websocketPayloadEventType(payload []byte) string {
|
||||||
|
|||||||
@@ -266,6 +266,33 @@ func TestAppendWebsocketEvent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
|
||||||
|
|
||||||
|
appendWebsocketEvent(&builder, "request", payload)
|
||||||
|
|
||||||
|
got := builder.String()
|
||||||
|
if len(got) > wsBodyLogMaxSize {
|
||||||
|
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, wsBodyLogTruncated) {
|
||||||
|
t.Fatalf("expected truncation marker in body log")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
|
||||||
|
initial := builder.String()
|
||||||
|
|
||||||
|
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
|
||||||
|
|
||||||
|
if builder.String() != initial {
|
||||||
|
t.Fatalf("builder grew after reaching limit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|||||||
@@ -287,5 +287,8 @@ func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundl
|
|||||||
FileName: fileName,
|
FileName: fileName,
|
||||||
Storage: tokenStorage,
|
Storage: tokenStorage,
|
||||||
Metadata: metadata,
|
Metadata: metadata,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"plan_type": planType,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
+537
-74
@@ -134,6 +134,7 @@ type Manager struct {
|
|||||||
hook Hook
|
hook Hook
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
auths map[string]*Auth
|
auths map[string]*Auth
|
||||||
|
scheduler *authScheduler
|
||||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||||
providerOffsets map[string]int
|
providerOffsets map[string]int
|
||||||
|
|
||||||
@@ -149,6 +150,9 @@ type Manager struct {
|
|||||||
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
|
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
|
||||||
apiKeyModelAlias atomic.Value
|
apiKeyModelAlias atomic.Value
|
||||||
|
|
||||||
|
// modelPoolOffsets tracks per-auth alias pool rotation state.
|
||||||
|
modelPoolOffsets map[string]int
|
||||||
|
|
||||||
// runtimeConfig stores the latest application config for request-time decisions.
|
// runtimeConfig stores the latest application config for request-time decisions.
|
||||||
// It is initialized in NewManager; never Load() before first Store().
|
// It is initialized in NewManager; never Load() before first Store().
|
||||||
runtimeConfig atomic.Value
|
runtimeConfig atomic.Value
|
||||||
@@ -176,14 +180,59 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
|||||||
hook: hook,
|
hook: hook,
|
||||||
auths: make(map[string]*Auth),
|
auths: make(map[string]*Auth),
|
||||||
providerOffsets: make(map[string]int),
|
providerOffsets: make(map[string]int),
|
||||||
|
modelPoolOffsets: make(map[string]int),
|
||||||
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
||||||
}
|
}
|
||||||
// atomic.Value requires non-nil initial value.
|
// atomic.Value requires non-nil initial value.
|
||||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||||
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
|
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
|
||||||
|
manager.scheduler = newAuthScheduler(selector)
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isBuiltInSelector(selector Selector) bool {
|
||||||
|
switch selector.(type) {
|
||||||
|
case *RoundRobinSelector, *FillFirstSelector:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
|
||||||
|
if m == nil || m.scheduler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.scheduler.rebuild(auths)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) syncScheduler() {
|
||||||
|
if m == nil || m.scheduler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||||
|
// supportedModelSet is rebuilt from the current global model registry state.
|
||||||
|
// This must be called after models have been registered for a newly added auth,
|
||||||
|
// because the initial scheduler.upsertAuth during Register/Update runs before
|
||||||
|
// registerModelsForAuth and therefore snapshots an empty model set.
|
||||||
|
func (m *Manager) RefreshSchedulerEntry(authID string) {
|
||||||
|
if m == nil || m.scheduler == nil || authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.RLock()
|
||||||
|
auth, ok := m.auths[authID]
|
||||||
|
if !ok || auth == nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
snapshot := auth.Clone()
|
||||||
|
m.mu.RUnlock()
|
||||||
|
m.scheduler.upsertAuth(snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) SetSelector(selector Selector) {
|
func (m *Manager) SetSelector(selector Selector) {
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return
|
return
|
||||||
@@ -194,6 +243,10 @@ func (m *Manager) SetSelector(selector Selector) {
|
|||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.selector = selector
|
m.selector = selector
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
if m.scheduler != nil {
|
||||||
|
m.scheduler.setSelector(selector)
|
||||||
|
m.syncScheduler()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetStore swaps the underlying persistence store.
|
// SetStore swaps the underlying persistence store.
|
||||||
@@ -251,16 +304,323 @@ func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) strin
|
|||||||
if resolved == "" {
|
if resolved == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
// Preserve thinking suffix from the client's requested model unless config already has one.
|
return preserveRequestedModelSuffix(requestedModel, resolved)
|
||||||
requestResult := thinking.ParseSuffix(requestedModel)
|
}
|
||||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
|
||||||
return resolved
|
|
||||||
}
|
|
||||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
|
||||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
|
||||||
}
|
|
||||||
return resolved
|
|
||||||
|
|
||||||
|
func isAPIKeyAuth(auth *Auth) bool {
|
||||||
|
if auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
kind, _ := auth.AccountInfo()
|
||||||
|
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
|
||||||
|
if !isAPIKeyAuth(auth) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if auth.Attributes == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAICompatProviderKey(auth *Auth) string {
|
||||||
|
if auth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
|
||||||
|
return strings.ToLower(providerKey)
|
||||||
|
}
|
||||||
|
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
|
||||||
|
return strings.ToLower(compatName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
|
||||||
|
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
|
||||||
|
if base == "" {
|
||||||
|
base = strings.TrimSpace(requestedModel)
|
||||||
|
}
|
||||||
|
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) nextModelPoolOffset(key string, size int) int {
|
||||||
|
if m == nil || size <= 1 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.modelPoolOffsets == nil {
|
||||||
|
m.modelPoolOffsets = make(map[string]int)
|
||||||
|
}
|
||||||
|
offset := m.modelPoolOffsets[key]
|
||||||
|
if offset >= 2_147_483_640 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
m.modelPoolOffsets[key] = offset + 1
|
||||||
|
if size <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return offset % size
|
||||||
|
}
|
||||||
|
|
||||||
|
func rotateStrings(values []string, offset int) []string {
|
||||||
|
if len(values) <= 1 {
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
if offset <= 0 {
|
||||||
|
out := make([]string, len(values))
|
||||||
|
copy(out, values)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
offset = offset % len(values)
|
||||||
|
out := make([]string, 0, len(values))
|
||||||
|
out = append(out, values[offset:]...)
|
||||||
|
out = append(out, values[:offset]...)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
|
||||||
|
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
|
if requestedModel == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &internalconfig.Config{}
|
||||||
|
}
|
||||||
|
providerKey := ""
|
||||||
|
compatName := ""
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
|
||||||
|
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
|
||||||
|
}
|
||||||
|
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
|
||||||
|
if entry == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||||
|
}
|
||||||
|
|
||||||
|
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
|
||||||
|
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
|
||||||
|
return m.prepareExecutionModels(auth, routeModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
|
||||||
|
requestedModel := rewriteModelForAuth(routeModel, auth)
|
||||||
|
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
|
||||||
|
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
|
||||||
|
if len(pool) == 1 {
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
|
||||||
|
return rotateStrings(pool, offset)
|
||||||
|
}
|
||||||
|
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
|
||||||
|
if strings.TrimSpace(resolved) == "" {
|
||||||
|
resolved = requestedModel
|
||||||
|
}
|
||||||
|
return []string{resolved}
|
||||||
|
}
|
||||||
|
|
||||||
|
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
for range ch {
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
|
||||||
|
if ch == nil {
|
||||||
|
return nil, true, nil
|
||||||
|
}
|
||||||
|
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
|
||||||
|
for {
|
||||||
|
var (
|
||||||
|
chunk cliproxyexecutor.StreamChunk
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, false, ctx.Err()
|
||||||
|
case chunk, ok = <-ch:
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chunk, ok = <-ch
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return buffered, true, nil
|
||||||
|
}
|
||||||
|
if chunk.Err != nil {
|
||||||
|
return nil, false, chunk.Err
|
||||||
|
}
|
||||||
|
buffered = append(buffered, chunk)
|
||||||
|
if len(chunk.Payload) > 0 {
|
||||||
|
return buffered, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
var failed bool
|
||||||
|
forward := true
|
||||||
|
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
|
||||||
|
if chunk.Err != nil && !failed {
|
||||||
|
failed = true
|
||||||
|
rerr := &Error{Message: chunk.Err.Error()}
|
||||||
|
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
|
||||||
|
}
|
||||||
|
if !forward {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
out <- chunk
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
forward = false
|
||||||
|
return false
|
||||||
|
case out <- chunk:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, chunk := range buffered {
|
||||||
|
if ok := emit(chunk); !ok {
|
||||||
|
discardStreamChunks(remaining)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for chunk := range remaining {
|
||||||
|
if ok := emit(chunk); !ok {
|
||||||
|
discardStreamChunks(remaining)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !failed {
|
||||||
|
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
|
||||||
|
if executor == nil {
|
||||||
|
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||||
|
}
|
||||||
|
execModels := m.prepareExecutionModels(auth, routeModel)
|
||||||
|
var lastErr error
|
||||||
|
for idx, execModel := range execModels {
|
||||||
|
execReq := req
|
||||||
|
execReq.Model = execModel
|
||||||
|
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
|
||||||
|
if errStream != nil {
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
return nil, errCtx
|
||||||
|
}
|
||||||
|
rerr := &Error{Message: errStream.Error()}
|
||||||
|
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||||
|
result.RetryAfter = retryAfterFromError(errStream)
|
||||||
|
m.MarkResult(ctx, result)
|
||||||
|
if isRequestInvalidError(errStream) {
|
||||||
|
return nil, errStream
|
||||||
|
}
|
||||||
|
lastErr = errStream
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
|
||||||
|
if bootstrapErr != nil {
|
||||||
|
if errCtx := ctx.Err(); errCtx != nil {
|
||||||
|
discardStreamChunks(streamResult.Chunks)
|
||||||
|
return nil, errCtx
|
||||||
|
}
|
||||||
|
if isRequestInvalidError(bootstrapErr) {
|
||||||
|
rerr := &Error{Message: bootstrapErr.Error()}
|
||||||
|
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||||
|
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||||
|
m.MarkResult(ctx, result)
|
||||||
|
discardStreamChunks(streamResult.Chunks)
|
||||||
|
return nil, bootstrapErr
|
||||||
|
}
|
||||||
|
if idx < len(execModels)-1 {
|
||||||
|
rerr := &Error{Message: bootstrapErr.Error()}
|
||||||
|
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||||
|
result.RetryAfter = retryAfterFromError(bootstrapErr)
|
||||||
|
m.MarkResult(ctx, result)
|
||||||
|
discardStreamChunks(streamResult.Chunks)
|
||||||
|
lastErr = bootstrapErr
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||||
|
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
|
||||||
|
close(errCh)
|
||||||
|
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if closed && len(buffered) == 0 {
|
||||||
|
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
|
||||||
|
m.MarkResult(ctx, result)
|
||||||
|
if idx < len(execModels)-1 {
|
||||||
|
lastErr = emptyErr
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
|
||||||
|
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
|
||||||
|
close(errCh)
|
||||||
|
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
remaining := streamResult.Chunks
|
||||||
|
if closed {
|
||||||
|
closedCh := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
close(closedCh)
|
||||||
|
remaining = closedCh
|
||||||
|
}
|
||||||
|
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
|
||||||
|
}
|
||||||
|
if lastErr == nil {
|
||||||
|
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
|
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
|
||||||
@@ -448,10 +808,14 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
auth.ID = uuid.NewString()
|
auth.ID = uuid.NewString()
|
||||||
}
|
}
|
||||||
auth.EnsureIndex()
|
auth.EnsureIndex()
|
||||||
|
authClone := auth.Clone()
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.auths[auth.ID] = auth.Clone()
|
m.auths[auth.ID] = authClone
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||||
|
if m.scheduler != nil {
|
||||||
|
m.scheduler.upsertAuth(authClone)
|
||||||
|
}
|
||||||
_ = m.persist(ctx, auth)
|
_ = m.persist(ctx, auth)
|
||||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||||
return auth.Clone(), nil
|
return auth.Clone(), nil
|
||||||
@@ -473,9 +837,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
auth.EnsureIndex()
|
auth.EnsureIndex()
|
||||||
m.auths[auth.ID] = auth.Clone()
|
authClone := auth.Clone()
|
||||||
|
m.auths[auth.ID] = authClone
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||||
|
if m.scheduler != nil {
|
||||||
|
m.scheduler.upsertAuth(authClone)
|
||||||
|
}
|
||||||
_ = m.persist(ctx, auth)
|
_ = m.persist(ctx, auth)
|
||||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||||
return auth.Clone(), nil
|
return auth.Clone(), nil
|
||||||
@@ -484,12 +852,13 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
// Load resets manager state from the backing store.
|
// Load resets manager state from the backing store.
|
||||||
func (m *Manager) Load(ctx context.Context) error {
|
func (m *Manager) Load(ctx context.Context) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
|
||||||
if m.store == nil {
|
if m.store == nil {
|
||||||
|
m.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
items, err := m.store.List(ctx)
|
items, err := m.store.List(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
m.mu.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m.auths = make(map[string]*Auth, len(items))
|
m.auths = make(map[string]*Auth, len(items))
|
||||||
@@ -505,6 +874,8 @@ func (m *Manager) Load(ctx context.Context) error {
|
|||||||
cfg = &internalconfig.Config{}
|
cfg = &internalconfig.Config{}
|
||||||
}
|
}
|
||||||
m.rebuildAPIKeyModelAliasLocked(cfg)
|
m.rebuildAPIKeyModelAliasLocked(cfg)
|
||||||
|
m.mu.Unlock()
|
||||||
|
m.syncScheduler()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -634,10 +1005,12 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
|||||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
models := m.prepareExecutionModels(auth, routeModel)
|
||||||
|
var authErr error
|
||||||
|
for _, upstreamModel := range models {
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
execReq.Model = upstreamModel
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
||||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
if errExec != nil {
|
if errExec != nil {
|
||||||
@@ -655,12 +1028,20 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
|||||||
if isRequestInvalidError(errExec) {
|
if isRequestInvalidError(errExec) {
|
||||||
return cliproxyexecutor.Response{}, errExec
|
return cliproxyexecutor.Response{}, errExec
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
authErr = errExec
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.MarkResult(execCtx, result)
|
m.MarkResult(execCtx, result)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
if authErr != nil {
|
||||||
|
if isRequestInvalidError(authErr) {
|
||||||
|
return cliproxyexecutor.Response{}, authErr
|
||||||
|
}
|
||||||
|
lastErr = authErr
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
|
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
|
||||||
@@ -696,10 +1077,12 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
|||||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
models := m.prepareExecutionModels(auth, routeModel)
|
||||||
|
var authErr error
|
||||||
|
for _, upstreamModel := range models {
|
||||||
execReq := req
|
execReq := req
|
||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
execReq.Model = upstreamModel
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
||||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
if errExec != nil {
|
if errExec != nil {
|
||||||
@@ -717,12 +1100,20 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
|||||||
if isRequestInvalidError(errExec) {
|
if isRequestInvalidError(errExec) {
|
||||||
return cliproxyexecutor.Response{}, errExec
|
return cliproxyexecutor.Response{}, errExec
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
authErr = errExec
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.hook.OnResult(execCtx, result)
|
m.hook.OnResult(execCtx, result)
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
if authErr != nil {
|
||||||
|
if isRequestInvalidError(authErr) {
|
||||||
|
return cliproxyexecutor.Response{}, authErr
|
||||||
|
}
|
||||||
|
lastErr = authErr
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) {
|
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) {
|
||||||
@@ -758,63 +1149,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
|||||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
}
|
}
|
||||||
execReq := req
|
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
|
||||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
|
||||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
||||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
||||||
streamResult, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
|
||||||
if errStream != nil {
|
if errStream != nil {
|
||||||
if errCtx := execCtx.Err(); errCtx != nil {
|
if errCtx := execCtx.Err(); errCtx != nil {
|
||||||
return nil, errCtx
|
return nil, errCtx
|
||||||
}
|
}
|
||||||
rerr := &Error{Message: errStream.Error()}
|
|
||||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
|
|
||||||
rerr.HTTPStatus = se.StatusCode()
|
|
||||||
}
|
|
||||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
|
||||||
result.RetryAfter = retryAfterFromError(errStream)
|
|
||||||
m.MarkResult(execCtx, result)
|
|
||||||
if isRequestInvalidError(errStream) {
|
if isRequestInvalidError(errStream) {
|
||||||
return nil, errStream
|
return nil, errStream
|
||||||
}
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
return streamResult, nil
|
||||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
|
||||||
defer close(out)
|
|
||||||
var failed bool
|
|
||||||
forward := true
|
|
||||||
for chunk := range streamChunks {
|
|
||||||
if chunk.Err != nil && !failed {
|
|
||||||
failed = true
|
|
||||||
rerr := &Error{Message: chunk.Err.Error()}
|
|
||||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
|
|
||||||
rerr.HTTPStatus = se.StatusCode()
|
|
||||||
}
|
|
||||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
|
||||||
}
|
|
||||||
if !forward {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if streamCtx == nil {
|
|
||||||
out <- chunk
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-streamCtx.Done():
|
|
||||||
forward = false
|
|
||||||
case out <- chunk:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !failed {
|
|
||||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
|
||||||
}
|
|
||||||
}(execCtx, auth.Clone(), provider, streamResult.Chunks)
|
|
||||||
return &cliproxyexecutor.StreamResult{
|
|
||||||
Headers: streamResult.Headers,
|
|
||||||
Chunks: out,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1245,6 +1591,7 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|||||||
suspendReason := ""
|
suspendReason := ""
|
||||||
clearModelQuota := false
|
clearModelQuota := false
|
||||||
setModelQuota := false
|
setModelQuota := false
|
||||||
|
var authSnapshot *Auth
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
||||||
@@ -1338,8 +1685,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_ = m.persist(ctx, auth)
|
_ = m.persist(ctx, auth)
|
||||||
|
authSnapshot = auth.Clone()
|
||||||
}
|
}
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
if m.scheduler != nil && authSnapshot != nil {
|
||||||
|
m.scheduler.upsertAuth(authSnapshot)
|
||||||
|
}
|
||||||
|
|
||||||
if clearModelQuota && result.Model != "" {
|
if clearModelQuota && result.Model != "" {
|
||||||
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
|
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
|
||||||
@@ -1533,18 +1884,22 @@ func statusCodeFromResult(err *Error) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// isRequestInvalidError returns true if the error represents a client request
|
// isRequestInvalidError returns true if the error represents a client request
|
||||||
// error that should not be retried. Specifically, it checks for 400 Bad Request
|
// error that should not be retried. Specifically, it treats 400 responses with
|
||||||
// with "invalid_request_error" in the message, indicating the request itself is
|
// "invalid_request_error" and all 422 responses as request-shape failures,
|
||||||
// malformed and switching to a different auth will not help.
|
// where switching auths or pooled upstream models will not help.
|
||||||
func isRequestInvalidError(err error) bool {
|
func isRequestInvalidError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
status := statusCodeFromError(err)
|
status := statusCodeFromError(err)
|
||||||
if status != http.StatusBadRequest {
|
switch status {
|
||||||
|
case http.StatusBadRequest:
|
||||||
|
return strings.Contains(err.Error(), "invalid_request_error")
|
||||||
|
case http.StatusUnprocessableEntity:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return strings.Contains(err.Error(), "invalid_request_error")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
||||||
@@ -1692,7 +2047,29 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
func (m *Manager) useSchedulerFastPath() bool {
|
||||||
|
if m == nil || m.scheduler == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return isBuiltInSelector(m.selector)
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetrySchedulerPick(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var cooldownErr *modelCooldownError
|
||||||
|
if errors.As(err, &cooldownErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var authErr *Error
|
||||||
|
if !errors.As(err, &authErr) || authErr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
@@ -1752,7 +2129,38 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
|||||||
return authCopy, executor, nil
|
return authCopy, executor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||||
|
if !m.useSchedulerFastPath() {
|
||||||
|
return m.pickNextLegacy(ctx, provider, model, opts, tried)
|
||||||
|
}
|
||||||
|
executor, okExecutor := m.Executor(provider)
|
||||||
|
if !okExecutor {
|
||||||
|
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||||
|
}
|
||||||
|
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
|
||||||
|
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
|
||||||
|
m.syncScheduler()
|
||||||
|
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
|
||||||
|
}
|
||||||
|
if errPick != nil {
|
||||||
|
return nil, nil, errPick
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||||
|
}
|
||||||
|
authCopy := selected.Clone()
|
||||||
|
if !selected.indexAssigned {
|
||||||
|
m.mu.Lock()
|
||||||
|
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||||
|
current.EnsureIndex()
|
||||||
|
authCopy = current.Clone()
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
return authCopy, executor, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||||
|
|
||||||
providerSet := make(map[string]struct{}, len(providers))
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
@@ -1835,6 +2243,58 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
|||||||
return authCopy, executor, providerKey, nil
|
return authCopy, executor, providerKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||||
|
if !m.useSchedulerFastPath() {
|
||||||
|
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
|
||||||
|
}
|
||||||
|
|
||||||
|
eligibleProviders := make([]string, 0, len(providers))
|
||||||
|
seenProviders := make(map[string]struct{}, len(providers))
|
||||||
|
for _, provider := range providers {
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(provider))
|
||||||
|
if providerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, seen := seenProviders[providerKey]; seen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, okExecutor := m.Executor(providerKey); !okExecutor {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenProviders[providerKey] = struct{}{}
|
||||||
|
eligibleProviders = append(eligibleProviders, providerKey)
|
||||||
|
}
|
||||||
|
if len(eligibleProviders) == 0 {
|
||||||
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
|
||||||
|
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
|
||||||
|
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
|
||||||
|
m.syncScheduler()
|
||||||
|
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
|
||||||
|
}
|
||||||
|
if errPick != nil {
|
||||||
|
return nil, nil, "", errPick
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||||
|
}
|
||||||
|
executor, okExecutor := m.Executor(providerKey)
|
||||||
|
if !okExecutor {
|
||||||
|
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||||
|
}
|
||||||
|
authCopy := selected.Clone()
|
||||||
|
if !selected.indexAssigned {
|
||||||
|
m.mu.Lock()
|
||||||
|
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||||
|
current.EnsureIndex()
|
||||||
|
authCopy = current.Clone()
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
return authCopy, executor, providerKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||||
if m.store == nil || auth == nil {
|
if m.store == nil || auth == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -2186,6 +2646,9 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
|||||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||||
current.LastError = &Error{Message: err.Error()}
|
current.LastError = &Error{Message: err.Error()}
|
||||||
m.auths[id] = current
|
m.auths[id] = current
|
||||||
|
if m.scheduler != nil {
|
||||||
|
m.scheduler.upsertAuth(current.Clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerProviderTestExecutor struct {
|
||||||
|
provider string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
prime func(*Manager, *Auth) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "register",
|
||||||
|
prime: func(manager *Manager, auth *Auth) error {
|
||||||
|
_, errRegister := manager.Register(ctx, auth)
|
||||||
|
return errRegister
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update",
|
||||||
|
prime: func(manager *Manager, auth *Auth) error {
|
||||||
|
_, errRegister := manager.Register(ctx, auth)
|
||||||
|
if errRegister != nil {
|
||||||
|
return errRegister
|
||||||
|
}
|
||||||
|
updated := auth.Clone()
|
||||||
|
updated.Metadata = map[string]any{"updated": true}
|
||||||
|
_, errUpdate := manager.Update(ctx, updated)
|
||||||
|
return errUpdate
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, testCase := range testCases {
|
||||||
|
testCase := testCase
|
||||||
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "refresh-entry-" + testCase.name,
|
||||||
|
Provider: "gemini",
|
||||||
|
}
|
||||||
|
if errPrime := testCase.prime(manager, auth); errPrime != nil {
|
||||||
|
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
|
||||||
|
}
|
||||||
|
|
||||||
|
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
|
||||||
|
|
||||||
|
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
var authErr *Error
|
||||||
|
if !errors.As(errPick, &authErr) || authErr == nil {
|
||||||
|
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
|
||||||
|
}
|
||||||
|
if authErr.Code != "auth_not_found" {
|
||||||
|
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.RefreshSchedulerEntry(auth.ID)
|
||||||
|
|
||||||
|
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() after refresh error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != auth.ID {
|
||||||
|
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
|
||||||
|
|
||||||
|
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
|
||||||
|
|
||||||
|
oldAuth := &Auth{
|
||||||
|
ID: "cooldown-stale-old",
|
||||||
|
Provider: "gemini",
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register old auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.MarkResult(ctx, Result{
|
||||||
|
AuthID: oldAuth.ID,
|
||||||
|
Provider: "gemini",
|
||||||
|
Model: "scheduler-cooldown-rebuild-model",
|
||||||
|
Success: false,
|
||||||
|
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
|
||||||
|
})
|
||||||
|
|
||||||
|
newAuth := &Auth{
|
||||||
|
ID: "cooldown-stale-new",
|
||||||
|
Provider: "gemini",
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
|
||||||
|
t.Fatalf("register new auth: %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(newAuth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
var cooldownErr *modelCooldownError
|
||||||
|
if !errors.As(errPick, &cooldownErr) {
|
||||||
|
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
|
||||||
|
}
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickNext() error = %v", errPick)
|
||||||
|
}
|
||||||
|
if executor == nil {
|
||||||
|
t.Fatal("pickNext() executor = nil")
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != newAuth.ID {
|
||||||
|
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -80,23 +80,24 @@ func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string
|
|||||||
return upstreamModel
|
return upstreamModel
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
|
func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
|
||||||
requestedModel = strings.TrimSpace(requestedModel)
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
if requestedModel == "" {
|
if requestedModel == "" {
|
||||||
return ""
|
return thinking.SuffixResult{}, nil
|
||||||
}
|
}
|
||||||
if len(models) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
requestResult := thinking.ParseSuffix(requestedModel)
|
requestResult := thinking.ParseSuffix(requestedModel)
|
||||||
base := requestResult.ModelName
|
base := requestResult.ModelName
|
||||||
|
if base == "" {
|
||||||
|
base = requestedModel
|
||||||
|
}
|
||||||
candidates := []string{base}
|
candidates := []string{base}
|
||||||
if base != requestedModel {
|
if base != requestedModel {
|
||||||
candidates = append(candidates, requestedModel)
|
candidates = append(candidates, requestedModel)
|
||||||
}
|
}
|
||||||
|
return requestResult, candidates
|
||||||
|
}
|
||||||
|
|
||||||
preserveSuffix := func(resolved string) string {
|
func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
|
||||||
resolved = strings.TrimSpace(resolved)
|
resolved = strings.TrimSpace(resolved)
|
||||||
if resolved == "" {
|
if resolved == "" {
|
||||||
return ""
|
return ""
|
||||||
@@ -108,25 +109,68 @@ func resolveModelAliasFromConfigModels(requestedModel string, models []modelAlia
|
|||||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||||
}
|
}
|
||||||
return resolved
|
return resolved
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
|
||||||
|
requestedModel = strings.TrimSpace(requestedModel)
|
||||||
|
if requestedModel == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]string, 0)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
for i := range models {
|
for i := range models {
|
||||||
name := strings.TrimSpace(models[i].GetName())
|
name := strings.TrimSpace(models[i].GetName())
|
||||||
alias := strings.TrimSpace(models[i].GetAlias())
|
alias := strings.TrimSpace(models[i].GetAlias())
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
if candidate == "" {
|
if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if alias != "" && strings.EqualFold(alias, candidate) {
|
resolved := candidate
|
||||||
if name != "" {
|
if name != "" {
|
||||||
return preserveSuffix(name)
|
resolved = name
|
||||||
}
|
}
|
||||||
return preserveSuffix(candidate)
|
resolved = preserveResolvedModelSuffix(resolved, requestResult)
|
||||||
|
key := strings.ToLower(strings.TrimSpace(resolved))
|
||||||
|
if key == "" {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
if name != "" && strings.EqualFold(name, candidate) {
|
if _, exists := seen[key]; exists {
|
||||||
return preserveSuffix(name)
|
break
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
out = append(out, resolved)
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(out) > 0 {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range models {
|
||||||
|
name := strings.TrimSpace(models[i].GetName())
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return []string{preserveResolvedModelSuffix(name, requestResult)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
|
||||||
|
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
|
||||||
|
if len(resolved) > 0 {
|
||||||
|
return resolved[0]
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,419 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAICompatPoolExecutor struct {
|
||||||
|
id string
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
executeModels []string
|
||||||
|
countModels []string
|
||||||
|
streamModels []string
|
||||||
|
executeErrors map[string]error
|
||||||
|
countErrors map[string]error
|
||||||
|
streamFirstErrors map[string]error
|
||||||
|
streamPayloads map[string][]cliproxyexecutor.StreamChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = auth
|
||||||
|
_ = opts
|
||||||
|
e.mu.Lock()
|
||||||
|
e.executeModels = append(e.executeModels, req.Model)
|
||||||
|
err := e.executeErrors[req.Model]
|
||||||
|
e.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = auth
|
||||||
|
_ = opts
|
||||||
|
e.mu.Lock()
|
||||||
|
e.streamModels = append(e.streamModels, req.Model)
|
||||||
|
err := e.streamFirstErrors[req.Model]
|
||||||
|
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
|
||||||
|
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
|
||||||
|
e.mu.Unlock()
|
||||||
|
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
|
||||||
|
if err != nil {
|
||||||
|
ch <- cliproxyexecutor.StreamChunk{Err: err}
|
||||||
|
close(ch)
|
||||||
|
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
if !hasCustomChunks {
|
||||||
|
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
|
||||||
|
} else {
|
||||||
|
for _, chunk := range chunks {
|
||||||
|
ch <- chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = auth
|
||||||
|
_ = opts
|
||||||
|
e.mu.Lock()
|
||||||
|
e.countModels = append(e.countModels, req.Model)
|
||||||
|
err := e.countErrors[req.Model]
|
||||||
|
e.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = auth
|
||||||
|
_ = req
|
||||||
|
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
out := make([]string, len(e.executeModels))
|
||||||
|
copy(out, e.executeModels)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) CountModels() []string {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
out := make([]string, len(e.countModels))
|
||||||
|
copy(out, e.countModels)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *openAICompatPoolExecutor) StreamModels() []string {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
out := make([]string, len(e.streamModels))
|
||||||
|
copy(out, e.streamModels)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
cfg := &internalconfig.Config{
|
||||||
|
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
|
||||||
|
Name: "pool",
|
||||||
|
Models: models,
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
m := NewManager(nil, nil, nil)
|
||||||
|
m.SetConfig(cfg)
|
||||||
|
if executor == nil {
|
||||||
|
executor = &openAICompatPoolExecutor{id: "pool"}
|
||||||
|
}
|
||||||
|
m.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "pool-auth-" + t.Name(),
|
||||||
|
Provider: "pool",
|
||||||
|
Status: StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"api_key": "test-key",
|
||||||
|
"compat_name": "pool",
|
||||||
|
"provider_key": "pool",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if _, err := m.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("register auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err == nil || err.Error() != invalidErr.Error() {
|
||||||
|
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
|
||||||
|
}
|
||||||
|
got := executor.CountModels()
|
||||||
|
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||||
|
t.Fatalf("count calls = %v, want only first invalid model", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
|
||||||
|
models := []modelAliasEntry{
|
||||||
|
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
|
||||||
|
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
|
||||||
|
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
|
||||||
|
}
|
||||||
|
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
|
||||||
|
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute %d: %v", i, err)
|
||||||
|
}
|
||||||
|
if len(resp.Payload) == 0 {
|
||||||
|
t.Fatalf("execute %d returned empty payload", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
got := executor.ExecuteModels()
|
||||||
|
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
|
||||||
|
if len(got) != len(want) {
|
||||||
|
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err == nil || err.Error() != invalidErr.Error() {
|
||||||
|
t.Fatalf("execute error = %v, want %v", err, invalidErr)
|
||||||
|
}
|
||||||
|
got := executor.ExecuteModels()
|
||||||
|
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||||
|
t.Fatalf("execute calls = %v, want only first invalid model", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute: %v", err)
|
||||||
|
}
|
||||||
|
if string(resp.Payload) != "glm-5" {
|
||||||
|
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
|
||||||
|
}
|
||||||
|
got := executor.ExecuteModels()
|
||||||
|
want := []string{"qwen3.5-plus", "glm-5"}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
|
||||||
|
"qwen3.5-plus": {},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute stream: %v", err)
|
||||||
|
}
|
||||||
|
var payload []byte
|
||||||
|
for chunk := range streamResult.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
payload = append(payload, chunk.Payload...)
|
||||||
|
}
|
||||||
|
if string(payload) != "glm-5" {
|
||||||
|
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
|
||||||
|
}
|
||||||
|
got := executor.StreamModels()
|
||||||
|
want := []string{"qwen3.5-plus", "glm-5"}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute stream: %v", err)
|
||||||
|
}
|
||||||
|
var payload []byte
|
||||||
|
for chunk := range streamResult.Chunks {
|
||||||
|
if chunk.Err != nil {
|
||||||
|
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||||
|
}
|
||||||
|
payload = append(payload, chunk.Payload...)
|
||||||
|
}
|
||||||
|
if string(payload) != "glm-5" {
|
||||||
|
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
|
||||||
|
}
|
||||||
|
got := executor.StreamModels()
|
||||||
|
want := []string{"qwen3.5-plus", "glm-5"}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
|
||||||
|
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err == nil || err.Error() != invalidErr.Error() {
|
||||||
|
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
|
||||||
|
}
|
||||||
|
got := executor.StreamModels()
|
||||||
|
if len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||||
|
t.Fatalf("stream calls = %v, want only first invalid model", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
executor := &openAICompatPoolExecutor{id: "pool"}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("execute count %d: %v", i, err)
|
||||||
|
}
|
||||||
|
if len(resp.Payload) == 0 {
|
||||||
|
t.Fatalf("execute count %d returned empty payload", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
got := executor.CountModels()
|
||||||
|
want := []string{"qwen3.5-plus", "glm-5"}
|
||||||
|
for i := range want {
|
||||||
|
if got[i] != want[i] {
|
||||||
|
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
|
||||||
|
alias := "claude-opus-4.66"
|
||||||
|
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
|
||||||
|
executor := &openAICompatPoolExecutor{
|
||||||
|
id: "pool",
|
||||||
|
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
|
||||||
|
}
|
||||||
|
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
|
||||||
|
{Name: "qwen3.5-plus", Alias: alias},
|
||||||
|
{Name: "glm-5", Alias: alias},
|
||||||
|
}, executor)
|
||||||
|
|
||||||
|
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected invalid request error")
|
||||||
|
}
|
||||||
|
if err != invalidErr {
|
||||||
|
t.Fatalf("error = %v, want %v", err, invalidErr)
|
||||||
|
}
|
||||||
|
if streamResult != nil {
|
||||||
|
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
|
||||||
|
}
|
||||||
|
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
|
||||||
|
t.Fatalf("stream calls = %v, want only first upstream model", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,904 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
|
||||||
|
type schedulerStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
schedulerStrategyCustom schedulerStrategy = iota
|
||||||
|
schedulerStrategyRoundRobin
|
||||||
|
schedulerStrategyFillFirst
|
||||||
|
)
|
||||||
|
|
||||||
|
// scheduledState describes how an auth currently participates in a model shard.
|
||||||
|
type scheduledState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
scheduledStateReady scheduledState = iota
|
||||||
|
scheduledStateCooldown
|
||||||
|
scheduledStateBlocked
|
||||||
|
scheduledStateDisabled
|
||||||
|
)
|
||||||
|
|
||||||
|
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
|
||||||
|
type authScheduler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
strategy schedulerStrategy
|
||||||
|
providers map[string]*providerScheduler
|
||||||
|
authProviders map[string]string
|
||||||
|
mixedCursors map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerScheduler stores auth metadata and model shards for a single provider.
|
||||||
|
type providerScheduler struct {
|
||||||
|
providerKey string
|
||||||
|
auths map[string]*scheduledAuthMeta
|
||||||
|
modelShards map[string]*modelScheduler
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
|
||||||
|
type scheduledAuthMeta struct {
|
||||||
|
auth *Auth
|
||||||
|
providerKey string
|
||||||
|
priority int
|
||||||
|
virtualParent string
|
||||||
|
websocketEnabled bool
|
||||||
|
supportedModelSet map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelScheduler tracks ready and blocked auths for one provider/model combination.
|
||||||
|
type modelScheduler struct {
|
||||||
|
modelKey string
|
||||||
|
entries map[string]*scheduledAuth
|
||||||
|
priorityOrder []int
|
||||||
|
readyByPriority map[int]*readyBucket
|
||||||
|
blocked cooldownQueue
|
||||||
|
}
|
||||||
|
|
||||||
|
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
|
||||||
|
type scheduledAuth struct {
|
||||||
|
meta *scheduledAuthMeta
|
||||||
|
auth *Auth
|
||||||
|
state scheduledState
|
||||||
|
nextRetryAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// readyBucket keeps the ready views for one priority level.
|
||||||
|
type readyBucket struct {
|
||||||
|
all readyView
|
||||||
|
ws readyView
|
||||||
|
}
|
||||||
|
|
||||||
|
// readyView holds the selection order for flat or grouped round-robin traversal.
|
||||||
|
type readyView struct {
|
||||||
|
flat []*scheduledAuth
|
||||||
|
cursor int
|
||||||
|
parentOrder []string
|
||||||
|
parentCursor int
|
||||||
|
children map[string]*childBucket
|
||||||
|
}
|
||||||
|
|
||||||
|
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
|
||||||
|
type childBucket struct {
|
||||||
|
items []*scheduledAuth
|
||||||
|
cursor int
|
||||||
|
}
|
||||||
|
|
||||||
|
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
|
||||||
|
type cooldownQueue []*scheduledAuth
|
||||||
|
|
||||||
|
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
|
||||||
|
func newAuthScheduler(selector Selector) *authScheduler {
|
||||||
|
return &authScheduler{
|
||||||
|
strategy: selectorStrategy(selector),
|
||||||
|
providers: make(map[string]*providerScheduler),
|
||||||
|
authProviders: make(map[string]string),
|
||||||
|
mixedCursors: make(map[string]int),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
|
||||||
|
func selectorStrategy(selector Selector) schedulerStrategy {
|
||||||
|
switch selector.(type) {
|
||||||
|
case *FillFirstSelector:
|
||||||
|
return schedulerStrategyFillFirst
|
||||||
|
case nil, *RoundRobinSelector:
|
||||||
|
return schedulerStrategyRoundRobin
|
||||||
|
default:
|
||||||
|
return schedulerStrategyCustom
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
|
||||||
|
func (s *authScheduler) setSelector(selector Selector) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.strategy = selectorStrategy(selector)
|
||||||
|
clear(s.mixedCursors)
|
||||||
|
}
|
||||||
|
|
||||||
|
// rebuild recreates the complete scheduler state from an auth snapshot.
|
||||||
|
func (s *authScheduler) rebuild(auths []*Auth) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.providers = make(map[string]*providerScheduler)
|
||||||
|
s.authProviders = make(map[string]string)
|
||||||
|
s.mixedCursors = make(map[string]int)
|
||||||
|
now := time.Now()
|
||||||
|
for _, auth := range auths {
|
||||||
|
s.upsertAuthLocked(auth, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// upsertAuth incrementally synchronizes one auth into the scheduler.
|
||||||
|
func (s *authScheduler) upsertAuth(auth *Auth) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.upsertAuthLocked(auth, time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeAuth deletes one auth from every scheduler shard that references it.
|
||||||
|
func (s *authScheduler) removeAuth(authID string) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.removeAuthLocked(authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickSingle returns the next auth for a single provider/model request using scheduler state.
|
||||||
|
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
modelKey := canonicalModelKey(model)
|
||||||
|
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||||
|
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
providerState := s.providers[providerKey]
|
||||||
|
if providerState == nil {
|
||||||
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||||
|
if shard == nil {
|
||||||
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
predicate := func(entry *scheduledAuth) bool {
|
||||||
|
if entry == nil || entry.auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(tried) > 0 {
|
||||||
|
if _, ok := tried[entry.auth.ID]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
|
||||||
|
return picked, nil
|
||||||
|
}
|
||||||
|
return nil, shard.unavailableErrorLocked(provider, model, predicate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickMixed returns the next auth and provider for a mixed-provider request.
|
||||||
|
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
normalized := normalizeProviderKeys(providers)
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||||
|
modelKey := canonicalModelKey(model)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if pinnedAuthID != "" {
|
||||||
|
providerKey := s.authProviders[pinnedAuthID]
|
||||||
|
if providerKey == "" || !containsProvider(normalized, providerKey) {
|
||||||
|
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
providerState := s.providers[providerKey]
|
||||||
|
if providerState == nil {
|
||||||
|
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||||
|
predicate := func(entry *scheduledAuth) bool {
|
||||||
|
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(tried) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
_, ok := tried[pinnedAuthID]
|
||||||
|
return !ok
|
||||||
|
}
|
||||||
|
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
|
||||||
|
return picked, providerKey, nil
|
||||||
|
}
|
||||||
|
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
||||||
|
}
|
||||||
|
|
||||||
|
predicate := triedPredicate(tried)
|
||||||
|
candidateShards := make([]*modelScheduler, len(normalized))
|
||||||
|
bestPriority := 0
|
||||||
|
hasCandidate := false
|
||||||
|
now := time.Now()
|
||||||
|
for providerIndex, providerKey := range normalized {
|
||||||
|
providerState := s.providers[providerKey]
|
||||||
|
if providerState == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shard := providerState.ensureModelLocked(modelKey, now)
|
||||||
|
candidateShards[providerIndex] = shard
|
||||||
|
if shard == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
|
||||||
|
if !okPriority {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !hasCandidate || priorityReady > bestPriority {
|
||||||
|
bestPriority = priorityReady
|
||||||
|
hasCandidate = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasCandidate {
|
||||||
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.strategy == schedulerStrategyFillFirst {
|
||||||
|
for providerIndex, providerKey := range normalized {
|
||||||
|
shard := candidateShards[providerIndex]
|
||||||
|
if shard == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
|
||||||
|
if picked != nil {
|
||||||
|
return picked, providerKey, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
|
}
|
||||||
|
|
||||||
|
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||||
|
start := 0
|
||||||
|
if len(normalized) > 0 {
|
||||||
|
start = s.mixedCursors[cursorKey] % len(normalized)
|
||||||
|
}
|
||||||
|
for offset := 0; offset < len(normalized); offset++ {
|
||||||
|
providerIndex := (start + offset) % len(normalized)
|
||||||
|
providerKey := normalized[providerIndex]
|
||||||
|
shard := candidateShards[providerIndex]
|
||||||
|
if shard == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
|
||||||
|
if picked == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.mixedCursors[cursorKey] = providerIndex + 1
|
||||||
|
return picked, providerKey, nil
|
||||||
|
}
|
||||||
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
|
||||||
|
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
|
||||||
|
now := time.Now()
|
||||||
|
total := 0
|
||||||
|
cooldownCount := 0
|
||||||
|
earliest := time.Time{}
|
||||||
|
for _, providerKey := range providers {
|
||||||
|
providerState := s.providers[providerKey]
|
||||||
|
if providerState == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
|
||||||
|
if shard == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
|
||||||
|
total += localTotal
|
||||||
|
cooldownCount += localCooldownCount
|
||||||
|
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
|
||||||
|
earliest = localEarliest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if total == 0 {
|
||||||
|
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
if cooldownCount == total && !earliest.IsZero() {
|
||||||
|
resetIn := earliest.Sub(now)
|
||||||
|
if resetIn < 0 {
|
||||||
|
resetIn = 0
|
||||||
|
}
|
||||||
|
return newModelCooldownError(model, "", resetIn)
|
||||||
|
}
|
||||||
|
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// triedPredicate builds a filter that excludes auths already attempted for the current request.
|
||||||
|
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
|
||||||
|
if len(tried) == 0 {
|
||||||
|
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
|
||||||
|
}
|
||||||
|
return func(entry *scheduledAuth) bool {
|
||||||
|
if entry == nil || entry.auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := tried[entry.auth.ID]
|
||||||
|
return !ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
|
||||||
|
func normalizeProviderKeys(providers []string) []string {
|
||||||
|
seen := make(map[string]struct{}, len(providers))
|
||||||
|
out := make([]string, 0, len(providers))
|
||||||
|
for _, provider := range providers {
|
||||||
|
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
if providerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[providerKey]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[providerKey] = struct{}{}
|
||||||
|
out = append(out, providerKey)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsProvider reports whether provider is present in the normalized provider list.
|
||||||
|
func containsProvider(providers []string, provider string) bool {
|
||||||
|
for _, candidate := range providers {
|
||||||
|
if candidate == provider {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
|
||||||
|
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
|
||||||
|
if auth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authID := strings.TrimSpace(auth.ID)
|
||||||
|
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||||
|
if authID == "" || providerKey == "" || auth.Disabled {
|
||||||
|
s.removeAuthLocked(authID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
|
||||||
|
if previousState := s.providers[previousProvider]; previousState != nil {
|
||||||
|
previousState.removeAuthLocked(authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
meta := buildScheduledAuthMeta(auth)
|
||||||
|
s.authProviders[authID] = providerKey
|
||||||
|
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
|
||||||
|
func (s *authScheduler) removeAuthLocked(authID string) {
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if providerKey := s.authProviders[authID]; providerKey != "" {
|
||||||
|
if providerState := s.providers[providerKey]; providerState != nil {
|
||||||
|
providerState.removeAuthLocked(authID)
|
||||||
|
}
|
||||||
|
delete(s.authProviders, authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
|
||||||
|
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
|
||||||
|
if s.providers == nil {
|
||||||
|
s.providers = make(map[string]*providerScheduler)
|
||||||
|
}
|
||||||
|
providerState := s.providers[providerKey]
|
||||||
|
if providerState == nil {
|
||||||
|
providerState = &providerScheduler{
|
||||||
|
providerKey: providerKey,
|
||||||
|
auths: make(map[string]*scheduledAuthMeta),
|
||||||
|
modelShards: make(map[string]*modelScheduler),
|
||||||
|
}
|
||||||
|
s.providers[providerKey] = providerState
|
||||||
|
}
|
||||||
|
return providerState
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
|
||||||
|
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
|
||||||
|
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||||
|
virtualParent := ""
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
|
||||||
|
}
|
||||||
|
return &scheduledAuthMeta{
|
||||||
|
auth: auth,
|
||||||
|
providerKey: providerKey,
|
||||||
|
priority: authPriority(auth),
|
||||||
|
virtualParent: virtualParent,
|
||||||
|
websocketEnabled: authWebsocketsEnabled(auth),
|
||||||
|
supportedModelSet: supportedModelSetForAuth(auth.ID),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
|
||||||
|
func supportedModelSetForAuth(authID string) map[string]struct{} {
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
set := make(map[string]struct{}, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
modelKey := canonicalModelKey(model.ID)
|
||||||
|
if modelKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
set[modelKey] = struct{}{}
|
||||||
|
}
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
|
||||||
|
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||||
|
if p == nil || meta == nil || meta.auth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.auths[meta.auth.ID] = meta
|
||||||
|
for modelKey, shard := range p.modelShards {
|
||||||
|
if shard == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !meta.supportsModel(modelKey) {
|
||||||
|
shard.removeEntryLocked(meta.auth.ID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shard.upsertEntryLocked(meta, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
|
||||||
|
func (p *providerScheduler) removeAuthLocked(authID string) {
|
||||||
|
if p == nil || authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(p.auths, authID)
|
||||||
|
for _, shard := range p.modelShards {
|
||||||
|
if shard != nil {
|
||||||
|
shard.removeEntryLocked(authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
|
||||||
|
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
|
||||||
|
if p == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
modelKey = canonicalModelKey(modelKey)
|
||||||
|
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
|
||||||
|
shard.promoteExpiredLocked(now)
|
||||||
|
return shard
|
||||||
|
}
|
||||||
|
shard := &modelScheduler{
|
||||||
|
modelKey: modelKey,
|
||||||
|
entries: make(map[string]*scheduledAuth),
|
||||||
|
readyByPriority: make(map[int]*readyBucket),
|
||||||
|
}
|
||||||
|
for _, meta := range p.auths {
|
||||||
|
if meta == nil || !meta.supportsModel(modelKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
shard.upsertEntryLocked(meta, now)
|
||||||
|
}
|
||||||
|
p.modelShards[modelKey] = shard
|
||||||
|
return shard
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportsModel reports whether the auth metadata currently supports modelKey.
|
||||||
|
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
|
||||||
|
modelKey = canonicalModelKey(modelKey)
|
||||||
|
if modelKey == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if len(m.supportedModelSet) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := m.supportedModelSet[modelKey]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
|
||||||
|
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||||
|
if m == nil || meta == nil || meta.auth == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry, ok := m.entries[meta.auth.ID]
|
||||||
|
if !ok || entry == nil {
|
||||||
|
entry = &scheduledAuth{}
|
||||||
|
m.entries[meta.auth.ID] = entry
|
||||||
|
}
|
||||||
|
previousState := entry.state
|
||||||
|
previousNextRetryAt := entry.nextRetryAt
|
||||||
|
previousPriority := 0
|
||||||
|
previousParent := ""
|
||||||
|
previousWebsocketEnabled := false
|
||||||
|
if entry.meta != nil {
|
||||||
|
previousPriority = entry.meta.priority
|
||||||
|
previousParent = entry.meta.virtualParent
|
||||||
|
previousWebsocketEnabled = entry.meta.websocketEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.meta = meta
|
||||||
|
entry.auth = meta.auth
|
||||||
|
entry.nextRetryAt = time.Time{}
|
||||||
|
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
|
||||||
|
switch {
|
||||||
|
case !blocked:
|
||||||
|
entry.state = scheduledStateReady
|
||||||
|
case reason == blockReasonCooldown:
|
||||||
|
entry.state = scheduledStateCooldown
|
||||||
|
entry.nextRetryAt = next
|
||||||
|
case reason == blockReasonDisabled:
|
||||||
|
entry.state = scheduledStateDisabled
|
||||||
|
default:
|
||||||
|
entry.state = scheduledStateBlocked
|
||||||
|
entry.nextRetryAt = next
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.rebuildIndexesLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
|
||||||
|
func (m *modelScheduler) removeEntryLocked(authID string) {
|
||||||
|
if m == nil || authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := m.entries[authID]; !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(m.entries, authID)
|
||||||
|
m.rebuildIndexesLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
|
||||||
|
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
|
||||||
|
if m == nil || len(m.blocked) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
changed := false
|
||||||
|
for _, entry := range m.blocked {
|
||||||
|
if entry == nil || entry.auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
|
||||||
|
switch {
|
||||||
|
case !blocked:
|
||||||
|
entry.state = scheduledStateReady
|
||||||
|
entry.nextRetryAt = time.Time{}
|
||||||
|
case reason == blockReasonCooldown:
|
||||||
|
entry.state = scheduledStateCooldown
|
||||||
|
entry.nextRetryAt = next
|
||||||
|
case reason == blockReasonDisabled:
|
||||||
|
entry.state = scheduledStateDisabled
|
||||||
|
entry.nextRetryAt = time.Time{}
|
||||||
|
default:
|
||||||
|
entry.state = scheduledStateBlocked
|
||||||
|
entry.nextRetryAt = next
|
||||||
|
}
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if changed {
|
||||||
|
m.rebuildIndexesLocked()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
|
||||||
|
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.promoteExpiredLocked(time.Now())
|
||||||
|
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
|
||||||
|
if !okPriority {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
|
||||||
|
// The caller must ensure expired entries are already promoted when needed.
|
||||||
|
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
|
||||||
|
if m == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
for _, priority := range m.priorityOrder {
|
||||||
|
bucket := m.readyByPriority[priority]
|
||||||
|
if bucket == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
view := &bucket.all
|
||||||
|
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||||
|
view = &bucket.ws
|
||||||
|
}
|
||||||
|
if view.pickFirst(predicate) != nil {
|
||||||
|
return priority, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
|
||||||
|
// The caller must ensure expired entries are already promoted when needed.
|
||||||
|
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bucket := m.readyByPriority[priority]
|
||||||
|
if bucket == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
view := &bucket.all
|
||||||
|
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||||
|
view = &bucket.ws
|
||||||
|
}
|
||||||
|
var picked *scheduledAuth
|
||||||
|
if strategy == schedulerStrategyFillFirst {
|
||||||
|
picked = view.pickFirst(predicate)
|
||||||
|
} else {
|
||||||
|
picked = view.pickRoundRobin(predicate)
|
||||||
|
}
|
||||||
|
if picked == nil || picked.auth == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return picked.auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||||
|
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||||
|
now := time.Now()
|
||||||
|
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
|
||||||
|
if total == 0 {
|
||||||
|
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
if cooldownCount == total && !earliest.IsZero() {
|
||||||
|
providerForError := provider
|
||||||
|
if providerForError == "mixed" {
|
||||||
|
providerForError = ""
|
||||||
|
}
|
||||||
|
resetIn := earliest.Sub(now)
|
||||||
|
if resetIn < 0 {
|
||||||
|
resetIn = 0
|
||||||
|
}
|
||||||
|
return newModelCooldownError(model, providerForError, resetIn)
|
||||||
|
}
|
||||||
|
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
|
||||||
|
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
|
||||||
|
if m == nil {
|
||||||
|
return 0, 0, time.Time{}
|
||||||
|
}
|
||||||
|
total := 0
|
||||||
|
cooldownCount := 0
|
||||||
|
earliest := time.Time{}
|
||||||
|
for _, entry := range m.entries {
|
||||||
|
if predicate != nil && !predicate(entry) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
total++
|
||||||
|
if entry == nil || entry.auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if entry.state != scheduledStateCooldown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cooldownCount++
|
||||||
|
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
|
||||||
|
earliest = entry.nextRetryAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total, cooldownCount, earliest
|
||||||
|
}
|
||||||
|
|
||||||
|
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
|
||||||
|
func (m *modelScheduler) rebuildIndexesLocked() {
|
||||||
|
m.readyByPriority = make(map[int]*readyBucket)
|
||||||
|
m.priorityOrder = m.priorityOrder[:0]
|
||||||
|
m.blocked = m.blocked[:0]
|
||||||
|
priorityBuckets := make(map[int][]*scheduledAuth)
|
||||||
|
for _, entry := range m.entries {
|
||||||
|
if entry == nil || entry.auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch entry.state {
|
||||||
|
case scheduledStateReady:
|
||||||
|
priority := entry.meta.priority
|
||||||
|
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
|
||||||
|
case scheduledStateCooldown, scheduledStateBlocked:
|
||||||
|
m.blocked = append(m.blocked, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for priority, entries := range priorityBuckets {
|
||||||
|
sort.Slice(entries, func(i, j int) bool {
|
||||||
|
return entries[i].auth.ID < entries[j].auth.ID
|
||||||
|
})
|
||||||
|
m.readyByPriority[priority] = buildReadyBucket(entries)
|
||||||
|
m.priorityOrder = append(m.priorityOrder, priority)
|
||||||
|
}
|
||||||
|
sort.Slice(m.priorityOrder, func(i, j int) bool {
|
||||||
|
return m.priorityOrder[i] > m.priorityOrder[j]
|
||||||
|
})
|
||||||
|
sort.Slice(m.blocked, func(i, j int) bool {
|
||||||
|
left := m.blocked[i]
|
||||||
|
right := m.blocked[j]
|
||||||
|
if left == nil || right == nil {
|
||||||
|
return left != nil
|
||||||
|
}
|
||||||
|
if left.nextRetryAt.Equal(right.nextRetryAt) {
|
||||||
|
return left.auth.ID < right.auth.ID
|
||||||
|
}
|
||||||
|
if left.nextRetryAt.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if right.nextRetryAt.IsZero() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return left.nextRetryAt.Before(right.nextRetryAt)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
|
||||||
|
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
|
||||||
|
bucket := &readyBucket{}
|
||||||
|
bucket.all = buildReadyView(entries)
|
||||||
|
wsEntries := make([]*scheduledAuth, 0, len(entries))
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
|
||||||
|
wsEntries = append(wsEntries, entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bucket.ws = buildReadyView(wsEntries)
|
||||||
|
return bucket
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
|
||||||
|
func buildReadyView(entries []*scheduledAuth) readyView {
|
||||||
|
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
groups := make(map[string][]*scheduledAuth)
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
|
||||||
|
}
|
||||||
|
if len(groups) <= 1 {
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
view.children = make(map[string]*childBucket, len(groups))
|
||||||
|
view.parentOrder = make([]string, 0, len(groups))
|
||||||
|
for parent := range groups {
|
||||||
|
view.parentOrder = append(view.parentOrder, parent)
|
||||||
|
}
|
||||||
|
sort.Strings(view.parentOrder)
|
||||||
|
for _, parent := range view.parentOrder {
|
||||||
|
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
|
||||||
|
}
|
||||||
|
return view
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
|
||||||
|
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||||
|
for _, entry := range v.flat {
|
||||||
|
if predicate == nil || predicate(entry) {
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
|
||||||
|
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||||
|
if len(v.parentOrder) > 1 && len(v.children) > 0 {
|
||||||
|
return v.pickGroupedRoundRobin(predicate)
|
||||||
|
}
|
||||||
|
if len(v.flat) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
start := 0
|
||||||
|
if len(v.flat) > 0 {
|
||||||
|
start = v.cursor % len(v.flat)
|
||||||
|
}
|
||||||
|
for offset := 0; offset < len(v.flat); offset++ {
|
||||||
|
index := (start + offset) % len(v.flat)
|
||||||
|
entry := v.flat[index]
|
||||||
|
if predicate != nil && !predicate(entry) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
v.cursor = index + 1
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
|
||||||
|
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||||
|
start := 0
|
||||||
|
if len(v.parentOrder) > 0 {
|
||||||
|
start = v.parentCursor % len(v.parentOrder)
|
||||||
|
}
|
||||||
|
for offset := 0; offset < len(v.parentOrder); offset++ {
|
||||||
|
parentIndex := (start + offset) % len(v.parentOrder)
|
||||||
|
parent := v.parentOrder[parentIndex]
|
||||||
|
child := v.children[parent]
|
||||||
|
if child == nil || len(child.items) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemStart := child.cursor % len(child.items)
|
||||||
|
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
|
||||||
|
itemIndex := (itemStart + itemOffset) % len(child.items)
|
||||||
|
entry := child.items[itemIndex]
|
||||||
|
if predicate != nil && !predicate(entry) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
child.cursor = itemIndex + 1
|
||||||
|
v.parentCursor = parentIndex + 1
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,216 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerBenchmarkExecutor struct {
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
|
||||||
|
|
||||||
|
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
|
||||||
|
b.Helper()
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
providers := []string{"gemini"}
|
||||||
|
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
|
||||||
|
if mixed {
|
||||||
|
providers = []string{"gemini", "claude"}
|
||||||
|
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
|
||||||
|
}
|
||||||
|
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
model := "bench-model"
|
||||||
|
for index := 0; index < total; index++ {
|
||||||
|
provider := providers[0]
|
||||||
|
if mixed && index%2 == 1 {
|
||||||
|
provider = providers[1]
|
||||||
|
}
|
||||||
|
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
|
||||||
|
if withPriority {
|
||||||
|
priority := "0"
|
||||||
|
if index%2 == 0 {
|
||||||
|
priority = "10"
|
||||||
|
}
|
||||||
|
auth.Attributes = map[string]string{"priority": priority}
|
||||||
|
}
|
||||||
|
_, errRegister := manager.Register(context.Background(), auth)
|
||||||
|
if errRegister != nil {
|
||||||
|
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
|
||||||
|
}
|
||||||
|
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
|
||||||
|
}
|
||||||
|
manager.syncScheduler()
|
||||||
|
b.Cleanup(func() {
|
||||||
|
for index := 0; index < total; index++ {
|
||||||
|
provider := providers[0]
|
||||||
|
if mixed && index%2 == 1 {
|
||||||
|
provider = providers[1]
|
||||||
|
}
|
||||||
|
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager, providers, model
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNext500(b *testing.B) {
|
||||||
|
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil {
|
||||||
|
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNext1000(b *testing.B) {
|
||||||
|
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil {
|
||||||
|
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNextPriority500(b *testing.B) {
|
||||||
|
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil {
|
||||||
|
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
|
||||||
|
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil {
|
||||||
|
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNextMixed500(b *testing.B) {
|
||||||
|
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||||
|
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
|
||||||
|
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||||
|
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
|
||||||
|
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||||
|
ctx := context.Background()
|
||||||
|
opts := cliproxyexecutor.Options{}
|
||||||
|
tried := map[string]struct{}{}
|
||||||
|
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
|
||||||
|
b.Fatalf("warmup pickNext error = %v", errWarm)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
|
||||||
|
if errPick != nil || auth == nil {
|
||||||
|
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
|
||||||
|
}
|
||||||
|
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,503 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schedulerTestExecutor struct{}
|
||||||
|
|
||||||
|
func (schedulerTestExecutor) Identifier() string { return "test" }
|
||||||
|
|
||||||
|
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackingSelector struct {
|
||||||
|
calls int
|
||||||
|
lastAuthID []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
|
s.calls++
|
||||||
|
s.lastAuthID = s.lastAuthID[:0]
|
||||||
|
for _, auth := range auths {
|
||||||
|
s.lastAuthID = append(s.lastAuthID, auth.ID)
|
||||||
|
}
|
||||||
|
if len(auths) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return auths[len(auths)-1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
|
||||||
|
scheduler := newAuthScheduler(selector)
|
||||||
|
scheduler.rebuild(auths)
|
||||||
|
return scheduler
|
||||||
|
}
|
||||||
|
|
||||||
|
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
|
||||||
|
t.Helper()
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
for _, authID := range authIDs {
|
||||||
|
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, authID := range authIDs {
|
||||||
|
reg.UnregisterClient(authID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
|
||||||
|
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
|
||||||
|
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
want := []string{"high-a", "high-b", "high-a"}
|
||||||
|
for index, wantID := range want {
|
||||||
|
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if got.ID != wantID {
|
||||||
|
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&FillFirstSelector{},
|
||||||
|
&Auth{ID: "b", Provider: "gemini"},
|
||||||
|
&Auth{ID: "a", Provider: "gemini"},
|
||||||
|
&Auth{ID: "c", Provider: "gemini"},
|
||||||
|
)
|
||||||
|
|
||||||
|
for index := 0; index < 3; index++ {
|
||||||
|
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if got.ID != "a" {
|
||||||
|
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
model := "gemini-2.5-pro"
|
||||||
|
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{
|
||||||
|
ID: "cooldown-expired",
|
||||||
|
Provider: "gemini",
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusError,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: time.Now().Add(-1 * time.Second),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickSingle() auth = nil")
|
||||||
|
}
|
||||||
|
if got.ID != "cooldown-expired" {
|
||||||
|
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
|
||||||
|
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
|
||||||
|
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
|
||||||
|
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
|
||||||
|
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
|
||||||
|
for index := range wantIDs {
|
||||||
|
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if got.ID != wantIDs[index] {
|
||||||
|
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||||
|
}
|
||||||
|
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
|
||||||
|
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{ID: "codex-http", Provider: "codex"},
|
||||||
|
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
|
||||||
|
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||||
|
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
|
||||||
|
for index, wantID := range want {
|
||||||
|
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickSingle() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if got.ID != wantID {
|
||||||
|
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{ID: "gemini-a", Provider: "gemini"},
|
||||||
|
&Auth{ID: "gemini-b", Provider: "gemini"},
|
||||||
|
&Auth{ID: "claude-a", Provider: "claude"},
|
||||||
|
)
|
||||||
|
|
||||||
|
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||||
|
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||||
|
for index := range wantProviders {
|
||||||
|
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if provider != wantProviders[index] {
|
||||||
|
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||||
|
}
|
||||||
|
if got.ID != wantIDs[index] {
|
||||||
|
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
model := "gpt-default"
|
||||||
|
registerSchedulerModels(t, "provider-low", model, "low")
|
||||||
|
registerSchedulerModels(t, "provider-high-a", model, "high-a")
|
||||||
|
registerSchedulerModels(t, "provider-high-b", model, "high-b")
|
||||||
|
|
||||||
|
scheduler := newSchedulerForTest(
|
||||||
|
&RoundRobinSelector{},
|
||||||
|
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
|
||||||
|
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
|
||||||
|
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
|
||||||
|
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
|
||||||
|
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
|
||||||
|
for index := range wantProviders {
|
||||||
|
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if provider != wantProviders[index] {
|
||||||
|
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||||
|
}
|
||||||
|
if got.ID != wantIDs[index] {
|
||||||
|
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||||
|
manager.executors["claude"] = schedulerTestExecutor{}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(gemini-b) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||||
|
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||||
|
for index := range wantProviders {
|
||||||
|
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickNextMixed() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if provider != wantProviders[index] {
|
||||||
|
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||||
|
}
|
||||||
|
if got.ID != wantIDs[index] {
|
||||||
|
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &trackingSelector{}
|
||||||
|
manager := NewManager(nil, selector, nil)
|
||||||
|
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||||
|
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
|
||||||
|
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
|
||||||
|
|
||||||
|
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickNext() error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickNext() auth = nil")
|
||||||
|
}
|
||||||
|
if selector.calls != 1 {
|
||||||
|
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
|
||||||
|
}
|
||||||
|
if len(selector.lastAuthID) != 2 {
|
||||||
|
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
|
||||||
|
}
|
||||||
|
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
|
||||||
|
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
if manager.scheduler == nil {
|
||||||
|
t.Fatalf("manager.scheduler = nil")
|
||||||
|
}
|
||||||
|
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
|
||||||
|
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.SetSelector(&FillFirstSelector{})
|
||||||
|
if manager.scheduler.strategy != schedulerStrategyFillFirst {
|
||||||
|
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(auth-b) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(auth-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != "auth-a" {
|
||||||
|
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
|
||||||
|
t.Fatalf("Update(auth-a) error = %v", errUpdate)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != "auth-b" {
|
||||||
|
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
manager.executors["gemini"] = schedulerTestExecutor{}
|
||||||
|
manager.executors["claude"] = schedulerTestExecutor{}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(gemini-b) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||||
|
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||||
|
for index := range wantProviders {
|
||||||
|
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickNextMixed() #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
if provider != wantProviders[index] {
|
||||||
|
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||||
|
}
|
||||||
|
if got.ID != wantIDs[index] {
|
||||||
|
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
manager.executors["claude"] = schedulerTestExecutor{}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(gemini-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("pickNextMixed() error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("pickNextMixed() auth = nil")
|
||||||
|
}
|
||||||
|
if provider != "claude" {
|
||||||
|
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
|
||||||
|
}
|
||||||
|
if got.ID != "claude-a" {
|
||||||
|
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
reg.UnregisterClient("auth-a")
|
||||||
|
reg.UnregisterClient("auth-b")
|
||||||
|
})
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(auth-a) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
|
||||||
|
t.Fatalf("Register(auth-b) error = %v", errRegister)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.MarkResult(context.Background(), Result{
|
||||||
|
AuthID: "auth-a",
|
||||||
|
Provider: "gemini",
|
||||||
|
Model: "test-model",
|
||||||
|
Success: false,
|
||||||
|
Error: &Error{HTTPStatus: 429, Message: "quota"},
|
||||||
|
})
|
||||||
|
|
||||||
|
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
|
||||||
|
}
|
||||||
|
if got == nil || got.ID != "auth-b" {
|
||||||
|
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.MarkResult(context.Background(), Result{
|
||||||
|
AuthID: "auth-a",
|
||||||
|
Provider: "gemini",
|
||||||
|
Model: "test-model",
|
||||||
|
Success: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, 2)
|
||||||
|
for index := 0; index < 2; index++ {
|
||||||
|
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
|
||||||
|
if errPick != nil {
|
||||||
|
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
|
||||||
|
}
|
||||||
|
seen[got.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(seen) != 2 {
|
||||||
|
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,16 +1,13 @@
|
|||||||
package cliproxy
|
package cliproxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
|
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
|
||||||
@@ -39,35 +36,12 @@ func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http.
|
|||||||
if rt != nil {
|
if rt != nil {
|
||||||
return rt
|
return rt
|
||||||
}
|
}
|
||||||
// Parse the proxy URL to determine the scheme.
|
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
|
||||||
proxyURL, errParse := url.Parse(proxyStr)
|
if errBuild != nil {
|
||||||
if errParse != nil {
|
log.Errorf("%v", errBuild)
|
||||||
log.Errorf("parse proxy URL failed: %v", errParse)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var transport *http.Transport
|
if transport == nil {
|
||||||
// Handle different proxy schemes.
|
|
||||||
if proxyURL.Scheme == "socks5" {
|
|
||||||
// Configure SOCKS5 proxy with optional authentication.
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
// Set up a custom transport using the SOCKS5 dialer.
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Configure HTTP or HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
} else {
|
|
||||||
log.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
p.mu.Lock()
|
p.mu.Lock()
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package cliproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRoundTripperForDirectBypassesProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
provider := newDefaultRoundTripperProvider()
|
||||||
|
rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"})
|
||||||
|
transport, ok := rt.(*http.Transport)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("transport type = %T, want *http.Transport", rt)
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
+22
-1
@@ -312,6 +312,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
|||||||
// This operation may block on network calls, but the auth configuration
|
// This operation may block on network calls, but the auth configuration
|
||||||
// is already effective at this point.
|
// is already effective at this point.
|
||||||
s.registerModelsForAuth(auth)
|
s.registerModelsForAuth(auth)
|
||||||
|
|
||||||
|
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
|
||||||
|
// from the now-populated global model registry. Without this, newly added auths
|
||||||
|
// have an empty supportedModelSet (because Register/Update upserts into the
|
||||||
|
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
|
||||||
|
s.coreManager.RefreshSchedulerEntry(auth.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||||
@@ -823,7 +829,22 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
case "codex":
|
case "codex":
|
||||||
models = registry.GetOpenAIModels()
|
codexPlanType := ""
|
||||||
|
if a.Attributes != nil {
|
||||||
|
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
|
||||||
|
}
|
||||||
|
switch strings.ToLower(codexPlanType) {
|
||||||
|
case "pro":
|
||||||
|
models = registry.GetCodexProModels()
|
||||||
|
case "plus":
|
||||||
|
models = registry.GetCodexPlusModels()
|
||||||
|
case "team":
|
||||||
|
models = registry.GetCodexTeamModels()
|
||||||
|
case "free":
|
||||||
|
models = registry.GetCodexFreeModels()
|
||||||
|
default:
|
||||||
|
models = registry.GetCodexProModels()
|
||||||
|
}
|
||||||
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
||||||
if len(entry.Models) > 0 {
|
if len(entry.Models) > 0 {
|
||||||
models = buildCodexConfigModels(entry)
|
models = buildCodexConfigModels(entry)
|
||||||
|
|||||||
@@ -0,0 +1,139 @@
|
|||||||
|
package proxyutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mode describes how a proxy setting should be interpreted.
|
||||||
|
type Mode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ModeInherit means no explicit proxy behavior was configured.
|
||||||
|
ModeInherit Mode = iota
|
||||||
|
// ModeDirect means outbound requests must bypass proxies explicitly.
|
||||||
|
ModeDirect
|
||||||
|
// ModeProxy means a concrete proxy URL was configured.
|
||||||
|
ModeProxy
|
||||||
|
// ModeInvalid means the proxy setting is present but malformed or unsupported.
|
||||||
|
ModeInvalid
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setting is the normalized interpretation of a proxy configuration value.
|
||||||
|
type Setting struct {
|
||||||
|
Raw string
|
||||||
|
Mode Mode
|
||||||
|
URL *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes.
|
||||||
|
func Parse(raw string) (Setting, error) {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
setting := Setting{Raw: trimmed}
|
||||||
|
|
||||||
|
if trimmed == "" {
|
||||||
|
setting.Mode = ModeInherit
|
||||||
|
return setting, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") {
|
||||||
|
setting.Mode = ModeDirect
|
||||||
|
return setting, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedURL, errParse := url.Parse(trimmed)
|
||||||
|
if errParse != nil {
|
||||||
|
setting.Mode = ModeInvalid
|
||||||
|
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
|
||||||
|
}
|
||||||
|
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||||
|
setting.Mode = ModeInvalid
|
||||||
|
return setting, fmt.Errorf("proxy URL missing scheme/host")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch parsedURL.Scheme {
|
||||||
|
case "socks5", "http", "https":
|
||||||
|
setting.Mode = ModeProxy
|
||||||
|
setting.URL = parsedURL
|
||||||
|
return setting, nil
|
||||||
|
default:
|
||||||
|
setting.Mode = ModeInvalid
|
||||||
|
return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDirectTransport returns a transport that bypasses environment proxies.
|
||||||
|
func NewDirectTransport() *http.Transport {
|
||||||
|
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||||
|
clone := transport.Clone()
|
||||||
|
clone.Proxy = nil
|
||||||
|
return clone
|
||||||
|
}
|
||||||
|
return &http.Transport{Proxy: nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
|
||||||
|
func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
|
||||||
|
setting, errParse := Parse(raw)
|
||||||
|
if errParse != nil {
|
||||||
|
return nil, setting.Mode, errParse
|
||||||
|
}
|
||||||
|
|
||||||
|
switch setting.Mode {
|
||||||
|
case ModeInherit:
|
||||||
|
return nil, setting.Mode, nil
|
||||||
|
case ModeDirect:
|
||||||
|
return NewDirectTransport(), setting.Mode, nil
|
||||||
|
case ModeProxy:
|
||||||
|
if setting.URL.Scheme == "socks5" {
|
||||||
|
var proxyAuth *proxy.Auth
|
||||||
|
if setting.URL.User != nil {
|
||||||
|
username := setting.URL.User.Username()
|
||||||
|
password, _ := setting.URL.User.Password()
|
||||||
|
proxyAuth = &proxy.Auth{User: username, Password: password}
|
||||||
|
}
|
||||||
|
dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
|
||||||
|
if errSOCKS5 != nil {
|
||||||
|
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||||
|
}
|
||||||
|
return &http.Transport{
|
||||||
|
Proxy: nil,
|
||||||
|
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(network, addr)
|
||||||
|
},
|
||||||
|
}, setting.Mode, nil
|
||||||
|
}
|
||||||
|
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
|
||||||
|
default:
|
||||||
|
return nil, setting.Mode, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildDialer constructs a proxy dialer for settings that operate at the connection layer.
|
||||||
|
func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
|
||||||
|
setting, errParse := Parse(raw)
|
||||||
|
if errParse != nil {
|
||||||
|
return nil, setting.Mode, errParse
|
||||||
|
}
|
||||||
|
|
||||||
|
switch setting.Mode {
|
||||||
|
case ModeInherit:
|
||||||
|
return nil, setting.Mode, nil
|
||||||
|
case ModeDirect:
|
||||||
|
return proxy.Direct, setting.Mode, nil
|
||||||
|
case ModeProxy:
|
||||||
|
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
|
||||||
|
if errDialer != nil {
|
||||||
|
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
|
||||||
|
}
|
||||||
|
return dialer, setting.Mode, nil
|
||||||
|
default:
|
||||||
|
return nil, setting.Mode, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package proxyutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want Mode
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{name: "inherit", input: "", want: ModeInherit},
|
||||||
|
{name: "direct", input: "direct", want: ModeDirect},
|
||||||
|
{name: "none", input: "none", want: ModeDirect},
|
||||||
|
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
|
||||||
|
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
|
||||||
|
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
|
||||||
|
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
setting, errParse := Parse(tt.input)
|
||||||
|
if tt.wantErr && errParse == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !tt.wantErr && errParse != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errParse)
|
||||||
|
}
|
||||||
|
if setting.Mode != tt.want {
|
||||||
|
t.Fatalf("mode = %d, want %d", setting.Mode, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildHTTPTransportDirectBypassesProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
transport, mode, errBuild := BuildHTTPTransport("direct")
|
||||||
|
if errBuild != nil {
|
||||||
|
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
|
||||||
|
}
|
||||||
|
if mode != ModeDirect {
|
||||||
|
t.Fatalf("mode = %d, want %d", mode, ModeDirect)
|
||||||
|
}
|
||||||
|
if transport == nil {
|
||||||
|
t.Fatal("expected transport, got nil")
|
||||||
|
}
|
||||||
|
if transport.Proxy != nil {
|
||||||
|
t.Fatal("expected direct transport to disable proxy function")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
transport, mode, errBuild := BuildHTTPTransport("http://proxy.example.com:8080")
|
||||||
|
if errBuild != nil {
|
||||||
|
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
|
||||||
|
}
|
||||||
|
if mode != ModeProxy {
|
||||||
|
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||||
|
}
|
||||||
|
if transport == nil {
|
||||||
|
t.Fatal("expected transport, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
|
||||||
|
if errRequest != nil {
|
||||||
|
t.Fatalf("http.NewRequest returned error: %v", errRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyURL, errProxy := transport.Proxy(req)
|
||||||
|
if errProxy != nil {
|
||||||
|
t.Fatalf("transport.Proxy returned error: %v", errProxy)
|
||||||
|
}
|
||||||
|
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
|
||||||
|
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user