fix(websocket): ensure state consistency on auth errors in streaming

- Added logic to reset `pinnedAuthID` and replay transcript on unauthorized, forbidden, or throttling errors.
- Enhanced error handling in `forwardResponsesWebsocket` with detailed status inspection.
- Introduced `shouldReleaseResponsesWebsocketPinnedAuth` to determine auth reset conditions.
- Updated state management to preserve prior request and response data during forced replay.

Fixed: #2230
This commit is contained in:
Luis Pater
2026-05-04 05:23:23 +08:00
parent a1487b0958
commit 8e6ef3fa64
2 changed files with 229 additions and 11 deletions
@@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
forceTranscriptReplayNextRequest := false
for {
msgType, payload, errReadMessage := conn.ReadMessage()
@@ -115,6 +116,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
if forceTranscriptReplayNextRequest {
allowIncrementalInputWithPreviousResponseID = false
}
allowCompactionReplayBypass := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
@@ -179,7 +183,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
previousLastRequest := bytes.Clone(lastRequest)
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
forcedTranscriptReplay := forceTranscriptReplayNextRequest
lastRequest = updatedLastRequest
if forcedTranscriptReplay {
forceTranscriptReplayNextRequest = false
}
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
@@ -204,12 +214,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
pinnedAuthID = ""
forceTranscriptReplayNextRequest = true
lastRequest = previousLastRequest
lastResponseOutput = previousLastResponseOutput
continue
}
lastResponseOutput = completedOutput
}
}
@@ -810,7 +827,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsTimelineLog *strings.Builder,
sessionID string,
) ([]byte, error) {
) ([]byte, *interfaces.ErrorMessage, error) {
completed := false
completedOutput := []byte("[]")
downstreamSessionKey := ""
@@ -822,7 +839,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
return completedOutput, nil, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
@@ -847,7 +864,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errMsg, errWrite
}
}
if errMsg != nil {
@@ -855,7 +872,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
} else {
cancel(nil)
}
return completedOutput, nil
return completedOutput, errMsg, nil
case chunk, ok := <-data:
if !ok {
if !completed {
@@ -881,13 +898,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errMsg, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
return completedOutput, errMsg, nil
}
cancel(nil)
return completedOutput, nil
return completedOutput, nil, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
@@ -914,13 +931,31 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
return completedOutput, nil, errWrite
}
}
}
}
}
func shouldReleaseResponsesWebsocketPinnedAuth(errMsg *interfaces.ErrorMessage) bool {
if errMsg == nil {
return false
}
status := errMsg.StatusCode
if status <= 0 && errMsg.Error != nil {
if se, ok := errMsg.Error.(interface{ StatusCode() int }); ok && se != nil {
status = se.StatusCode()
}
}
switch status {
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests:
return true
default:
return false
}
}
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {