From fb08b92402187cd87f888d371ee5d6f5107f6d36 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 6 May 2026 22:09:33 +0800 Subject: [PATCH] feat(executor): add upstream disconnect handling for Codex WebSocket sessions - Introduced `UpstreamDisconnectChan` for Codex WebSocket sessions to notify downstream connections of upstream disconnections. - Implemented `notifyUpstreamDisconnect` to signal errors and close channels on disconnect events. - Added integration tests to validate WebSocket session behavior on upstream disconnect. - Updated OpenAI WebSocket response handlers to properly close connections upon upstream disconnect notifications. --- .../executor/codex_websockets_executor.go | 40 ++++++- .../codex_websockets_executor_test.go | 59 ++++++++++ .../openai/openai_responses_websocket.go | 25 ++++ .../openai/openai_responses_websocket_test.go | 110 ++++++++++++++++++ 4 files changed, 233 insertions(+), 1 deletion(-) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index d6f1de86..94b78b66 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -76,6 +76,9 @@ type codexWebsocketSession struct { activeCancel context.CancelFunc readerConn *websocket.Conn + + upstreamDisconnectOnce sync.Once + upstreamDisconnectCh chan error } func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor { @@ -151,6 +154,22 @@ func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) { }) } +func (s *codexWebsocketSession) notifyUpstreamDisconnect(err error) { + if s == nil { + return + } + s.upstreamDisconnectOnce.Do(func() { + if s.upstreamDisconnectCh == nil { + return + } + select { + case s.upstreamDisconnectCh <- err: + default: + } + close(s.upstreamDisconnectCh) + }) +} + func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { if ctx == nil { ctx = context.Background() @@ -1221,11 +1240,22 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb if sess, ok := store.sessions[sessionID]; ok && sess != nil { return sess } - sess := &codexWebsocketSession{sessionID: sessionID} + sess := &codexWebsocketSession{ + sessionID: sessionID, + upstreamDisconnectCh: make(chan error, 1), + } store.sessions[sessionID] = sess return sess } +func (e *CodexWebsocketsExecutor) UpstreamDisconnectChan(sessionID string) <-chan error { + sess := e.getOrCreateSession(sessionID) + if sess == nil { + return nil + } + return sess.upstreamDisconnectCh +} + func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { if sess == nil { return e.dialCodexWebsocket(ctx, auth, wsURL, headers) @@ -1354,6 +1384,7 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes sess.connMu.Unlock() logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err) + sess.notifyUpstreamDisconnect(err) if errClose := conn.Close(); errClose != nil { log.Errorf("codex websockets executor: close websocket error: %v", errClose) } @@ -1592,6 +1623,13 @@ func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) { e.wsExec.CloseExecutionSession(sessionID) } +func (e *CodexAutoExecutor) UpstreamDisconnectChan(sessionID string) <-chan error { + if e == nil || e.wsExec == nil { + return nil + } + return e.wsExec.UpstreamDisconnectChan(sessionID) +} + func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool { if auth == nil { return false diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index 9c7bb591..fbcf9c45 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -3,6 +3,7 @@ package executor import ( "bytes" "context" + "errors" "net/http" "net/http/httptest" "strings" @@ -92,6 +93,64 @@ func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T) } } +func TestCodexWebsocketsUpstreamDisconnectChanSignalsOnInvalidate(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket: %v", err) + return + } + defer func() { _ = conn.Close() }() + for { + if _, _, errRead := conn.ReadMessage(); errRead != nil { + return + } + } + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { _ = conn.Close() }() + + exec := NewCodexWebsocketsExecutor(&config.Config{}) + sessionID := "sess-1" + disconnectCh := exec.UpstreamDisconnectChan(sessionID) + if disconnectCh == nil { + t.Fatal("expected disconnect channel") + } + + sess := exec.getOrCreateSession(sessionID) + if sess == nil { + t.Fatal("expected session") + } + sess.connMu.Lock() + sess.conn = conn + sess.authID = "auth-1" + sess.wsURL = "ws://example.test/responses" + sess.readerConn = conn + sess.connMu.Unlock() + + upstreamErr := errors.New("upstream gone") + exec.invalidateUpstreamConn(sess, conn, "test_invalidate", upstreamErr) + + select { + case errRead, ok := <-disconnectCh: + if !ok { + t.Fatal("expected disconnect channel to deliver error before closing") + } + if errRead == nil || errRead.Error() != upstreamErr.Error() { + t.Fatalf("disconnect error = %v, want %v", errRead, upstreamErr) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for disconnect signal") + } +} + func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil) diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 7a9d2224..c617c946 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -56,6 +56,31 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { retainResponsesWebsocketToolCaches(downstreamSessionKey) clientIP := websocketClientAddress(c) log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP) + + wsDone := make(chan struct{}) + defer close(wsDone) + + if h != nil && h.AuthManager != nil { + if exec, ok := h.AuthManager.Executor("codex"); ok && exec != nil { + type upstreamDisconnectSubscriber interface { + UpstreamDisconnectChan(sessionID string) <-chan error + } + if subscriber, ok := exec.(upstreamDisconnectSubscriber); ok && subscriber != nil { + disconnectCh := subscriber.UpstreamDisconnectChan(passthroughSessionID) + if disconnectCh != nil { + go func() { + select { + case <-wsDone: + return + case <-disconnectCh: + _ = conn.Close() + } + }() + } + } + } + } + var wsTerminateErr error var wsTimelineLog strings.Builder defer func() { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index 1d397ecd..319127f0 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -85,6 +85,79 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg } func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status } +type websocketUpstreamDisconnectExecutor struct { + mu sync.Mutex + subscribed chan string + sessions map[string]chan error +} + +func (e *websocketUpstreamDisconnectExecutor) Identifier() string { return "codex" } + +func (e *websocketUpstreamDisconnectExecutor) UpstreamDisconnectChan(sessionID string) <-chan error { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return nil + } + e.mu.Lock() + if e.sessions == nil { + e.sessions = make(map[string]chan error) + } + ch, ok := e.sessions[sessionID] + if !ok { + ch = make(chan error, 1) + e.sessions[sessionID] = ch + } + subscribed := e.subscribed + e.mu.Unlock() + + if subscribed != nil { + select { + case subscribed <- sessionID: + default: + } + } + return ch +} + +func (e *websocketUpstreamDisconnectExecutor) TriggerDisconnect(sessionID string, err error) { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return + } + e.mu.Lock() + ch := e.sessions[sessionID] + delete(e.sessions, sessionID) + e.mu.Unlock() + if ch == nil { + return + } + select { + case ch <- err: + default: + } + close(ch) +} + +func (e *websocketUpstreamDisconnectExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketUpstreamDisconnectExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketUpstreamDisconnectExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketUpstreamDisconnectExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketUpstreamDisconnectExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -934,6 +1007,43 @@ func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) { } } +func TestResponsesWebsocketClosesOnCodexUpstreamDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + executor := &websocketUpstreamDisconnectExecutor{subscribed: make(chan string, 1)} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { _ = conn.Close() }() + + var sessionID string + select { + case sessionID = <-executor.subscribed: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for upstream disconnect subscription") + } + + executor.TriggerDisconnect(sessionID, errors.New("upstream disconnected")) + + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _, err = conn.ReadMessage() + if err == nil { + t.Fatalf("expected downstream websocket to close after upstream disconnect") + } +} + func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) { manager := coreauth.NewManager(nil, nil, nil) auth := &coreauth.Auth{