fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency

This commit is contained in:
Luis Pater
2026-04-01 17:16:49 +08:00
parent 105a21548f
commit d1c07a091e
3 changed files with 547 additions and 7 deletions
@@ -33,6 +33,8 @@ const (
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsBodyLogMaxSize = 32 * 1024
wsBodyLogTruncated = "\n...[truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
@@ -52,6 +54,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
return
}
passthroughSessionID := uuid.NewString()
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
@@ -164,6 +167,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
continue
}
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
@@ -324,6 +330,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
if errDedupeFunctionCalls == nil {
mergedInput = dedupedInput
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
@@ -355,7 +365,8 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
@@ -402,6 +413,42 @@ func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte)
return bytes.Clone(normalized)
}
func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
rawArray = strings.TrimSpace(rawArray)
if rawArray == "" {
return "[]", nil
}
var items []json.RawMessage
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
return "", errUnmarshal
}
seenCallIDs := make(map[string]struct{}, len(items))
filtered := make([]json.RawMessage, 0, len(items))
for _, item := range items {
if len(item) == 0 {
continue
}
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call" {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID != "" {
if _, ok := seenCallIDs[callID]; ok {
continue
}
seenCallIDs[callID] = struct{}{}
}
}
filtered = append(filtered, item)
}
out, errMarshal := json.Marshal(filtered)
if errMarshal != nil {
return "", errMarshal
}
return string(out), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
@@ -667,6 +714,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
downstreamSessionKey := ""
if c != nil && c.Request != nil {
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
}
for {
select {
@@ -744,6 +795,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
completed = true
@@ -891,18 +943,53 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
separator := []byte{}
if builder.Len() > 0 {
builder.WriteString("\n")
separator = []byte("\n")
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
header := []byte("websocket." + eventType + "\n")
footer := []byte("\n")
entryLen := len(separator) + len(header) + len(trimmedPayload) + len(footer)
remaining := wsBodyLogMaxSize - builder.Len()
if entryLen <= remaining {
builder.Write(separator)
builder.Write(header)
builder.Write(trimmedPayload)
builder.Write(footer)
return
}
marker := []byte(wsBodyLogTruncated)
if len(marker) > remaining {
builder.Write(marker[:remaining])
return
}
allowed := remaining - len(marker)
parts := [][]byte{separator, header, trimmedPayload, footer}
for _, part := range parts {
if allowed <= 0 {
break
}
if len(part) <= allowed {
builder.Write(part)
allowed -= len(part)
continue
}
builder.Write(part[:allowed])
allowed = 0
break
}
builder.Write(marker)
}
func websocketPayloadEventType(payload []byte) string {