fix(websocket): ensure state consistency on auth errors in streaming
- Added logic to reset `pinnedAuthID` and replay transcript on unauthorized, forbidden, or throttling errors. - Enhanced error handling in `forwardResponsesWebsocket` with detailed status inspection. - Introduced `shouldReleaseResponsesWebsocketPinnedAuth` to determine auth reset conditions. - Updated state management to preserve prior request and response data during forced replay. Fixed: #2230
This commit is contained in:
@@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
var lastRequest []byte
|
||||
lastResponseOutput := []byte("[]")
|
||||
pinnedAuthID := ""
|
||||
forceTranscriptReplayNextRequest := false
|
||||
|
||||
for {
|
||||
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||
@@ -115,6 +116,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||
}
|
||||
if forceTranscriptReplayNextRequest {
|
||||
allowIncrementalInputWithPreviousResponseID = false
|
||||
}
|
||||
|
||||
allowCompactionReplayBypass := false
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
@@ -179,7 +183,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
|
||||
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
|
||||
updatedLastRequest = bytes.Clone(requestJSON)
|
||||
previousLastRequest := bytes.Clone(lastRequest)
|
||||
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
|
||||
forcedTranscriptReplay := forceTranscriptReplayNextRequest
|
||||
lastRequest = updatedLastRequest
|
||||
if forcedTranscriptReplay {
|
||||
forceTranscriptReplayNextRequest = false
|
||||
}
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
@@ -204,12 +214,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
|
||||
completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
wsTerminateErr = errForward
|
||||
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||
return
|
||||
}
|
||||
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
|
||||
pinnedAuthID = ""
|
||||
forceTranscriptReplayNextRequest = true
|
||||
lastRequest = previousLastRequest
|
||||
lastResponseOutput = previousLastResponseOutput
|
||||
continue
|
||||
}
|
||||
lastResponseOutput = completedOutput
|
||||
}
|
||||
}
|
||||
@@ -810,7 +827,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
errs <-chan *interfaces.ErrorMessage,
|
||||
wsTimelineLog *strings.Builder,
|
||||
sessionID string,
|
||||
) ([]byte, error) {
|
||||
) ([]byte, *interfaces.ErrorMessage, error) {
|
||||
completed := false
|
||||
completedOutput := []byte("[]")
|
||||
downstreamSessionKey := ""
|
||||
@@ -822,7 +839,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return completedOutput, c.Request.Context().Err()
|
||||
return completedOutput, nil, c.Request.Context().Err()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
errs = nil
|
||||
@@ -847,7 +864,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
// errWrite,
|
||||
// )
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
return completedOutput, errMsg, errWrite
|
||||
}
|
||||
}
|
||||
if errMsg != nil {
|
||||
@@ -855,7 +872,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
return completedOutput, nil
|
||||
return completedOutput, errMsg, nil
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
if !completed {
|
||||
@@ -881,13 +898,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
errWrite,
|
||||
)
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
return completedOutput, errMsg, errWrite
|
||||
}
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, nil
|
||||
return completedOutput, errMsg, nil
|
||||
}
|
||||
cancel(nil)
|
||||
return completedOutput, nil
|
||||
return completedOutput, nil, nil
|
||||
}
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
@@ -914,13 +931,31 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
errWrite,
|
||||
)
|
||||
cancel(errWrite)
|
||||
return completedOutput, errWrite
|
||||
return completedOutput, nil, errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shouldReleaseResponsesWebsocketPinnedAuth(errMsg *interfaces.ErrorMessage) bool {
|
||||
if errMsg == nil {
|
||||
return false
|
||||
}
|
||||
status := errMsg.StatusCode
|
||||
if status <= 0 && errMsg.Error != nil {
|
||||
if se, ok := errMsg.Error.(interface{ StatusCode() int }); ok && se != nil {
|
||||
status = se.StatusCode()
|
||||
}
|
||||
}
|
||||
switch status {
|
||||
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if output.Exists() && output.IsArray() {
|
||||
|
||||
Reference in New Issue
Block a user