feat(executor): add upstream disconnect handling for Codex WebSocket sessions
- Introduced `UpstreamDisconnectChan` for Codex WebSocket sessions to notify downstream connections of upstream disconnections. - Implemented `notifyUpstreamDisconnect` to signal errors and close channels on disconnect events. - Added integration tests to validate WebSocket session behavior on upstream disconnect. - Updated OpenAI WebSocket response handlers to properly close connections upon upstream disconnect notifications.
This commit is contained in:
@@ -76,6 +76,9 @@ type codexWebsocketSession struct {
|
|||||||
activeCancel context.CancelFunc
|
activeCancel context.CancelFunc
|
||||||
|
|
||||||
readerConn *websocket.Conn
|
readerConn *websocket.Conn
|
||||||
|
|
||||||
|
upstreamDisconnectOnce sync.Once
|
||||||
|
upstreamDisconnectCh chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
|
||||||
@@ -151,6 +154,22 @@ func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *codexWebsocketSession) notifyUpstreamDisconnect(err error) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.upstreamDisconnectOnce.Do(func() {
|
||||||
|
if s.upstreamDisconnectCh == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case s.upstreamDisconnectCh <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
close(s.upstreamDisconnectCh)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
ctx = context.Background()
|
ctx = context.Background()
|
||||||
@@ -1221,11 +1240,22 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb
|
|||||||
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
|
if sess, ok := store.sessions[sessionID]; ok && sess != nil {
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
sess := &codexWebsocketSession{sessionID: sessionID}
|
sess := &codexWebsocketSession{
|
||||||
|
sessionID: sessionID,
|
||||||
|
upstreamDisconnectCh: make(chan error, 1),
|
||||||
|
}
|
||||||
store.sessions[sessionID] = sess
|
store.sessions[sessionID] = sess
|
||||||
return sess
|
return sess
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *CodexWebsocketsExecutor) UpstreamDisconnectChan(sessionID string) <-chan error {
|
||||||
|
sess := e.getOrCreateSession(sessionID)
|
||||||
|
if sess == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return sess.upstreamDisconnectCh
|
||||||
|
}
|
||||||
|
|
||||||
func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
|
||||||
if sess == nil {
|
if sess == nil {
|
||||||
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
|
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
|
||||||
@@ -1354,6 +1384,7 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
|
|||||||
sess.connMu.Unlock()
|
sess.connMu.Unlock()
|
||||||
|
|
||||||
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err)
|
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err)
|
||||||
|
sess.notifyUpstreamDisconnect(err)
|
||||||
if errClose := conn.Close(); errClose != nil {
|
if errClose := conn.Close(); errClose != nil {
|
||||||
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
|
||||||
}
|
}
|
||||||
@@ -1592,6 +1623,13 @@ func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) {
|
|||||||
e.wsExec.CloseExecutionSession(sessionID)
|
e.wsExec.CloseExecutionSession(sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *CodexAutoExecutor) UpstreamDisconnectChan(sessionID string) <-chan error {
|
||||||
|
if e == nil || e.wsExec == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.wsExec.UpstreamDisconnectChan(sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool {
|
func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package executor
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -92,6 +93,64 @@ func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCodexWebsocketsUpstreamDisconnectChanSignalsOnInvalidate(t *testing.T) {
|
||||||
|
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("upgrade websocket: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
for {
|
||||||
|
if _, _, errRead := conn.ReadMessage(); errRead != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
exec := NewCodexWebsocketsExecutor(&config.Config{})
|
||||||
|
sessionID := "sess-1"
|
||||||
|
disconnectCh := exec.UpstreamDisconnectChan(sessionID)
|
||||||
|
if disconnectCh == nil {
|
||||||
|
t.Fatal("expected disconnect channel")
|
||||||
|
}
|
||||||
|
|
||||||
|
sess := exec.getOrCreateSession(sessionID)
|
||||||
|
if sess == nil {
|
||||||
|
t.Fatal("expected session")
|
||||||
|
}
|
||||||
|
sess.connMu.Lock()
|
||||||
|
sess.conn = conn
|
||||||
|
sess.authID = "auth-1"
|
||||||
|
sess.wsURL = "ws://example.test/responses"
|
||||||
|
sess.readerConn = conn
|
||||||
|
sess.connMu.Unlock()
|
||||||
|
|
||||||
|
upstreamErr := errors.New("upstream gone")
|
||||||
|
exec.invalidateUpstreamConn(sess, conn, "test_invalidate", upstreamErr)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case errRead, ok := <-disconnectCh:
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected disconnect channel to deliver error before closing")
|
||||||
|
}
|
||||||
|
if errRead == nil || errRead.Error() != upstreamErr.Error() {
|
||||||
|
t.Fatalf("disconnect error = %v, want %v", errRead, upstreamErr)
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for disconnect signal")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,31 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
retainResponsesWebsocketToolCaches(downstreamSessionKey)
|
retainResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||||
clientIP := websocketClientAddress(c)
|
clientIP := websocketClientAddress(c)
|
||||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
|
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
|
||||||
|
|
||||||
|
wsDone := make(chan struct{})
|
||||||
|
defer close(wsDone)
|
||||||
|
|
||||||
|
if h != nil && h.AuthManager != nil {
|
||||||
|
if exec, ok := h.AuthManager.Executor("codex"); ok && exec != nil {
|
||||||
|
type upstreamDisconnectSubscriber interface {
|
||||||
|
UpstreamDisconnectChan(sessionID string) <-chan error
|
||||||
|
}
|
||||||
|
if subscriber, ok := exec.(upstreamDisconnectSubscriber); ok && subscriber != nil {
|
||||||
|
disconnectCh := subscriber.UpstreamDisconnectChan(passthroughSessionID)
|
||||||
|
if disconnectCh != nil {
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-wsDone:
|
||||||
|
return
|
||||||
|
case <-disconnectCh:
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var wsTerminateErr error
|
var wsTerminateErr error
|
||||||
var wsTimelineLog strings.Builder
|
var wsTimelineLog strings.Builder
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -85,6 +85,79 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
|
|||||||
|
|
||||||
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
|
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
|
||||||
|
|
||||||
|
type websocketUpstreamDisconnectExecutor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
subscribed chan string
|
||||||
|
sessions map[string]chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) UpstreamDisconnectChan(sessionID string) <-chan error {
|
||||||
|
sessionID = strings.TrimSpace(sessionID)
|
||||||
|
if sessionID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
e.mu.Lock()
|
||||||
|
if e.sessions == nil {
|
||||||
|
e.sessions = make(map[string]chan error)
|
||||||
|
}
|
||||||
|
ch, ok := e.sessions[sessionID]
|
||||||
|
if !ok {
|
||||||
|
ch = make(chan error, 1)
|
||||||
|
e.sessions[sessionID] = ch
|
||||||
|
}
|
||||||
|
subscribed := e.subscribed
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
if subscribed != nil {
|
||||||
|
select {
|
||||||
|
case subscribed <- sessionID:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) TriggerDisconnect(sessionID string, err error) {
|
||||||
|
sessionID = strings.TrimSpace(sessionID)
|
||||||
|
if sessionID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.mu.Lock()
|
||||||
|
ch := e.sessions[sessionID]
|
||||||
|
delete(e.sessions, sessionID)
|
||||||
|
e.mu.Unlock()
|
||||||
|
if ch == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case ch <- err:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketUpstreamDisconnectExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
|
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
|
||||||
|
|
||||||
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
@@ -934,6 +1007,43 @@ func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestResponsesWebsocketClosesOnCodexUpstreamDisconnect(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
executor := &websocketUpstreamDisconnectExecutor{subscribed: make(chan string, 1)}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||||
|
server := httptest.NewServer(router)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
var sessionID string
|
||||||
|
select {
|
||||||
|
case sessionID = <-executor.subscribed:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for upstream disconnect subscription")
|
||||||
|
}
|
||||||
|
|
||||||
|
executor.TriggerDisconnect(sessionID, errors.New("upstream disconnected"))
|
||||||
|
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
_, _, err = conn.ReadMessage()
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected downstream websocket to close after upstream disconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||||
manager := coreauth.NewManager(nil, nil, nil)
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
auth := &coreauth.Auth{
|
auth := &coreauth.Auth{
|
||||||
|
|||||||
Reference in New Issue
Block a user