From 8300ee8bbee62ce85389e42e76464f6dcb7d4a26 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 10 May 2026 14:00:13 +0800 Subject: [PATCH] feat(auth): enhance home auth session reuse with scoped caching and ref counting - Added `homeRuntimeAuthSessions` and `homeRuntimeAuthRefs` for scoped caching of home auths per session. - Updated `pickNextViaHome` to prevent reuse of already-tried pinned auths during session retries. - Implemented reference counting for shared auths across multiple sessions to improve memory management. - Enhanced session cleanup logic to clear cached auths only when all referencing sessions are closed. - Added unit tests to validate scoped caching, retry logic, and session cleanup behavior. --- sdk/cliproxy/auth/conductor.go | 110 ++++++++++++++---- .../auth/home_websocket_reuse_test.go | 105 ++++++++++++++++- 2 files changed, 186 insertions(+), 29 deletions(-) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 939f1d2b..64a28d58 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -153,7 +153,9 @@ type Manager struct { scheduler *authScheduler // homeRuntimeAuths caches auths returned by Home so websocket sessions can // reuse an established upstream credential without dispatching every turn. - homeRuntimeAuths map[string]*Auth + homeRuntimeAuths map[string]*Auth + homeRuntimeAuthSessions map[string]map[string]struct{} + homeRuntimeAuthRefs map[string]int // providerOffsets tracks per-model provider rotation state for multi-provider routing. providerOffsets map[string]int @@ -193,14 +195,16 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { hook = NoopHook{} } manager := &Manager{ - store: store, - executors: make(map[string]ProviderExecutor), - selector: selector, - hook: hook, - auths: make(map[string]*Auth), - homeRuntimeAuths: make(map[string]*Auth), - providerOffsets: make(map[string]int), - modelPoolOffsets: make(map[string]int), + store: store, + executors: make(map[string]ProviderExecutor), + selector: selector, + hook: hook, + auths: make(map[string]*Auth), + homeRuntimeAuths: make(map[string]*Auth), + homeRuntimeAuthSessions: make(map[string]map[string]struct{}), + homeRuntimeAuthRefs: make(map[string]int), + providerOffsets: make(map[string]int), + modelPoolOffsets: make(map[string]int), } // atomic.Value requires non-nil initial value. manager.runtimeConfig.Store(&internalconfig.Config{}) @@ -2764,6 +2768,8 @@ func (m *Manager) CloseExecutionSession(sessionID string) { m.mu.Lock() if sessionID == CloseAllExecutionSessionsID { m.clearHomeRuntimeAuthsLocked() + } else { + m.clearHomeRuntimeAuthsForSessionLocked(sessionID) } executors := make([]ProviderExecutor, 0, len(m.executors)) for _, exec := range m.executors { @@ -2809,7 +2815,7 @@ func (m *Manager) routeAwareSelectionRequired(auth *Auth, routeModel string) boo func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { if m.HomeEnabled() { - auth, exec, _, err := m.pickNextViaHome(ctx, model, opts) + auth, exec, _, err := m.pickNextViaHome(ctx, model, opts, tried) return auth, exec, err } @@ -2883,7 +2889,7 @@ func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, op func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { if m.HomeEnabled() { - auth, exec, _, err := m.pickNextViaHome(ctx, model, opts) + auth, exec, _, err := m.pickNextViaHome(ctx, model, opts, tried) return auth, exec, err } @@ -2945,7 +2951,7 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { if m.HomeEnabled() { - return m.pickNextViaHome(ctx, model, opts) + return m.pickNextViaHome(ctx, model, opts, tried) } pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) @@ -3041,7 +3047,7 @@ func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, m func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { if m.HomeEnabled() { - return m.pickNextViaHome(ctx, model, opts) + return m.pickNextViaHome(ctx, model, opts, tried) } if !m.useSchedulerFastPath() { @@ -3213,26 +3219,76 @@ func (m *Manager) clearHomeRuntimeAuthsLocked() { return } m.homeRuntimeAuths = make(map[string]*Auth) + m.homeRuntimeAuthSessions = make(map[string]map[string]struct{}) + m.homeRuntimeAuthRefs = make(map[string]int) } -func (m *Manager) rememberHomeRuntimeAuth(auth *Auth) { - if m == nil || auth == nil || strings.TrimSpace(auth.ID) == "" || !authWebsocketsEnabled(auth) { +func (m *Manager) clearHomeRuntimeAuthsForSessionLocked(sessionID string) { + sessionID = strings.TrimSpace(sessionID) + if m == nil || sessionID == "" { + return + } + authIDs := m.homeRuntimeAuthSessions[sessionID] + if len(authIDs) == 0 { + delete(m.homeRuntimeAuthSessions, sessionID) + return + } + for authID := range authIDs { + refCount := m.homeRuntimeAuthRefs[authID] + if refCount <= 1 { + delete(m.homeRuntimeAuthRefs, authID) + delete(m.homeRuntimeAuths, authID) + continue + } + m.homeRuntimeAuthRefs[authID] = refCount - 1 + } + delete(m.homeRuntimeAuthSessions, sessionID) +} + +func (m *Manager) rememberHomeRuntimeAuth(sessionID string, auth *Auth) { + sessionID = strings.TrimSpace(sessionID) + authID := "" + if auth != nil { + authID = strings.TrimSpace(auth.ID) + } + if m == nil || auth == nil || sessionID == "" || authID == "" || !authWebsocketsEnabled(auth) { return } m.mu.Lock() if m.homeRuntimeAuths == nil { m.homeRuntimeAuths = make(map[string]*Auth) } - m.homeRuntimeAuths[auth.ID] = auth.Clone() + if m.homeRuntimeAuthSessions == nil { + m.homeRuntimeAuthSessions = make(map[string]map[string]struct{}) + } + if m.homeRuntimeAuthRefs == nil { + m.homeRuntimeAuthRefs = make(map[string]int) + } + m.homeRuntimeAuths[authID] = auth.Clone() + sessionAuths := m.homeRuntimeAuthSessions[sessionID] + if sessionAuths == nil { + sessionAuths = make(map[string]struct{}) + m.homeRuntimeAuthSessions[sessionID] = sessionAuths + } + if _, exists := sessionAuths[authID]; !exists { + sessionAuths[authID] = struct{}{} + m.homeRuntimeAuthRefs[authID]++ + } m.mu.Unlock() } -func (m *Manager) homeRuntimeAuthByID(authID string) (*Auth, ProviderExecutor, string, bool) { +func (m *Manager) homeRuntimeAuthByID(sessionID string, authID string) (*Auth, ProviderExecutor, string, bool) { + sessionID = strings.TrimSpace(sessionID) authID = strings.TrimSpace(authID) - if m == nil || authID == "" { + if m == nil || sessionID == "" || authID == "" { return nil, nil, "", false } m.mu.RLock() + sessionAuths := m.homeRuntimeAuthSessions[sessionID] + if _, ok := sessionAuths[authID]; !ok { + m.mu.RUnlock() + return nil, nil, "", false + } auth := m.homeRuntimeAuths[authID] m.mu.RUnlock() if auth == nil || !authWebsocketsEnabled(auth) { @@ -3255,17 +3311,22 @@ func (m *Manager) homeRuntimeAuthByID(authID string) (*Auth, ProviderExecutor, s return auth.Clone(), executor, providerKey, true } -func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts cliproxyexecutor.Options) (*Auth, ProviderExecutor, string, error) { +func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { if m == nil { return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"} } if ctx == nil { ctx = context.Background() } - if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" { + executionSessionID := homeExecutionSessionIDFromMetadata(opts.Metadata) + count := homeAuthCountFromMetadata(opts.Metadata) + if cliproxyexecutor.DownstreamWebsocket(ctx) && executionSessionID != "" && count <= 1 { if pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata); pinnedAuthID != "" { - if auth, executor, providerKey, ok := m.homeRuntimeAuthByID(pinnedAuthID); ok { - return auth, executor, providerKey, nil + _, alreadyTried := tried[pinnedAuthID] + if !alreadyTried { + if auth, executor, providerKey, ok := m.homeRuntimeAuthByID(executionSessionID, pinnedAuthID); ok { + return auth, executor, providerKey, nil + } } } } @@ -3277,7 +3338,6 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro requestedModel := requestedModelFromMetadata(opts.Metadata, model) sessionID := ExtractSessionID(opts.Headers, opts.OriginalRequest, opts.Metadata) - count := homeAuthCountFromMetadata(opts.Metadata) raw, err := client.RPopAuth(ctx, requestedModel, sessionID, opts.Headers, count) if err != nil { @@ -3350,8 +3410,8 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro } authCopy := auth.Clone() - if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" && authWebsocketsEnabled(authCopy) { - m.rememberHomeRuntimeAuth(authCopy) + if cliproxyexecutor.DownstreamWebsocket(ctx) && executionSessionID != "" && authWebsocketsEnabled(authCopy) { + m.rememberHomeRuntimeAuth(executionSessionID, authCopy) } return authCopy, executor, providerKey, nil } diff --git a/sdk/cliproxy/auth/home_websocket_reuse_test.go b/sdk/cliproxy/auth/home_websocket_reuse_test.go index b3b329ee..284dd076 100644 --- a/sdk/cliproxy/auth/home_websocket_reuse_test.go +++ b/sdk/cliproxy/auth/home_websocket_reuse_test.go @@ -26,7 +26,7 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing. Metadata: map[string]any{"email": "home@example.com"}, } auth.EnsureIndex() - manager.rememberHomeRuntimeAuth(auth) + manager.rememberHomeRuntimeAuth("session-1", auth) cachedAuth, ok := manager.GetByID("home-auth-1") if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) { t.Fatalf("GetByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok) @@ -41,7 +41,7 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing. Headers: http.Header{"Authorization": {"Bearer client-key"}}, } - got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts) + got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts, nil) if errPick != nil { t.Fatalf("pickNextViaHome() error = %v", errPick) } @@ -56,6 +56,79 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing. } } +func TestPickNextViaHomeDoesNotReuseTriedPinnedWebsocketAuth(t *testing.T) { + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(schedulerTestExecutor{}) + + auth := &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + Attributes: map[string]string{ + "websockets": "true", + }, + } + manager.rememberHomeRuntimeAuth("session-1", auth) + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + opts := cliproxyexecutor.Options{ + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "session-1", + cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1", + }, + } + tried := map[string]struct{}{"home-auth-1": {}} + + got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts, tried) + if errPick == nil { + t.Fatal("pickNextViaHome() error is nil, want home unavailable error") + } + var authErr *Error + if !errors.As(errPick, &authErr) || authErr.Code != "home_unavailable" { + t.Fatalf("pickNextViaHome() error = %v, want home_unavailable", errPick) + } + if got != nil || executor != nil || provider != "" { + t.Fatalf("pickNextViaHome() reused tried auth: auth=%#v executor=%#v provider=%q", got, executor, provider) + } +} + +func TestPickNextViaHomeDoesNotReusePinnedWebsocketAuthAfterFirstHomeAttempt(t *testing.T) { + manager := NewManager(nil, nil, nil) + manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) + manager.RegisterExecutor(schedulerTestExecutor{}) + + auth := &Auth{ + ID: "home-auth-1", + Provider: "test", + Status: StatusActive, + Attributes: map[string]string{ + "websockets": "true", + }, + } + manager.rememberHomeRuntimeAuth("session-1", auth) + + ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background()) + opts := withHomeAuthCount(cliproxyexecutor.Options{ + Metadata: map[string]any{ + cliproxyexecutor.ExecutionSessionMetadataKey: "session-1", + cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1", + }, + }, 2) + + got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts, nil) + if errPick == nil { + t.Fatal("pickNextViaHome() error is nil, want home unavailable error") + } + var authErr *Error + if !errors.As(errPick, &authErr) || authErr.Code != "home_unavailable" { + t.Fatalf("pickNextViaHome() error = %v, want home_unavailable", errPick) + } + if got != nil || executor != nil || provider != "" { + t.Fatalf("pickNextViaHome() reused auth after first home attempt: auth=%#v executor=%#v provider=%q", got, executor, provider) + } +} + func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) { manager := NewManager(nil, nil, nil) manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) @@ -78,7 +151,7 @@ func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) { Headers: http.Header{"Authorization": {"Bearer client-key"}}, } - got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts) + got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts, nil) if errPick == nil { t.Fatal("pickNextViaHome() error is nil, want home unavailable error") } @@ -94,7 +167,7 @@ func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) { func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) { manager := NewManager(nil, nil, nil) manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}}) - manager.rememberHomeRuntimeAuth(&Auth{ + manager.rememberHomeRuntimeAuth("session-1", &Auth{ ID: "home-auth-1", Provider: "test", Attributes: map[string]string{ @@ -111,3 +184,27 @@ func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) { t.Fatal("remembered home auth was not cleared when home was disabled") } } + +func TestCloseExecutionSessionClearsHomeRuntimeAuthForSession(t *testing.T) { + manager := NewManager(nil, nil, nil) + auth := &Auth{ + ID: "home-auth-1", + Provider: "test", + Attributes: map[string]string{ + "websockets": "true", + }, + } + + manager.rememberHomeRuntimeAuth("session-1", auth) + manager.rememberHomeRuntimeAuth("session-2", auth) + + manager.CloseExecutionSession("session-1") + if _, ok := manager.GetByID("home-auth-1"); !ok { + t.Fatal("shared home auth was cleared while another session still referenced it") + } + + manager.CloseExecutionSession("session-2") + if _, ok := manager.GetByID("home-auth-1"); ok { + t.Fatal("home auth was not cleared when its last session closed") + } +}