diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index caf26f13..7a9d2224 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" + forceTranscriptReplayNextRequest := false for { msgType, payload, errReadMessage := conn.ReadMessage() @@ -115,6 +116,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } + if forceTranscriptReplayNextRequest { + allowIncrementalInputWithPreviousResponseID = false + } allowCompactionReplayBypass := false if pinnedAuthID != "" && h != nil && h.AuthManager != nil { @@ -179,7 +183,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON) updatedLastRequest = bytes.Clone(requestJSON) + previousLastRequest := bytes.Clone(lastRequest) + previousLastResponseOutput := bytes.Clone(lastResponseOutput) + forcedTranscriptReplay := forceTranscriptReplayNextRequest lastRequest = updatedLastRequest + if forcedTranscriptReplay { + forceTranscriptReplayNextRequest = false + } modelName := gjson.GetBytes(requestJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) @@ -204,12 +214,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID) + completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID) if errForward != nil { wsTerminateErr = errForward log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) return } + if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) { + pinnedAuthID = "" + forceTranscriptReplayNextRequest = true + lastRequest = previousLastRequest + lastResponseOutput = previousLastResponseOutput + continue + } lastResponseOutput = completedOutput } } @@ -810,7 +827,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errs <-chan *interfaces.ErrorMessage, wsTimelineLog *strings.Builder, sessionID string, -) ([]byte, error) { +) ([]byte, *interfaces.ErrorMessage, error) { completed := false completedOutput := []byte("[]") downstreamSessionKey := "" @@ -822,7 +839,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( select { case <-c.Request.Context().Done(): cancel(c.Request.Context().Err()) - return completedOutput, c.Request.Context().Err() + return completedOutput, nil, c.Request.Context().Err() case errMsg, ok := <-errs: if !ok { errs = nil @@ -847,7 +864,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( // errWrite, // ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, errMsg, errWrite } } if errMsg != nil { @@ -855,7 +872,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( } else { cancel(nil) } - return completedOutput, nil + return completedOutput, errMsg, nil case chunk, ok := <-data: if !ok { if !completed { @@ -881,13 +898,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, errMsg, errWrite } cancel(errMsg.Error) - return completedOutput, nil + return completedOutput, errMsg, nil } cancel(nil) - return completedOutput, nil + return completedOutput, nil, nil } payloads := websocketJSONPayloadsFromChunk(chunk) @@ -914,13 +931,31 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errWrite) - return completedOutput, errWrite + return completedOutput, nil, errWrite } } } } } +func shouldReleaseResponsesWebsocketPinnedAuth(errMsg *interfaces.ErrorMessage) bool { + if errMsg == nil { + return false + } + status := errMsg.StatusCode + if status <= 0 && errMsg.Error != nil { + if se, ok := errMsg.Error.(interface{ StatusCode() int }); ok && se != nil { + status = se.StatusCode() + } + } + switch status { + case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests: + return true + default: + return false + } +} + func responseCompletedOutputFromPayload(payload []byte) []byte { output := gjson.GetBytes(payload, "response.output") if output.Exists() && output.IsArray() { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index f2c4319e..1d397ecd 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -69,6 +69,22 @@ type websocketAuthCaptureExecutor struct { authIDs []string } +type websocketPinnedFailoverExecutor struct { + mu sync.Mutex + authIDs []string + calls map[string]int + payloads map[string][][]byte +} + +type websocketPinnedFailoverStatusError struct { + status int + msg string +} + +func (e websocketPinnedFailoverStatusError) Error() string { return e.msg } + +func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status } + func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -106,6 +122,76 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string { return append([]string(nil), e.authIDs...) } +func (e *websocketPinnedFailoverExecutor) Identifier() string { return "test-provider" } + +func (e *websocketPinnedFailoverExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketPinnedFailoverExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + authID := "" + if auth != nil { + authID = auth.ID + } + + e.mu.Lock() + if e.calls == nil { + e.calls = make(map[string]int) + } + if e.payloads == nil { + e.payloads = make(map[string][][]byte) + } + e.authIDs = append(e.authIDs, authID) + e.calls[authID]++ + call := e.calls[authID] + e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload)) + e.mu.Unlock() + + if authID == "auth-a" && call == 2 { + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{ + status: http.StatusTooManyRequests, + msg: `{"error":{"message":"quota exhausted","type":"rate_limit_error","code":"rate_limit_exceeded"}}`, + }} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil + } + + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-%s-%d","output":[{"type":"message","id":"out-%s-%d"}]}}`, authID, call, authID, call))} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketPinnedFailoverExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketPinnedFailoverExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketPinnedFailoverExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketPinnedFailoverExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + return append([]string(nil), e.authIDs...) +} + +func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte { + e.mu.Lock() + defer e.mu.Unlock() + src := e.payloads[authID] + out := make([][]byte, len(src)) + for i := range src { + out[i] = bytes.Clone(src[i]) + } + return out +} + func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -681,7 +767,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(errCh) var timelineLog strings.Builder - completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + completedOutput, errMsg, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, @@ -694,6 +780,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { serverErrCh <- err return } + if errMsg != nil { + serverErrCh <- fmt.Errorf("unexpected websocket error message: %v", errMsg.Error) + return + } if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" { serverErrCh <- errors.New("completed output not captured") return @@ -760,7 +850,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing return } - _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + _, _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, @@ -1113,6 +1203,99 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { } } +func TestResponsesWebsocketReleasesPinnedAuthAfterQuotaError(t *testing.T) { + gin.SetMode(gin.TestMode) + + selector := &orderedWebsocketSelector{order: []string{"auth-a", "auth-b"}} + executor := &websocketPinnedFailoverExecutor{} + manager := coreauth.NewManager(nil, selector, nil) + manager.RegisterExecutor(executor) + + authA := &coreauth.Auth{ + ID: "auth-a", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), authA); err != nil { + t.Fatalf("Register auth A: %v", err) + } + authB := &coreauth.Auth{ + ID: "auth-b", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), authB); err != nil { + t.Fatalf("Register auth B: %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(authA.ID, authA.Provider, []*registry.ModelInfo{{ID: "quota-model"}}) + registry.GetGlobalRegistry().RegisterClient(authB.ID, authB.Provider, []*registry.ModelInfo{{ID: "quota-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(authA.ID) + registry.GetGlobalRegistry().UnregisterClient(authB.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + requests := []string{ + `{"type":"response.create","model":"quota-model","input":[{"type":"message","id":"msg-1"}]}`, + `{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-2"}]}`, + `{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-3"}]}`, + } + wantTypes := []string{wsEventTypeCompleted, wsEventTypeError, wsEventTypeCompleted} + for i := range requests { + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil { + t.Fatalf("write websocket message %d: %v", i+1, errWrite) + } + _, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + t.Fatalf("read websocket message %d: %v", i+1, errReadMessage) + } + if got := gjson.GetBytes(payload, "type").String(); got != wantTypes[i] { + t.Fatalf("message %d payload type = %s, want %s: %s", i+1, got, wantTypes[i], payload) + } + if i == 1 && int(gjson.GetBytes(payload, "status").Int()) != http.StatusTooManyRequests { + t.Fatalf("quota payload status = %d, want %d: %s", gjson.GetBytes(payload, "status").Int(), http.StatusTooManyRequests, payload) + } + } + + if got := executor.AuthIDs(); len(got) != 3 || got[0] != "auth-a" || got[1] != "auth-a" || got[2] != "auth-b" { + t.Fatalf("selected auth IDs = %v, want [auth-a auth-a auth-b]", got) + } + + authBPayloads := executor.Payloads("auth-b") + if len(authBPayloads) != 1 { + t.Fatalf("auth-b payload count = %d, want 1", len(authBPayloads)) + } + authBPayload := authBPayloads[0] + if gjson.GetBytes(authBPayload, "previous_response_id").Exists() { + t.Fatalf("previous_response_id leaked after auth failover: %s", authBPayload) + } + authBInput := gjson.GetBytes(authBPayload, "input").Raw + if !strings.Contains(authBInput, `"id":"msg-1"`) || !strings.Contains(authBInput, `"id":"msg-3"`) { + t.Fatalf("auth-b replay input missing expected transcript items: %s", authBInput) + } +} + func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) { lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`) lastResponseOutput := []byte(`[