Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce53d3a287 | ||
|
|
4cc99e7449 | ||
|
|
71773fe032 | ||
|
|
a1e0fa0f39 | ||
|
|
fc2f0b6983 | ||
|
|
5c9997cdac | ||
|
|
f5941a411c | ||
|
|
ba672bbd07 | ||
|
|
d9c6627a53 | ||
|
|
2e9907c3ac | ||
|
|
90afb9cb73 | ||
|
|
d0cc0cd9a5 | ||
|
|
338321e553 |
@@ -150,6 +150,10 @@ A Windows tray application implemented using PowerShell scripts, without relying
|
||||
|
||||
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
|
||||
|
||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||
|
||||
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -149,6 +149,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
|
||||
|
||||
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
|
||||
|
||||
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
|
||||
|
||||
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -1266,6 +1266,10 @@ func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if system.Type == gjson.String && system.String() != "" {
|
||||
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
|
||||
partJSON, _ = sjson.Set(partJSON, "text", system.String())
|
||||
result += "," + partJSON
|
||||
}
|
||||
result += "]"
|
||||
|
||||
|
||||
@@ -980,3 +980,87 @@ func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *te
|
||||
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 1: String system prompt is preserved and converted to a content block
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
system := gjson.GetBytes(out, "system")
|
||||
if !system.IsArray() {
|
||||
t.Fatalf("system should be an array, got %s", system.Type)
|
||||
}
|
||||
|
||||
blocks := system.Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
|
||||
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
|
||||
}
|
||||
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
|
||||
}
|
||||
if blocks[2].Get("text").String() != "You are a helpful assistant." {
|
||||
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
|
||||
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 2: Strict mode drops the string system prompt
|
||||
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
|
||||
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, true)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 3: Empty string system prompt does not produce a spurious block
|
||||
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
|
||||
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 2 {
|
||||
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 4: Array system prompt is unaffected by the string handling
|
||||
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
|
||||
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[2].Get("text").String() != "Be concise." {
|
||||
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
// Test case 5: Special characters in string system prompt survive conversion
|
||||
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
|
||||
payload := []byte(`{"system":"Use <xml> tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
out := checkSystemInstructionsWithMode(payload, false)
|
||||
|
||||
blocks := gjson.GetBytes(out, "system").Array()
|
||||
if len(blocks) != 3 {
|
||||
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
|
||||
}
|
||||
if blocks[2].Get("text").String() != `Use <xml> tags & "quotes" in output.` {
|
||||
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,21 +147,21 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
content := m.Get("content")
|
||||
|
||||
if (role == "system" || role == "developer") && len(arr) > 1 {
|
||||
// system -> system_instruction as a user message style
|
||||
// system -> systemInstruction as a user message style
|
||||
if content.Type == gjson.String {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.String())
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
|
||||
systemPartIndex++
|
||||
} else if content.IsObject() && content.Get("type").String() == "text" {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
|
||||
systemPartIndex++
|
||||
} else if content.IsArray() {
|
||||
contents := content.Array()
|
||||
if len(contents) > 0 {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.role", "user")
|
||||
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
|
||||
for j := 0; j < len(contents); j++ {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("system_instruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
|
||||
systemPartIndex++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
if instructions := root.Get("instructions"); instructions.Exists() {
|
||||
systemInstr := `{"parts":[{"text":""}]}`
|
||||
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||
}
|
||||
|
||||
// Convert input messages to Gemini contents format
|
||||
@@ -119,7 +119,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
if strings.EqualFold(itemRole, "system") {
|
||||
if contentArray := item.Get("content"); contentArray.Exists() {
|
||||
systemInstr := ""
|
||||
if systemInstructionResult := gjson.Get(out, "system_instruction"); systemInstructionResult.Exists() {
|
||||
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
|
||||
systemInstr = systemInstructionResult.Raw
|
||||
} else {
|
||||
systemInstr = `{"parts":[]}`
|
||||
@@ -140,7 +140,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
if systemInstr != `{"parts":[]}` {
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
|
||||
}
|
||||
}
|
||||
continue
|
||||
|
||||
@@ -34,6 +34,8 @@ const (
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsPayloadLogMaxSize = 2048
|
||||
wsBodyLogMaxSize = 64 * 1024
|
||||
wsBodyLogTruncated = "\n[websocket log truncated]\n"
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
@@ -825,18 +827,71 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
if builder.Len() >= wsBodyLogMaxSize {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
if !appendWebsocketLogString(builder, "\n") {
|
||||
return
|
||||
}
|
||||
}
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
if !appendWebsocketLogString(builder, "websocket.") {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogString(builder, eventType) {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogString(builder, "\n") {
|
||||
return
|
||||
}
|
||||
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
|
||||
appendWebsocketLogString(builder, wsBodyLogTruncated)
|
||||
return
|
||||
}
|
||||
appendWebsocketLogString(builder, "\n")
|
||||
}
|
||||
|
||||
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
|
||||
if builder == nil {
|
||||
return false
|
||||
}
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if len(value) <= remaining {
|
||||
builder.WriteString(value)
|
||||
return true
|
||||
}
|
||||
builder.WriteString(value[:remaining])
|
||||
return false
|
||||
}
|
||||
|
||||
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
|
||||
if builder == nil {
|
||||
return false
|
||||
}
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
if remaining <= 0 {
|
||||
return false
|
||||
}
|
||||
if len(value) <= remaining {
|
||||
builder.Write(value)
|
||||
return true
|
||||
}
|
||||
limit := remaining - reserveForSuffix
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
if limit > len(value) {
|
||||
limit = len(value)
|
||||
}
|
||||
builder.Write(value[:limit])
|
||||
return false
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
|
||||
@@ -266,6 +266,34 @@ func TestAppendWebsocketEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
|
||||
|
||||
appendWebsocketEvent(&builder, "request", payload)
|
||||
|
||||
got := builder.String()
|
||||
if len(got) > wsBodyLogMaxSize {
|
||||
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
|
||||
}
|
||||
if !strings.Contains(got, wsBodyLogTruncated) {
|
||||
t.Fatalf("expected truncation marker in body log")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
|
||||
initial := builder.String()
|
||||
|
||||
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
|
||||
|
||||
if builder.String() != initial {
|
||||
t.Fatalf("builder grew after reaching limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
@@ -213,6 +213,26 @@ func (m *Manager) syncScheduler() {
|
||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||
}
|
||||
|
||||
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||
// supportedModelSet is rebuilt from the current global model registry state.
|
||||
// This must be called after models have been registered for a newly added auth,
|
||||
// because the initial scheduler.upsertAuth during Register/Update runs before
|
||||
// registerModelsForAuth and therefore snapshots an empty model set.
|
||||
func (m *Manager) RefreshSchedulerEntry(authID string) {
|
||||
if m == nil || m.scheduler == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.RLock()
|
||||
auth, ok := m.auths[authID]
|
||||
if !ok || auth == nil {
|
||||
m.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
snapshot := auth.Clone()
|
||||
m.mu.RUnlock()
|
||||
m.scheduler.upsertAuth(snapshot)
|
||||
}
|
||||
|
||||
func (m *Manager) SetSelector(selector Selector) {
|
||||
if m == nil {
|
||||
return
|
||||
@@ -2038,6 +2058,10 @@ func shouldRetrySchedulerPick(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var cooldownErr *modelCooldownError
|
||||
if errors.As(err, &cooldownErr) {
|
||||
return true
|
||||
}
|
||||
var authErr *Error
|
||||
if !errors.As(err, &authErr) || authErr == nil {
|
||||
return false
|
||||
|
||||
163
sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
Normal file
163
sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type schedulerProviderTestExecutor struct {
|
||||
provider string
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
|
||||
|
||||
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
prime func(*Manager, *Auth) error
|
||||
}{
|
||||
{
|
||||
name: "register",
|
||||
prime: func(manager *Manager, auth *Auth) error {
|
||||
_, errRegister := manager.Register(ctx, auth)
|
||||
return errRegister
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update",
|
||||
prime: func(manager *Manager, auth *Auth) error {
|
||||
_, errRegister := manager.Register(ctx, auth)
|
||||
if errRegister != nil {
|
||||
return errRegister
|
||||
}
|
||||
updated := auth.Clone()
|
||||
updated.Metadata = map[string]any{"updated": true}
|
||||
_, errUpdate := manager.Update(ctx, updated)
|
||||
return errUpdate
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
auth := &Auth{
|
||||
ID: "refresh-entry-" + testCase.name,
|
||||
Provider: "gemini",
|
||||
}
|
||||
if errPrime := testCase.prime(manager, auth); errPrime != nil {
|
||||
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
|
||||
}
|
||||
|
||||
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||
var authErr *Error
|
||||
if !errors.As(errPick, &authErr) || authErr == nil {
|
||||
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
|
||||
}
|
||||
if authErr.Code != "auth_not_found" {
|
||||
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
|
||||
}
|
||||
|
||||
manager.RefreshSchedulerEntry(auth.ID)
|
||||
|
||||
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickSingle() after refresh error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != auth.ID {
|
||||
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
|
||||
|
||||
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
|
||||
|
||||
oldAuth := &Auth{
|
||||
ID: "cooldown-stale-old",
|
||||
Provider: "gemini",
|
||||
}
|
||||
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
|
||||
t.Fatalf("register old auth: %v", errRegister)
|
||||
}
|
||||
|
||||
manager.MarkResult(ctx, Result{
|
||||
AuthID: oldAuth.ID,
|
||||
Provider: "gemini",
|
||||
Model: "scheduler-cooldown-rebuild-model",
|
||||
Success: false,
|
||||
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
|
||||
})
|
||||
|
||||
newAuth := &Auth{
|
||||
ID: "cooldown-stale-new",
|
||||
Provider: "gemini",
|
||||
}
|
||||
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
|
||||
t.Fatalf("register new auth: %v", errRegister)
|
||||
}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(newAuth.ID)
|
||||
})
|
||||
|
||||
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||
var cooldownErr *modelCooldownError
|
||||
if !errors.As(errPick, &cooldownErr) {
|
||||
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
|
||||
}
|
||||
if got != nil {
|
||||
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
|
||||
}
|
||||
|
||||
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNext() error = %v", errPick)
|
||||
}
|
||||
if executor == nil {
|
||||
t.Fatal("pickNext() executor = nil")
|
||||
}
|
||||
if got == nil || got.ID != newAuth.ID {
|
||||
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
|
||||
}
|
||||
}
|
||||
@@ -250,17 +250,41 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
||||
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
||||
}
|
||||
|
||||
predicate := triedPredicate(tried)
|
||||
candidateShards := make([]*modelScheduler, len(normalized))
|
||||
bestPriority := 0
|
||||
hasCandidate := false
|
||||
now := time.Now()
|
||||
for providerIndex, providerKey := range normalized {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, now)
|
||||
candidateShards[providerIndex] = shard
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
|
||||
if !okPriority {
|
||||
continue
|
||||
}
|
||||
if !hasCandidate || priorityReady > bestPriority {
|
||||
bestPriority = priorityReady
|
||||
hasCandidate = true
|
||||
}
|
||||
}
|
||||
if !hasCandidate {
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
if s.strategy == schedulerStrategyFillFirst {
|
||||
for _, providerKey := range normalized {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
for providerIndex, providerKey := range normalized {
|
||||
shard := candidateShards[providerIndex]
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyLocked(false, s.strategy, triedPredicate(tried))
|
||||
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
|
||||
if picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
@@ -276,15 +300,11 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
||||
for offset := 0; offset < len(normalized); offset++ {
|
||||
providerIndex := (start + offset) % len(normalized)
|
||||
providerKey := normalized[providerIndex]
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
shard := candidateShards[providerIndex]
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyLocked(false, schedulerStrategyRoundRobin, triedPredicate(tried))
|
||||
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
|
||||
if picked == nil {
|
||||
continue
|
||||
}
|
||||
@@ -629,6 +649,19 @@ func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedule
|
||||
return nil
|
||||
}
|
||||
m.promoteExpiredLocked(time.Now())
|
||||
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
|
||||
if !okPriority {
|
||||
return nil
|
||||
}
|
||||
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
|
||||
}
|
||||
|
||||
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
|
||||
// The caller must ensure expired entries are already promoted when needed.
|
||||
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
|
||||
if m == nil {
|
||||
return 0, false
|
||||
}
|
||||
for _, priority := range m.priorityOrder {
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
@@ -638,17 +671,37 @@ func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedule
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
view = &bucket.ws
|
||||
}
|
||||
var picked *scheduledAuth
|
||||
if strategy == schedulerStrategyFillFirst {
|
||||
picked = view.pickFirst(predicate)
|
||||
} else {
|
||||
picked = view.pickRoundRobin(predicate)
|
||||
}
|
||||
if picked != nil && picked.auth != nil {
|
||||
return picked.auth
|
||||
if view.pickFirst(predicate) != nil {
|
||||
return priority, true
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
|
||||
// The caller must ensure expired entries are already promoted when needed.
|
||||
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
return nil
|
||||
}
|
||||
view := &bucket.all
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
view = &bucket.ws
|
||||
}
|
||||
var picked *scheduledAuth
|
||||
if strategy == schedulerStrategyFillFirst {
|
||||
picked = view.pickFirst(predicate)
|
||||
} else {
|
||||
picked = view.pickRoundRobin(predicate)
|
||||
}
|
||||
if picked == nil || picked.auth == nil {
|
||||
return nil
|
||||
}
|
||||
return picked.auth
|
||||
}
|
||||
|
||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||
|
||||
@@ -176,6 +176,25 @@ func BenchmarkManagerPickNextMixed500(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
|
||||
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
|
||||
ctx := context.Background()
|
||||
opts := cliproxyexecutor.Options{}
|
||||
tried := map[string]struct{}{}
|
||||
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
|
||||
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
|
||||
if errPick != nil || auth == nil || exec == nil || provider == "" {
|
||||
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
|
||||
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -237,6 +237,41 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := "gpt-default"
|
||||
registerSchedulerModels(t, "provider-low", model, "low")
|
||||
registerSchedulerModels(t, "provider-high-a", model, "high-a")
|
||||
registerSchedulerModels(t, "provider-high-b", model, "high-b")
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
&RoundRobinSelector{},
|
||||
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
|
||||
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
|
||||
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
|
||||
)
|
||||
|
||||
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
|
||||
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
|
||||
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
|
||||
for index := range wantProviders {
|
||||
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("pickMixed() #%d auth = nil", index)
|
||||
}
|
||||
if provider != wantProviders[index] {
|
||||
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
|
||||
}
|
||||
if got.ID != wantIDs[index] {
|
||||
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -312,6 +312,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
// This operation may block on network calls, but the auth configuration
|
||||
// is already effective at this point.
|
||||
s.registerModelsForAuth(auth)
|
||||
|
||||
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
|
||||
// from the now-populated global model registry. Without this, newly added auths
|
||||
// have an empty supportedModelSet (because Register/Update upserts into the
|
||||
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
|
||||
s.coreManager.RefreshSchedulerEntry(auth.ID)
|
||||
}
|
||||
|
||||
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
|
||||
Reference in New Issue
Block a user