fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency
This commit is contained in:
@@ -33,6 +33,8 @@ const (
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsBodyLogMaxSize = 32 * 1024
|
||||
wsBodyLogTruncated = "\n...[truncated]\n"
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
@@ -52,6 +54,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
passthroughSessionID := uuid.NewString()
|
||||
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
|
||||
clientRemoteAddr := ""
|
||||
if c != nil && c.Request != nil {
|
||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||
@@ -164,6 +167,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
|
||||
updatedLastRequest = bytes.Clone(requestJSON)
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
@@ -324,6 +330,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
||||
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||
}
|
||||
}
|
||||
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
|
||||
if errDedupeFunctionCalls == nil {
|
||||
mergedInput = dedupedInput
|
||||
}
|
||||
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
@@ -355,7 +365,8 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
||||
}
|
||||
|
||||
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
|
||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||
if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
|
||||
@@ -402,6 +413,42 @@ func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte)
|
||||
return bytes.Clone(normalized)
|
||||
}
|
||||
|
||||
func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
|
||||
rawArray = strings.TrimSpace(rawArray)
|
||||
if rawArray == "" {
|
||||
return "[]", nil
|
||||
}
|
||||
var items []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
|
||||
return "", errUnmarshal
|
||||
}
|
||||
|
||||
seenCallIDs := make(map[string]struct{}, len(items))
|
||||
filtered := make([]json.RawMessage, 0, len(items))
|
||||
for _, item := range items {
|
||||
if len(item) == 0 {
|
||||
continue
|
||||
}
|
||||
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||
if itemType == "function_call" {
|
||||
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||
if callID != "" {
|
||||
if _, ok := seenCallIDs[callID]; ok {
|
||||
continue
|
||||
}
|
||||
seenCallIDs[callID] = struct{}{}
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
|
||||
out, errMarshal := json.Marshal(filtered)
|
||||
if errMarshal != nil {
|
||||
return "", errMarshal
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||
if len(attributes) > 0 {
|
||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||
@@ -667,6 +714,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
) ([]byte, error) {
|
||||
completed := false
|
||||
completedOutput := []byte("[]")
|
||||
downstreamSessionKey := ""
|
||||
if c != nil && c.Request != nil {
|
||||
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -744,6 +795,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
for i := range payloads {
|
||||
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
completed = true
|
||||
@@ -891,18 +943,53 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
if builder.Len() >= wsBodyLogMaxSize {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
separator := []byte{}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
separator = []byte("\n")
|
||||
}
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
header := []byte("websocket." + eventType + "\n")
|
||||
footer := []byte("\n")
|
||||
entryLen := len(separator) + len(header) + len(trimmedPayload) + len(footer)
|
||||
remaining := wsBodyLogMaxSize - builder.Len()
|
||||
|
||||
if entryLen <= remaining {
|
||||
builder.Write(separator)
|
||||
builder.Write(header)
|
||||
builder.Write(trimmedPayload)
|
||||
builder.Write(footer)
|
||||
return
|
||||
}
|
||||
|
||||
marker := []byte(wsBodyLogTruncated)
|
||||
if len(marker) > remaining {
|
||||
builder.Write(marker[:remaining])
|
||||
return
|
||||
}
|
||||
|
||||
allowed := remaining - len(marker)
|
||||
parts := [][]byte{separator, header, trimmedPayload, footer}
|
||||
for _, part := range parts {
|
||||
if allowed <= 0 {
|
||||
break
|
||||
}
|
||||
if len(part) <= allowed {
|
||||
builder.Write(part)
|
||||
allowed -= len(part)
|
||||
continue
|
||||
}
|
||||
builder.Write(part[:allowed])
|
||||
allowed = 0
|
||||
break
|
||||
}
|
||||
builder.Write(marker)
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
|
||||
Reference in New Issue
Block a user