fix(codex): centralize session management with global store and add tests for executor session lifecycle
This commit is contained in:
@@ -46,10 +46,18 @@ const (
|
|||||||
type CodexWebsocketsExecutor struct {
|
type CodexWebsocketsExecutor struct {
|
||||||
*CodexExecutor
|
*CodexExecutor
|
||||||
|
|
||||||
sessMu sync.Mutex
|
store *codexWebsocketSessionStore
|
||||||
|
}
|
||||||
|
|
||||||
|
type codexWebsocketSessionStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
sessions map[string]*codexWebsocketSession
|
sessions map[string]*codexWebsocketSession
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{
|
||||||
|
sessions: make(map[string]*codexWebsocketSession),
|
||||||
|
}
|
||||||
|
|
||||||
type codexWebsocketSession struct {
|
type codexWebsocketSession struct {
|
||||||
sessionID string
|
sessionID string
|
||||||
|
|
||||||
@@ -73,7 +81,7 @@ type codexWebsocketSession struct {
|
|||||||
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
||||||
return &CodexWebsocketsExecutor{
|
return &CodexWebsocketsExecutor{
|
||||||
CodexExecutor: NewCodexExecutor(cfg),
|
CodexExecutor: NewCodexExecutor(cfg),
|
||||||
sessions: make(map[string]*codexWebsocketSession),
|
store: globalCodexWebsocketSessionStore,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1058,16 +1066,23 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
|
|||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
e.sessMu.Lock()
|
if e == nil {
|
||||||
defer e.sessMu.Unlock()
|
return nil
|
||||||
if e.sessions == nil {
|
|
||||||
e.sessions = make(map[string]*codexWebsocketSession)
|
|
||||||
}
|
}
|
||||||
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
|
store := e.store
|
||||||
|
if store == nil {
|
||||||
|
store = globalCodexWebsocketSessionStore
|
||||||
|
}
|
||||||
|
store.mu.Lock()
|
||||||
|
defer store.mu.Unlock()
|
||||||
|
if store.sessions == nil {
|
||||||
|
store.sessions = make(map[string]*codexWebsocketSession)
|
||||||
|
}
|
||||||
|
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
sess := &codexWebsocketSession{sessionID: sessionID}
|
sess := &codexWebsocketSession{sessionID: sessionID}
|
||||||
e.sessions[sessionID] = sess
|
store.sessions[sessionID] = sess
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1213,14 +1228,20 @@ func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
|
||||||
e.closeAllExecutionSessions("executor_replaced")
|
// Executor replacement can happen during hot reload (config/credential changes).
|
||||||
|
// Do not force-close upstream websocket sessions here, otherwise in-flight
|
||||||
|
// downstream websocket requests get interrupted.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
e.sessMu.Lock()
|
store := e.store
|
||||||
sess := e.sessions[sessionID]
|
if store == nil {
|
||||||
delete(e.sessions, sessionID)
|
store = globalCodexWebsocketSessionStore
|
||||||
e.sessMu.Unlock()
|
}
|
||||||
|
store.mu.Lock()
|
||||||
|
sess := store.sessions[sessionID]
|
||||||
|
delete(store.sessions, sessionID)
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
e.closeExecutionSession(sess, "session_closed")
|
e.closeExecutionSession(sess, "session_closed")
|
||||||
}
|
}
|
||||||
@@ -1230,15 +1251,19 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
e.sessMu.Lock()
|
store := e.store
|
||||||
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
|
if store == nil {
|
||||||
for sessionID, sess := range e.sessions {
|
store = globalCodexWebsocketSessionStore
|
||||||
delete(e.sessions, sessionID)
|
}
|
||||||
|
store.mu.Lock()
|
||||||
|
sessions := make([]*codexWebsocketSession, 0, len(store.sessions))
|
||||||
|
for sessionID, sess := range store.sessions {
|
||||||
|
delete(store.sessions, sessionID)
|
||||||
if sess != nil {
|
if sess != nil {
|
||||||
sessions = append(sessions, sess)
|
sessions = append(sessions, sess)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
e.sessMu.Unlock()
|
store.mu.Unlock()
|
||||||
|
|
||||||
for i := range sessions {
|
for i := range sessions {
|
||||||
e.closeExecutionSession(sessions[i], reason)
|
e.closeExecutionSession(sessions[i], reason)
|
||||||
@@ -1246,6 +1271,10 @@ func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
|
||||||
|
closeCodexWebsocketSession(sess, reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) {
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1286,6 +1315,69 @@ func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string
|
|||||||
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions
|
||||||
|
// associated with the supplied auth ID.
|
||||||
|
func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) {
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reason = strings.TrimSpace(reason)
|
||||||
|
if reason == "" {
|
||||||
|
reason = "auth_removed"
|
||||||
|
}
|
||||||
|
|
||||||
|
store := globalCodexWebsocketSessionStore
|
||||||
|
if store == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionItem struct {
|
||||||
|
sessionID string
|
||||||
|
sess *codexWebsocketSession
|
||||||
|
}
|
||||||
|
|
||||||
|
store.mu.Lock()
|
||||||
|
items := make([]sessionItem, 0, len(store.sessions))
|
||||||
|
for sessionID, sess := range store.sessions {
|
||||||
|
items = append(items, sessionItem{sessionID: sessionID, sess: sess})
|
||||||
|
}
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
|
matches := make([]sessionItem, 0)
|
||||||
|
for i := range items {
|
||||||
|
sess := items[i].sess
|
||||||
|
if sess == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
sess.connMu.Lock()
|
||||||
|
sessAuthID := strings.TrimSpace(sess.authID)
|
||||||
|
sess.connMu.Unlock()
|
||||||
|
if sessAuthID == authID {
|
||||||
|
matches = append(matches, items[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
toClose := make([]*codexWebsocketSession, 0, len(matches))
|
||||||
|
store.mu.Lock()
|
||||||
|
for i := range matches {
|
||||||
|
current, ok := store.sessions[matches[i].sessionID]
|
||||||
|
if !ok || current == nil || current != matches[i].sess {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(store.sessions, matches[i].sessionID)
|
||||||
|
toClose = append(toClose, current)
|
||||||
|
}
|
||||||
|
store.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range toClose {
|
||||||
|
closeCodexWebsocketSession(toClose[i], reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
|
||||||
// 1. The downstream transport is websocket, and
|
// 1. The downstream transport is websocket, and
|
||||||
// 2. The selected auth enables websockets.
|
// 2. The selected auth enables websockets.
|
||||||
|
|||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCodexWebsocketsExecutor_SessionStoreSurvivesExecutorReplacement(t *testing.T) {
|
||||||
|
sessionID := "test-session-store-survives-replace"
|
||||||
|
|
||||||
|
globalCodexWebsocketSessionStore.mu.Lock()
|
||||||
|
delete(globalCodexWebsocketSessionStore.sessions, sessionID)
|
||||||
|
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||||
|
|
||||||
|
exec1 := NewCodexWebsocketsExecutor(nil)
|
||||||
|
sess1 := exec1.getOrCreateSession(sessionID)
|
||||||
|
if sess1 == nil {
|
||||||
|
t.Fatalf("expected session to be created")
|
||||||
|
}
|
||||||
|
|
||||||
|
exec2 := NewCodexWebsocketsExecutor(nil)
|
||||||
|
sess2 := exec2.getOrCreateSession(sessionID)
|
||||||
|
if sess2 == nil {
|
||||||
|
t.Fatalf("expected session to be available across executors")
|
||||||
|
}
|
||||||
|
if sess1 != sess2 {
|
||||||
|
t.Fatalf("expected the same session instance across executors")
|
||||||
|
}
|
||||||
|
|
||||||
|
exec1.CloseExecutionSession(cliproxyauth.CloseAllExecutionSessionsID)
|
||||||
|
|
||||||
|
globalCodexWebsocketSessionStore.mu.Lock()
|
||||||
|
_, stillPresent := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||||
|
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||||
|
if !stillPresent {
|
||||||
|
t.Fatalf("expected session to remain after executor replacement close marker")
|
||||||
|
}
|
||||||
|
|
||||||
|
exec2.CloseExecutionSession(sessionID)
|
||||||
|
|
||||||
|
globalCodexWebsocketSessionStore.mu.Lock()
|
||||||
|
_, presentAfterClose := globalCodexWebsocketSessionStore.sessions[sessionID]
|
||||||
|
globalCodexWebsocketSessionStore.mu.Unlock()
|
||||||
|
if presentAfterClose {
|
||||||
|
t.Fatalf("expected session to be removed after explicit close")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -335,6 +335,7 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
|||||||
log.Errorf("failed to disable auth %s: %v", id, err)
|
log.Errorf("failed to disable auth %s: %v", id, err)
|
||||||
}
|
}
|
||||||
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
|
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
|
||||||
|
executor.CloseCodexWebsocketSessionsForAuthID(existing.ID, "auth_removed")
|
||||||
s.ensureExecutorsForAuth(existing)
|
s.ensureExecutorsForAuth(existing)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user