fix(codex): centralize session management with global store and add tests for executor session lifecycle

This commit is contained in:
Luis Pater
2026-04-01 13:17:10 +08:00
parent 1734aa1664
commit 105a21548f
3 changed files with 159 additions and 18 deletions
@@ -46,10 +46,18 @@ const (
type CodexWebsocketsExecutor struct { type CodexWebsocketsExecutor struct {
*CodexExecutor *CodexExecutor
sessMu sync.Mutex store *codexWebsocketSessionStore
}
type codexWebsocketSessionStore struct {
mu sync.Mutex
sessions map[string]*codexWebsocketSession sessions map[string]*codexWebsocketSession
} }
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
sessions: make(map[string]*codexWebsocketSession),
}
type codexWebsocketSession struct { type codexWebsocketSession struct {
sessionID string sessionID string
@@ -73,7 +81,7 @@ type codexWebsocketSession struct {
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor { func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
return &CodexWebsocketsExecutor{ return &CodexWebsocketsExecutor{
CodexExecutor: NewCodexExecutor(cfg), CodexExecutor: NewCodexExecutor(cfg),
sessions: make(map[string]*codexWebsocketSession), store: globalCodexWebsocketSessionStore,
} }
} }
@@ -1058,16 +1066,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
if sessionID == "" { if sessionID == "" {
return nil return nil
} }
e.sessMu.Lock() if e == nil {
defer e.sessMu.Unlock() return nil
if e.sessions == nil {
e.sessions = make(map[string]*codexWebsocketSession)
} }
if sess, ok := e.sessions[sessionID]; ok && sess != nil { store := e.store
if store == nil {
store = globalCodexWebsocketSessionStore
}
store.mu.Lock()
defer store.mu.Unlock()
if store.sessions == nil {
store.sessions = make(map[string]*codexWebsocketSession)
}
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
return sess return sess
} }
sess := &codexWebsocketSession{sessionID: sessionID} sess := &codexWebsocketSession{sessionID: sessionID}
e.sessions[sessionID] = sess store.sessions[sessionID] = sess
return sess return sess
} }
@@ -1213,14 +1228,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
return return
} }
if sessionID == cliproxyauth.CloseAllExecutionSessionsID { if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
e.closeAllExecutionSessions("executor_replaced") // Executor replacement can happen during hot reload (config/credential changes).
// Do not force-close upstream websocket sessions here, otherwise in-flight
// downstream websocket requests get interrupted.
return return
} }
e.sessMu.Lock() store := e.store
sess := e.sessions[sessionID] if store == nil {
delete(e.sessions, sessionID) store = globalCodexWebsocketSessionStore
e.sessMu.Unlock() }
store.mu.Lock()
sess := store.sessions[sessionID]
delete(store.sessions, sessionID)
store.mu.Unlock()
e.closeExecutionSession(sess, "session_closed") e.closeExecutionSession(sess, "session_closed")
} }
@@ -1230,15 +1251,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
return return
} }
e.sessMu.Lock() store := e.store
sessions := make([]*codexWebsocketSession, 0, len(e.sessions)) if store == nil {
for sessionID, sess := range e.sessions { store = globalCodexWebsocketSessionStore
delete(e.sessions, sessionID) }
store.mu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
delete(store.sessions, sessionID)
if sess != nil { if sess != nil {
sessions = append(sessions, sess) sessions = append(sessions, sess)
} }
} }
e.sessMu.Unlock() store.mu.Unlock()
for i := range sessions { for i := range sessions {
e.closeExecutionSession(sessions[i], reason) e.closeExecutionSession(sessions[i], reason)
@@ -1246,6 +1271,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
} }
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) { func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
closeCodexWebsocketSession(sess, reason)
}
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
if sess == nil { if sess == nil {
return return
} }
@@ -1286,6 +1315,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
} }
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
// associated with the supplied auth ID.
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "auth_removed"
}
store := globalCodexWebsocketSessionStore
if store == nil {
return
}
type sessionItem struct {
sessionID string
sess *codexWebsocketSession
}
store.mu.Lock()
items := make([]sessionItem, 0, len(store.sessions))
for sessionID, sess := range store.sessions {
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
}
store.mu.Unlock()
matches := make([]sessionItem, 0)
for i := range items {
sess := items[i].sess
if sess == nil {
continue
}
sess.connMu.Lock()
sessAuthID := strings.TrimSpace(sess.authID)
sess.connMu.Unlock()
if sessAuthID == authID {
matches = append(matches, items[i])
}
}
if len(matches) == 0 {
return
}
toClose := make([]*codexWebsocketSession, 0, len(matches))
store.mu.Lock()
for i := range matches {
current, ok := store.sessions[matches[i].sessionID]
if !ok || current == nil || current != matches[i].sess {
continue
}
delete(store.sessions, matches[i].sessionID)
toClose = append(toClose, current)
}
store.mu.Unlock()
for i := range toClose {
closeCodexWebsocketSession(toClose[i], reason)
}
}
// CodexAutoExecutor routes Codex requests to the websocket transport only when: // CodexAutoExecutor routes Codex requests to the websocket transport only when:
// 1. The downstream transport is websocket, and // 1. The downstream transport is websocket, and
// 2. The selected auth enables websockets. // 2. The selected auth enables websockets.
@@ -0,0 +1,48 @@
package executor
import (
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
sessionID := "test-session-store-survives-replace"
globalCodexWebsocketSessionStore.mu.Lock()
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
globalCodexWebsocketSessionStore.mu.Unlock()
exec1 := NewCodexWebsocketsExecutor(nil)
sess1 := exec1.getOrCreateSession(sessionID)
if sess1 == nil {
t.Fatalf("expected session to be created")
}
exec2 := NewCodexWebsocketsExecutor(nil)
sess2 := exec2.getOrCreateSession(sessionID)
if sess2 == nil {
t.Fatalf("expected session to be available across executors")
}
if sess1 != sess2 {
t.Fatalf("expected the same session instance across executors")
}
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
globalCodexWebsocketSessionStore.mu.Lock()
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if !stillPresent {
t.Fatalf("expected session to remain after executor replacement close marker")
}
exec2.CloseExecutionSession(sessionID)
globalCodexWebsocketSessionStore.mu.Lock()
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
globalCodexWebsocketSessionStore.mu.Unlock()
if presentAfterClose {
t.Fatalf("expected session to be removed after explicit close")
}
}
+1
View File
@@ -335,6 +335,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
log.Errorf("failed to disable auth %s: %v", id, err) log.Errorf("failed to disable auth %s: %v", id, err)
} }
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") { if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
executor.CloseCodexWebsocketSessionsForAuthID(existing.ID, "auth_removed")
s.ensureExecutorsForAuth(existing) s.ensureExecutorsForAuth(existing)
} }
} }