feat(executor): add upstream disconnect handling for Codex WebSocket sessions

- Introduced `UpstreamDisconnectChan` for Codex WebSocket sessions to notify downstream connections of upstream disconnections.
- Implemented `notifyUpstreamDisconnect` to signal errors and close channels on disconnect events.
- Added integration tests to validate WebSocket session behavior on upstream disconnect.
- Updated OpenAI WebSocket response handlers to properly close connections upon upstream disconnect notifications.
This commit is contained in:
Luis Pater
2026-05-06 22:09:33 +08:00
parent ed1458aa6d
commit fb08b92402
4 changed files with 233 additions and 1 deletions
@@ -76,6 +76,9 @@ type codexWebsocketSession struct {
activeCancel context.CancelFunc
readerConn *websocket.Conn
upstreamDisconnectOnce sync.Once
upstreamDisconnectCh chan error
}
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
@@ -151,6 +154,22 @@ func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) {
})
}
func (s *codexWebsocketSession) notifyUpstreamDisconnect(err error) {
if s == nil {
return
}
s.upstreamDisconnectOnce.Do(func() {
if s.upstreamDisconnectCh == nil {
return
}
select {
case s.upstreamDisconnectCh <- err:
default:
}
close(s.upstreamDisconnectCh)
})
}
func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if ctx == nil {
ctx = context.Background()
@@ -1221,11 +1240,22 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
return sess
}
sess := &codexWebsocketSession{sessionID: sessionID}
sess := &codexWebsocketSession{
sessionID: sessionID,
upstreamDisconnectCh: make(chan error, 1),
}
store.sessions[sessionID] = sess
return sess
}
func (e *CodexWebsocketsExecutor) UpstreamDisconnectChan(sessionID string) <-chan error {
sess := e.getOrCreateSession(sessionID)
if sess == nil {
return nil
}
return sess.upstreamDisconnectCh
}
func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
if sess == nil {
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
@@ -1354,6 +1384,7 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
sess.connMu.Unlock()
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err)
sess.notifyUpstreamDisconnect(err)
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
@@ -1592,6 +1623,13 @@ func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) {
e.wsExec.CloseExecutionSession(sessionID)
}
func (e *CodexAutoExecutor) UpstreamDisconnectChan(sessionID string) <-chan error {
if e == nil || e.wsExec == nil {
return nil
}
return e.wsExec.UpstreamDisconnectChan(sessionID)
}
func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool {
if auth == nil {
return false
@@ -3,6 +3,7 @@ package executor
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
@@ -92,6 +93,64 @@ func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T)
}
}
func TestCodexWebsocketsUpstreamDisconnectChanSignalsOnInvalidate(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket: %v", err)
return
}
defer func() { _ = conn.Close() }()
for {
if _, _, errRead := conn.ReadMessage(); errRead != nil {
return
}
}
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() { _ = conn.Close() }()
exec := NewCodexWebsocketsExecutor(&config.Config{})
sessionID := "sess-1"
disconnectCh := exec.UpstreamDisconnectChan(sessionID)
if disconnectCh == nil {
t.Fatal("expected disconnect channel")
}
sess := exec.getOrCreateSession(sessionID)
if sess == nil {
t.Fatal("expected session")
}
sess.connMu.Lock()
sess.conn = conn
sess.authID = "auth-1"
sess.wsURL = "ws://example.test/responses"
sess.readerConn = conn
sess.connMu.Unlock()
upstreamErr := errors.New("upstream gone")
exec.invalidateUpstreamConn(sess, conn, "test_invalidate", upstreamErr)
select {
case errRead, ok := <-disconnectCh:
if !ok {
t.Fatal("expected disconnect channel to deliver error before closing")
}
if errRead == nil || errRead.Error() != upstreamErr.Error() {
t.Fatalf("disconnect error = %v, want %v", errRead, upstreamErr)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for disconnect signal")
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)