refactor(auth): simplify home auth session management and remove ref counting
- Consolidated `homeRuntimeAuths` to store a map of session-scoped auth maps, replacing `homeRuntimeAuthSessions` and `homeRuntimeAuthRefs`. - Adjusted session cleanup logic to directly remove session-scoped auths without reference counting. - Added `GetExecutionSessionAuthByID` to retrieve auths scoped to a specific execution session. - Updated tests to reflect the new session-scoped caching behavior.
This commit is contained in:
@@ -104,6 +104,15 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
var lastRequest []byte
|
var lastRequest []byte
|
||||||
lastResponseOutput := []byte("[]")
|
lastResponseOutput := []byte("[]")
|
||||||
pinnedAuthID := ""
|
pinnedAuthID := ""
|
||||||
|
sessionAuthByID := func(authID string) (*coreauth.Auth, bool) {
|
||||||
|
if h == nil || h.AuthManager == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if auth, ok := h.AuthManager.GetExecutionSessionAuthByID(passthroughSessionID, authID); ok {
|
||||||
|
return auth, true
|
||||||
|
}
|
||||||
|
return h.AuthManager.GetByID(authID)
|
||||||
|
}
|
||||||
forceTranscriptReplayNextRequest := false
|
forceTranscriptReplayNextRequest := false
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@@ -130,8 +139,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now())
|
appendWebsocketTimelineEvent(&wsTimelineLog, "request", payload, time.Now())
|
||||||
|
|
||||||
allowIncrementalInputWithPreviousResponseID := false
|
allowIncrementalInputWithPreviousResponseID := false
|
||||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
if pinnedAuthID != "" {
|
||||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -146,8 +155,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
allowCompactionReplayBypass := false
|
allowCompactionReplayBypass := false
|
||||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
if pinnedAuthID != "" {
|
||||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
if pinnedAuth, ok := sessionAuthByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
|
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -228,7 +237,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
if authID == "" || h == nil || h.AuthManager == nil {
|
if authID == "" || h == nil || h.AuthManager == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
selectedAuth, ok := h.AuthManager.GetByID(authID)
|
selectedAuth, ok := sessionAuthByID(authID)
|
||||||
if !ok || selectedAuth == nil {
|
if !ok || selectedAuth == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,9 +153,7 @@ type Manager struct {
|
|||||||
scheduler *authScheduler
|
scheduler *authScheduler
|
||||||
// homeRuntimeAuths caches auths returned by Home so websocket sessions can
|
// homeRuntimeAuths caches auths returned by Home so websocket sessions can
|
||||||
// reuse an established upstream credential without dispatching every turn.
|
// reuse an established upstream credential without dispatching every turn.
|
||||||
homeRuntimeAuths map[string]*Auth
|
homeRuntimeAuths map[string]map[string]*Auth
|
||||||
homeRuntimeAuthSessions map[string]map[string]struct{}
|
|
||||||
homeRuntimeAuthRefs map[string]int
|
|
||||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||||
providerOffsets map[string]int
|
providerOffsets map[string]int
|
||||||
|
|
||||||
@@ -195,16 +193,14 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
|||||||
hook = NoopHook{}
|
hook = NoopHook{}
|
||||||
}
|
}
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
store: store,
|
store: store,
|
||||||
executors: make(map[string]ProviderExecutor),
|
executors: make(map[string]ProviderExecutor),
|
||||||
selector: selector,
|
selector: selector,
|
||||||
hook: hook,
|
hook: hook,
|
||||||
auths: make(map[string]*Auth),
|
auths: make(map[string]*Auth),
|
||||||
homeRuntimeAuths: make(map[string]*Auth),
|
homeRuntimeAuths: make(map[string]map[string]*Auth),
|
||||||
homeRuntimeAuthSessions: make(map[string]map[string]struct{}),
|
providerOffsets: make(map[string]int),
|
||||||
homeRuntimeAuthRefs: make(map[string]int),
|
modelPoolOffsets: make(map[string]int),
|
||||||
providerOffsets: make(map[string]int),
|
|
||||||
modelPoolOffsets: make(map[string]int),
|
|
||||||
}
|
}
|
||||||
// atomic.Value requires non-nil initial value.
|
// atomic.Value requires non-nil initial value.
|
||||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||||
@@ -2724,10 +2720,24 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
|
|||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
auth, ok := m.auths[id]
|
auth, ok := m.auths[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
auth, ok = m.homeRuntimeAuths[id]
|
return nil, false
|
||||||
if !ok {
|
}
|
||||||
return nil, false
|
return auth.Clone(), true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetExecutionSessionAuthByID retrieves a Home runtime auth scoped to an execution session.
|
||||||
|
func (m *Manager) GetExecutionSessionAuthByID(sessionID string, authID string) (*Auth, bool) {
|
||||||
|
sessionID = strings.TrimSpace(sessionID)
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if m == nil || sessionID == "" || authID == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
sessionAuths := m.homeRuntimeAuths[sessionID]
|
||||||
|
auth := sessionAuths[authID]
|
||||||
|
if auth == nil {
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
return auth.Clone(), true
|
return auth.Clone(), true
|
||||||
}
|
}
|
||||||
@@ -3218,9 +3228,7 @@ func (m *Manager) clearHomeRuntimeAuthsLocked() {
|
|||||||
if m == nil {
|
if m == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.homeRuntimeAuths = make(map[string]*Auth)
|
m.homeRuntimeAuths = make(map[string]map[string]*Auth)
|
||||||
m.homeRuntimeAuthSessions = make(map[string]map[string]struct{})
|
|
||||||
m.homeRuntimeAuthRefs = make(map[string]int)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) clearHomeRuntimeAuthsForSessionLocked(sessionID string) {
|
func (m *Manager) clearHomeRuntimeAuthsForSessionLocked(sessionID string) {
|
||||||
@@ -3228,21 +3236,7 @@ func (m *Manager) clearHomeRuntimeAuthsForSessionLocked(sessionID string) {
|
|||||||
if m == nil || sessionID == "" {
|
if m == nil || sessionID == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authIDs := m.homeRuntimeAuthSessions[sessionID]
|
delete(m.homeRuntimeAuths, sessionID)
|
||||||
if len(authIDs) == 0 {
|
|
||||||
delete(m.homeRuntimeAuthSessions, sessionID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for authID := range authIDs {
|
|
||||||
refCount := m.homeRuntimeAuthRefs[authID]
|
|
||||||
if refCount <= 1 {
|
|
||||||
delete(m.homeRuntimeAuthRefs, authID)
|
|
||||||
delete(m.homeRuntimeAuths, authID)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
m.homeRuntimeAuthRefs[authID] = refCount - 1
|
|
||||||
}
|
|
||||||
delete(m.homeRuntimeAuthSessions, sessionID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) rememberHomeRuntimeAuth(sessionID string, auth *Auth) {
|
func (m *Manager) rememberHomeRuntimeAuth(sessionID string, auth *Auth) {
|
||||||
@@ -3256,24 +3250,14 @@ func (m *Manager) rememberHomeRuntimeAuth(sessionID string, auth *Auth) {
|
|||||||
}
|
}
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
if m.homeRuntimeAuths == nil {
|
if m.homeRuntimeAuths == nil {
|
||||||
m.homeRuntimeAuths = make(map[string]*Auth)
|
m.homeRuntimeAuths = make(map[string]map[string]*Auth)
|
||||||
}
|
}
|
||||||
if m.homeRuntimeAuthSessions == nil {
|
sessionAuths := m.homeRuntimeAuths[sessionID]
|
||||||
m.homeRuntimeAuthSessions = make(map[string]map[string]struct{})
|
|
||||||
}
|
|
||||||
if m.homeRuntimeAuthRefs == nil {
|
|
||||||
m.homeRuntimeAuthRefs = make(map[string]int)
|
|
||||||
}
|
|
||||||
m.homeRuntimeAuths[authID] = auth.Clone()
|
|
||||||
sessionAuths := m.homeRuntimeAuthSessions[sessionID]
|
|
||||||
if sessionAuths == nil {
|
if sessionAuths == nil {
|
||||||
sessionAuths = make(map[string]struct{})
|
sessionAuths = make(map[string]*Auth)
|
||||||
m.homeRuntimeAuthSessions[sessionID] = sessionAuths
|
m.homeRuntimeAuths[sessionID] = sessionAuths
|
||||||
}
|
|
||||||
if _, exists := sessionAuths[authID]; !exists {
|
|
||||||
sessionAuths[authID] = struct{}{}
|
|
||||||
m.homeRuntimeAuthRefs[authID]++
|
|
||||||
}
|
}
|
||||||
|
sessionAuths[authID] = auth.Clone()
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3284,12 +3268,8 @@ func (m *Manager) homeRuntimeAuthByID(sessionID string, authID string) (*Auth, P
|
|||||||
return nil, nil, "", false
|
return nil, nil, "", false
|
||||||
}
|
}
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
sessionAuths := m.homeRuntimeAuthSessions[sessionID]
|
sessionAuths := m.homeRuntimeAuths[sessionID]
|
||||||
if _, ok := sessionAuths[authID]; !ok {
|
auth := sessionAuths[authID]
|
||||||
m.mu.RUnlock()
|
|
||||||
return nil, nil, "", false
|
|
||||||
}
|
|
||||||
auth := m.homeRuntimeAuths[authID]
|
|
||||||
m.mu.RUnlock()
|
m.mu.RUnlock()
|
||||||
if auth == nil || !authWebsocketsEnabled(auth) {
|
if auth == nil || !authWebsocketsEnabled(auth) {
|
||||||
return nil, nil, "", false
|
return nil, nil, "", false
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing.
|
|||||||
}
|
}
|
||||||
auth.EnsureIndex()
|
auth.EnsureIndex()
|
||||||
manager.rememberHomeRuntimeAuth("session-1", auth)
|
manager.rememberHomeRuntimeAuth("session-1", auth)
|
||||||
cachedAuth, ok := manager.GetByID("home-auth-1")
|
cachedAuth, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1")
|
||||||
if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) {
|
if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) {
|
||||||
t.Fatalf("GetByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok)
|
t.Fatalf("GetExecutionSessionAuthByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||||
@@ -56,6 +56,61 @@ func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPickNextViaHomeKeepsSameAuthIDPayloadSessionScoped(t *testing.T) {
|
||||||
|
manager := NewManager(nil, nil, nil)
|
||||||
|
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
|
||||||
|
manager.RegisterExecutor(schedulerTestExecutor{})
|
||||||
|
|
||||||
|
manager.rememberHomeRuntimeAuth("session-1", &Auth{
|
||||||
|
ID: "home-auth-1",
|
||||||
|
Provider: "test",
|
||||||
|
Status: StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"websockets": "true",
|
||||||
|
homeUpstreamModelAttributeKey: "upstream-model-a",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
manager.rememberHomeRuntimeAuth("session-2", &Auth{
|
||||||
|
ID: "home-auth-1",
|
||||||
|
Provider: "test",
|
||||||
|
Status: StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"websockets": "true",
|
||||||
|
homeUpstreamModelAttributeKey: "upstream-model-b",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||||
|
optsSession1 := cliproxyexecutor.Options{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
cliproxyexecutor.ExecutionSessionMetadataKey: "session-1",
|
||||||
|
cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
optsSession2 := cliproxyexecutor.Options{
|
||||||
|
Metadata: map[string]any{
|
||||||
|
cliproxyexecutor.ExecutionSessionMetadataKey: "session-2",
|
||||||
|
cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
gotSession1, _, _, errSession1 := manager.pickNextViaHome(ctx, "gpt-5.4", optsSession1, nil)
|
||||||
|
if errSession1 != nil {
|
||||||
|
t.Fatalf("pickNextViaHome(session-1) error = %v", errSession1)
|
||||||
|
}
|
||||||
|
if got := gotSession1.Attributes[homeUpstreamModelAttributeKey]; got != "upstream-model-a" {
|
||||||
|
t.Fatalf("pickNextViaHome(session-1) upstream model = %q, want upstream-model-a", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotSession2, _, _, errSession2 := manager.pickNextViaHome(ctx, "gpt-5.4", optsSession2, nil)
|
||||||
|
if errSession2 != nil {
|
||||||
|
t.Fatalf("pickNextViaHome(session-2) error = %v", errSession2)
|
||||||
|
}
|
||||||
|
if got := gotSession2.Attributes[homeUpstreamModelAttributeKey]; got != "upstream-model-b" {
|
||||||
|
t.Fatalf("pickNextViaHome(session-2) upstream model = %q, want upstream-model-b", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestPickNextViaHomeDoesNotReuseTriedPinnedWebsocketAuth(t *testing.T) {
|
func TestPickNextViaHomeDoesNotReuseTriedPinnedWebsocketAuth(t *testing.T) {
|
||||||
manager := NewManager(nil, nil, nil)
|
manager := NewManager(nil, nil, nil)
|
||||||
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
|
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
|
||||||
@@ -135,10 +190,12 @@ func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) {
|
|||||||
manager.RegisterExecutor(schedulerTestExecutor{})
|
manager.RegisterExecutor(schedulerTestExecutor{})
|
||||||
|
|
||||||
manager.mu.Lock()
|
manager.mu.Lock()
|
||||||
manager.homeRuntimeAuths["home-auth-1"] = &Auth{
|
manager.homeRuntimeAuths["session-1"] = map[string]*Auth{
|
||||||
ID: "home-auth-1",
|
"home-auth-1": &Auth{
|
||||||
Provider: "test",
|
ID: "home-auth-1",
|
||||||
Status: StatusActive,
|
Provider: "test",
|
||||||
|
Status: StatusActive,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
manager.mu.Unlock()
|
manager.mu.Unlock()
|
||||||
|
|
||||||
@@ -175,12 +232,12 @@ func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if _, ok := manager.GetByID("home-auth-1"); !ok {
|
if _, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1"); !ok {
|
||||||
t.Fatal("expected remembered home auth before disabling home")
|
t.Fatal("expected remembered home auth before disabling home")
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.SetConfig(&internalconfig.Config{})
|
manager.SetConfig(&internalconfig.Config{})
|
||||||
if _, ok := manager.GetByID("home-auth-1"); ok {
|
if _, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1"); ok {
|
||||||
t.Fatal("remembered home auth was not cleared when home was disabled")
|
t.Fatal("remembered home auth was not cleared when home was disabled")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,12 +256,15 @@ func TestCloseExecutionSessionClearsHomeRuntimeAuthForSession(t *testing.T) {
|
|||||||
manager.rememberHomeRuntimeAuth("session-2", auth)
|
manager.rememberHomeRuntimeAuth("session-2", auth)
|
||||||
|
|
||||||
manager.CloseExecutionSession("session-1")
|
manager.CloseExecutionSession("session-1")
|
||||||
if _, ok := manager.GetByID("home-auth-1"); !ok {
|
if _, ok := manager.GetExecutionSessionAuthByID("session-1", "home-auth-1"); ok {
|
||||||
t.Fatal("shared home auth was cleared while another session still referenced it")
|
t.Fatal("home auth for closed session was not cleared")
|
||||||
|
}
|
||||||
|
if _, ok := manager.GetExecutionSessionAuthByID("session-2", "home-auth-1"); !ok {
|
||||||
|
t.Fatal("home auth for another session was cleared")
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.CloseExecutionSession("session-2")
|
manager.CloseExecutionSession("session-2")
|
||||||
if _, ok := manager.GetByID("home-auth-1"); ok {
|
if _, ok := manager.GetExecutionSessionAuthByID("session-2", "home-auth-1"); ok {
|
||||||
t.Fatal("home auth was not cleared when its last session closed")
|
t.Fatal("home auth was not cleared when its last session closed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user