fix: repair websocket custom tool calls

This commit is contained in:
Kai Wang
2026-04-03 17:11:41 +08:00
parent ab9ebea592
commit 8f0e66b72e
@@ -266,15 +266,15 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue continue
} }
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
switch itemType { switch {
case "function_call_output": case isResponsesToolCallOutputType(itemType):
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" { if callID == "" {
continue continue
} }
outputPresent[callID] = struct{}{} outputPresent[callID] = struct{}{}
outputCache.record(sessionKey, callID, item) outputCache.record(sessionKey, callID, item)
case "function_call": case isResponsesToolCallType(itemType):
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" { if callID == "" {
continue continue
@@ -293,7 +293,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
continue continue
} }
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String()) itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
if itemType == "function_call_output" { if isResponsesToolCallOutputType(itemType) {
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String()) callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
if callID == "" { if callID == "" {
// Upstream rejects tool outputs without a call_id; drop it. // Upstream rejects tool outputs without a call_id; drop it.
@@ -325,7 +325,7 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls. // Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
continue continue
} }
if itemType != "function_call" { if !isResponsesToolCallType(itemType) {
filtered = append(filtered, item) filtered = append(filtered, item)
continue continue
} }
@@ -376,7 +376,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
return return
} }
for _, item := range output.Array() { for _, item := range output.Array() {
if strings.TrimSpace(item.Get("type").String()) != "function_call" { if !isResponsesToolCallType(item.Get("type").String()) {
continue continue
} }
callID := strings.TrimSpace(item.Get("call_id").String()) callID := strings.TrimSpace(item.Get("call_id").String())
@@ -390,7 +390,7 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
if !item.Exists() || !item.IsObject() { if !item.Exists() || !item.IsObject() {
return return
} }
if strings.TrimSpace(item.Get("type").String()) != "function_call" { if !isResponsesToolCallType(item.Get("type").String()) {
return return
} }
callID := strings.TrimSpace(item.Get("call_id").String()) callID := strings.TrimSpace(item.Get("call_id").String())
@@ -400,3 +400,21 @@ func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolO
cache.record(sessionKey, callID, json.RawMessage(item.Raw)) cache.record(sessionKey, callID, json.RawMessage(item.Raw))
} }
} }
func isResponsesToolCallType(itemType string) bool {
switch strings.TrimSpace(itemType) {
case "function_call", "custom_tool_call":
return true
default:
return false
}
}
func isResponsesToolCallOutputType(itemType string) bool {
switch strings.TrimSpace(itemType) {
case "function_call_output", "custom_tool_call_output":
return true
default:
return false
}
}