fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency

This commit is contained in:
Luis Pater
2026-04-01 17:16:49 +08:00
parent 105a21548f
commit d1c07a091e
3 changed files with 547 additions and 7 deletions
@@ -10,6 +10,7 @@ import (
"strings"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
@@ -442,6 +443,108 @@ func TestSetWebsocketRequestBody(t *testing.T) {
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`)
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
t.Fatalf("expected warmup output to remain")
}
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected first item: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted output: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`))
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 3 {
t.Fatalf("repaired input len = %d, want 3", len(input))
}
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
t.Fatalf("missing inserted call: %s", input[0].Raw)
}
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
t.Fatalf("unexpected output item: %s", input[1].Raw)
}
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
}
}
func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) {
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
callCache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
input := gjson.GetBytes(repaired, "input").Array()
if len(input) != 1 {
t.Fatalf("repaired input len = %d, want 1", len(input))
}
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
}
}
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
cache := newWebsocketToolOutputCache(time.Minute, 10)
sessionKey := "session-1"
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`)
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
cached, ok := cache.get(sessionKey, "call-1")
if !ok {
t.Fatalf("expected cached tool call")
}
if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
t.Fatalf("unexpected cached tool call: %s", cached)
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -767,6 +870,29 @@ func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplace
}
}
func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
items := gjson.GetBytes(normalized, "input").Array()
if len(items) != 3 {
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
}
if items[0].Get("id").String() != "fc-1" ||
items[1].Get("id").String() != "tool-out-1" ||
items[2].Get("id").String() != "msg-2" {
t.Fatalf("unexpected merged input order: %s", normalized)
}
}
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
gin.SetMode(gin.TestMode)