fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
@@ -442,6 +443,108 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) {
|
||||
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`)
|
||||
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
|
||||
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
|
||||
t.Fatalf("expected warmup output to remain")
|
||||
}
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 3 {
|
||||
t.Fatalf("repaired input len = %d, want 3", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("unexpected first item: %s", input[0].Raw)
|
||||
}
|
||||
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("missing inserted output: %s", input[1].Raw)
|
||||
}
|
||||
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) {
|
||||
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("repaired input len = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) {
|
||||
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
callCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`))
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 3 {
|
||||
t.Fatalf("repaired input len = %d, want 3", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("missing inserted call: %s", input[0].Raw)
|
||||
}
|
||||
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
|
||||
t.Fatalf("unexpected output item: %s", input[1].Raw)
|
||||
}
|
||||
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) {
|
||||
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
callCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
|
||||
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
|
||||
|
||||
input := gjson.GetBytes(repaired, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("repaired input len = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
|
||||
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
|
||||
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||
sessionKey := "session-1"
|
||||
|
||||
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`)
|
||||
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
|
||||
|
||||
cached, ok := cache.get(sessionKey, "call-1")
|
||||
if !ok {
|
||||
t.Fatalf("expected cached tool call")
|
||||
}
|
||||
if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
|
||||
t.Fatalf("unexpected cached tool call: %s", cached)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -767,6 +870,29 @@ func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplace
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
|
||||
|
||||
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
|
||||
items := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(items) != 3 {
|
||||
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
|
||||
}
|
||||
if items[0].Get("id").String() != "fc-1" ||
|
||||
items[1].Get("id").String() != "tool-out-1" ||
|
||||
items[2].Get("id").String() != "msg-2" {
|
||||
t.Fatalf("unexpected merged input order: %s", normalized)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user