diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 4e72d7c8..939f1d2b 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -151,6 +151,9 @@ type Manager struct { mu sync.RWMutex auths map[string]*Auth scheduler *authScheduler + // homeRuntimeAuths caches auths returned by Home so websocket sessions can + // reuse an established upstream credential without dispatching every turn. + homeRuntimeAuths map[string]*Auth // providerOffsets tracks per-model provider rotation state for multi-provider routing. providerOffsets map[string]int @@ -195,6 +198,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { selector: selector, hook: hook, auths: make(map[string]*Auth), + homeRuntimeAuths: make(map[string]*Auth), providerOffsets: make(map[string]int), modelPoolOffsets: make(map[string]int), } @@ -376,6 +380,9 @@ func (m *Manager) SetConfig(cfg *internalconfig.Config) { cfg = &internalconfig.Config{} } m.runtimeConfig.Store(cfg) + if !cfg.Home.Enabled { + m.clearHomeRuntimeAuths() + } m.rebuildAPIKeyModelAliasFromRuntimeConfig() } @@ -2713,7 +2720,10 @@ func (m *Manager) GetByID(id string) (*Auth, bool) { defer m.mu.RUnlock() auth, ok := m.auths[id] if !ok { - return nil, false + auth, ok = m.homeRuntimeAuths[id] + if !ok { + return nil, false + } } return auth.Clone(), true } @@ -2751,12 +2761,15 @@ func (m *Manager) CloseExecutionSession(sessionID string) { return } - m.mu.RLock() + m.mu.Lock() + if sessionID == CloseAllExecutionSessionsID { + m.clearHomeRuntimeAuthsLocked() + } executors := make([]ProviderExecutor, 0, len(m.executors)) for _, exec := range m.executors { executors = append(executors, exec) } - m.mu.RUnlock() + m.mu.Unlock() for i := range executors { if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil { @@ -3168,6 +3181,80 @@ func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) { ginCtx.Set("userApiKey", apiKey) } +func homeExecutionSessionIDFromMetadata(meta map[string]any) string { + if len(meta) == 0 { + return "" + } + raw, ok := meta[cliproxyexecutor.ExecutionSessionMetadataKey] + if !ok || raw == nil { + return "" + } + switch value := raw.(type) { + case string: + return strings.TrimSpace(value) + case []byte: + return strings.TrimSpace(string(value)) + default: + return "" + } +} + +func (m *Manager) clearHomeRuntimeAuths() { + if m == nil { + return + } + m.mu.Lock() + m.clearHomeRuntimeAuthsLocked() + m.mu.Unlock() +} + +func (m *Manager) clearHomeRuntimeAuthsLocked() { + if m == nil { + return + } + m.homeRuntimeAuths = make(map[string]*Auth) +} + +func (m *Manager) rememberHomeRuntimeAuth(auth *Auth) { + if m == nil || auth == nil || strings.TrimSpace(auth.ID) == "" || !authWebsocketsEnabled(auth) { + return + } + m.mu.Lock() + if m.homeRuntimeAuths == nil { + m.homeRuntimeAuths = make(map[string]*Auth) + } + m.homeRuntimeAuths[auth.ID] = auth.Clone() + m.mu.Unlock() +} + +func (m *Manager) homeRuntimeAuthByID(authID string) (*Auth, ProviderExecutor, string, bool) { + authID = strings.TrimSpace(authID) + if m == nil || authID == "" { + return nil, nil, "", false + } + m.mu.RLock() + auth := m.homeRuntimeAuths[authID] + m.mu.RUnlock() + if auth == nil || !authWebsocketsEnabled(auth) { + return nil, nil, "", false + } + providerKey := strings.ToLower(strings.TrimSpace(auth.Provider)) + if providerKey == "" { + return nil, nil, "", false + } + executor, ok := m.Executor(providerKey) + if !ok && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["base_url"]) != "" { + executor, ok = m.Executor("openai-compatibility") + if ok { + providerKey = "openai-compatibility" + } + } + if !ok { + return nil, nil, "", false + } + return auth.Clone(), executor, providerKey, true +} + func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts cliproxyexecutor.Options) (*Auth, ProviderExecutor, string, error) { if m == nil { return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} @@ -3175,6 +3262,14 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro if ctx == nil { ctx = context.Background() } + if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" { + if pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata); pinnedAuthID != "" { + if auth, executor, providerKey, ok := m.homeRuntimeAuthByID(pinnedAuthID); ok { + return auth, executor, providerKey, nil + } + } + } + client := home.Current() if client == nil || !client.HeartbeatOK() { return nil, nil, "", &Error{Code: "home_unavailable", Message: "home control center unavailable", HTTPStatus: http.StatusServiceUnavailable} @@ -3254,7 +3349,11 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered", HTTPStatus: http.StatusBadGateway} } - return auth.Clone(), executor, providerKey, nil + authCopy := auth.Clone() + if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" && authWebsocketsEnabled(authCopy) { + m.rememberHomeRuntimeAuth(authCopy) + } + return authCopy, executor, providerKey, nil } func requestedModelFromMetadata(metadata map[string]any, fallback string) string { diff --git a/sdk/cliproxy/auth/home_websocket_reuse_test.go b/sdk/cliproxy/auth/home_websocket_reuse_test.go new file mode 100644 index 00000000..b3b329ee --- /dev/null +++ b/sdk/cliproxy/auth/home_websocket_reuse_test.go @@ -0,0 +1,113 @@ +package auth + +import ( + "context" + "errors" + "net/http" + "testing" + + internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" +) + +func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing.T) { + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(schedulerTestExecutor{}) + + auth := &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + Attributes: map[string]string{ + "websockets": "true", + homeUpstreamModelAttributeKey: "upstream-model", + }, + Metadata: map[string]any{"email": "home@example.com"}, + } + auth.EnsureIndex() + manager.rememberHomeRuntimeAuth(auth) + cachedAuth, ok := manager.GetByID("home-auth-1") + if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) { + t.Fatalf("GetByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok) + } + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + opts := cliproxyexecutor.Options{ + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "session-1", + cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1", + }, + Headers: http.Header{"Authorization": {"Bearer client-key"}}, + } + + got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts) + if errPick != nil { + t.Fatalf("pickNextViaHome() error = %v", errPick) + } + if got == nil || got.ID != "home-auth-1" { + t.Fatalf("pickNextViaHome() auth = %#v, want home-auth-1", got) + } + if executor == nil { + t.Fatal("pickNextViaHome() executor is nil") + } + if provider != "test" { + t.Fatalf("pickNextViaHome() provider = %q, want test", provider) + } +} + +func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) { + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(schedulerTestExecutor{}) + + manager.mu.Lock() + manager.homeRuntimeAuths["home-auth-1"] = &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + } + manager.mu.Unlock() + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + opts := cliproxyexecutor.Options{ + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "session-1", + cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1", + }, + Headers: http.Header{"Authorization": {"Bearer client-key"}}, + } + + got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts) + if errPick == nil { + t.Fatal("pickNextViaHome() error is nil, want home unavailable error") + } + var authErr *Error + if !errors.As(errPick, &authErr) || authErr.Code != "home_unavailable" { + t.Fatalf("pickNextViaHome() error = %v, want home_unavailable", errPick) + } + if got != nil || executor != nil || provider != "" { + t.Fatalf("pickNextViaHome() reused non-websocket auth: auth=%#v executor=%#v provider=%q", got, executor, provider) + } +} + +func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) { + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.rememberHomeRuntimeAuth(&Auth{ + ID: "home-auth-1", + Provider: "test", + Attributes: map[string]string{ + "websockets": "true", + }, + }) + + if _, ok := manager.GetByID("home-auth-1"); !ok { + t.Fatal("expected remembered home auth before disabling home") + } + + manager.SetConfig(&internalconfig.Config{}) + if _, ok := manager.GetByID("home-auth-1"); ok { + t.Fatal("remembered home auth was not cleared when home was disabled") + } +}