feat(executor): add upstream disconnect handling for Codex WebSocket sessions

- Introduced `UpstreamDisconnectChan` for Codex WebSocket sessions to notify downstream connections of upstream disconnections.
- Implemented `notifyUpstreamDisconnect` to signal errors and close channels on disconnect events.
- Added integration tests to validate WebSocket session behavior on upstream disconnect.
- Updated OpenAI WebSocket response handlers to properly close connections upon upstream disconnect notifications.
This commit is contained in:
Luis Pater
2026-05-06 22:09:33 +08:00
parent ed1458aa6d
commit fb08b92402
4 changed files with 233 additions and 1 deletions
@@ -3,6 +3,7 @@ package executor
import (
"bytes"
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
@@ -92,6 +93,64 @@ func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T)
}
}
func TestCodexWebsocketsUpstreamDisconnectChanSignalsOnInvalidate(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Errorf("upgrade websocket: %v", err)
return
}
defer func() { _ = conn.Close() }()
for {
if _, _, errRead := conn.ReadMessage(); errRead != nil {
return
}
}
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() { _ = conn.Close() }()
exec := NewCodexWebsocketsExecutor(&config.Config{})
sessionID := "sess-1"
disconnectCh := exec.UpstreamDisconnectChan(sessionID)
if disconnectCh == nil {
t.Fatal("expected disconnect channel")
}
sess := exec.getOrCreateSession(sessionID)
if sess == nil {
t.Fatal("expected session")
}
sess.connMu.Lock()
sess.conn = conn
sess.authID = "auth-1"
sess.wsURL = "ws://example.test/responses"
sess.readerConn = conn
sess.connMu.Unlock()
upstreamErr := errors.New("upstream gone")
exec.invalidateUpstreamConn(sess, conn, "test_invalidate", upstreamErr)
select {
case errRead, ok := <-disconnectCh:
if !ok {
t.Fatal("expected disconnect channel to deliver error before closing")
}
if errRead == nil || errRead.Error() != upstreamErr.Error() {
t.Fatalf("disconnect error = %v, want %v", errRead, upstreamErr)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for disconnect signal")
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)