feat(executor): add upstream disconnect handling for Codex WebSocket sessions
- Introduced `UpstreamDisconnectChan` for Codex WebSocket sessions to notify downstream connections of upstream disconnections. - Implemented `notifyUpstreamDisconnect` to signal errors and close channels on disconnect events. - Added integration tests to validate WebSocket session behavior on upstream disconnect. - Updated OpenAI WebSocket response handlers to properly close connections upon upstream disconnect notifications.
This commit is contained in:
@@ -3,6 +3,7 @@ package executor
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
@@ -92,6 +93,64 @@ func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexWebsocketsUpstreamDisconnectChanSignalsOnInvalidate(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Errorf("upgrade websocket: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
for {
|
||||
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
exec := NewCodexWebsocketsExecutor(&config.Config{})
|
||||
sessionID := "sess-1"
|
||||
disconnectCh := exec.UpstreamDisconnectChan(sessionID)
|
||||
if disconnectCh == nil {
|
||||
t.Fatal("expected disconnect channel")
|
||||
}
|
||||
|
||||
sess := exec.getOrCreateSession(sessionID)
|
||||
if sess == nil {
|
||||
t.Fatal("expected session")
|
||||
}
|
||||
sess.connMu.Lock()
|
||||
sess.conn = conn
|
||||
sess.authID = "auth-1"
|
||||
sess.wsURL = "ws://example.test/responses"
|
||||
sess.readerConn = conn
|
||||
sess.connMu.Unlock()
|
||||
|
||||
upstreamErr := errors.New("upstream gone")
|
||||
exec.invalidateUpstreamConn(sess, conn, "test_invalidate", upstreamErr)
|
||||
|
||||
select {
|
||||
case errRead, ok := <-disconnectCh:
|
||||
if !ok {
|
||||
t.Fatal("expected disconnect channel to deliver error before closing")
|
||||
}
|
||||
if errRead == nil || errRead.Error() != upstreamErr.Error() {
|
||||
t.Fatalf("disconnect error = %v, want %v", errRead, upstreamErr)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for disconnect signal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user