feat(executor): add upstream disconnect handling for Codex WebSocket sessions
- Introduced `UpstreamDisconnectChan` for Codex WebSocket sessions to notify downstream connections of upstream disconnections. - Implemented `notifyUpstreamDisconnect` to signal errors and close channels on disconnect events. - Added integration tests to validate WebSocket session behavior on upstream disconnect. - Updated OpenAI WebSocket response handlers to properly close connections upon upstream disconnect notifications.
This commit is contained in:
@@ -56,6 +56,31 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
retainResponsesWebsocketToolCaches(downstreamSessionKey)
|
||||
clientIP := websocketClientAddress(c)
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientIP)
|
||||
|
||||
wsDone := make(chan struct{})
|
||||
defer close(wsDone)
|
||||
|
||||
if h != nil && h.AuthManager != nil {
|
||||
if exec, ok := h.AuthManager.Executor("codex"); ok && exec != nil {
|
||||
type upstreamDisconnectSubscriber interface {
|
||||
UpstreamDisconnectChan(sessionID string) <-chan error
|
||||
}
|
||||
if subscriber, ok := exec.(upstreamDisconnectSubscriber); ok && subscriber != nil {
|
||||
disconnectCh := subscriber.UpstreamDisconnectChan(passthroughSessionID)
|
||||
if disconnectCh != nil {
|
||||
go func() {
|
||||
select {
|
||||
case <-wsDone:
|
||||
return
|
||||
case <-disconnectCh:
|
||||
_ = conn.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var wsTerminateErr error
|
||||
var wsTimelineLog strings.Builder
|
||||
defer func() {
|
||||
|
||||
@@ -85,6 +85,79 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
|
||||
|
||||
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
|
||||
|
||||
type websocketUpstreamDisconnectExecutor struct {
|
||||
mu sync.Mutex
|
||||
subscribed chan string
|
||||
sessions map[string]chan error
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) UpstreamDisconnectChan(sessionID string) <-chan error {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return nil
|
||||
}
|
||||
e.mu.Lock()
|
||||
if e.sessions == nil {
|
||||
e.sessions = make(map[string]chan error)
|
||||
}
|
||||
ch, ok := e.sessions[sessionID]
|
||||
if !ok {
|
||||
ch = make(chan error, 1)
|
||||
e.sessions[sessionID] = ch
|
||||
}
|
||||
subscribed := e.subscribed
|
||||
e.mu.Unlock()
|
||||
|
||||
if subscribed != nil {
|
||||
select {
|
||||
case subscribed <- sessionID:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return ch
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) TriggerDisconnect(sessionID string, err error) {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return
|
||||
}
|
||||
e.mu.Lock()
|
||||
ch := e.sessions[sessionID]
|
||||
delete(e.sessions, sessionID)
|
||||
e.mu.Unlock()
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case ch <- err:
|
||||
default:
|
||||
}
|
||||
close(ch)
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketUpstreamDisconnectExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
|
||||
|
||||
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -934,6 +1007,43 @@ func TestResponsesWebsocketTimelineRecordsDisconnectEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketClosesOnCodexUpstreamDisconnect(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
executor := &websocketUpstreamDisconnectExecutor{subscribed: make(chan string, 1)}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
router := gin.New()
|
||||
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("dial websocket: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
var sessionID string
|
||||
select {
|
||||
case sessionID = <-executor.subscribed:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for upstream disconnect subscription")
|
||||
}
|
||||
|
||||
executor.TriggerDisconnect(sessionID, errors.New("upstream disconnected"))
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
_, _, err = conn.ReadMessage()
|
||||
if err == nil {
|
||||
t.Fatalf("expected downstream websocket to close after upstream disconnect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
|
||||
Reference in New Issue
Block a user