fix(openai): add websocket tool call repair with caching and tests to improve transcript consistency
This commit is contained in:
@@ -33,6 +33,8 @@ const (
|
|||||||
wsDoneMarker = "[DONE]"
|
wsDoneMarker = "[DONE]"
|
||||||
wsTurnStateHeader = "x-codex-turn-state"
|
wsTurnStateHeader = "x-codex-turn-state"
|
||||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
wsBodyLogMaxSize = 32 * 1024
|
||||||
|
wsBodyLogTruncated = "\n...[truncated]\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||||
@@ -52,6 +54,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
passthroughSessionID := uuid.NewString()
|
passthroughSessionID := uuid.NewString()
|
||||||
|
downstreamSessionKey := websocketDownstreamSessionKey(c.Request)
|
||||||
clientRemoteAddr := ""
|
clientRemoteAddr := ""
|
||||||
if c != nil && c.Request != nil {
|
if c != nil && c.Request != nil {
|
||||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||||
@@ -164,6 +167,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
|
||||||
|
updatedLastRequest = bytes.Clone(requestJSON)
|
||||||
lastRequest = updatedLastRequest
|
lastRequest = updatedLastRequest
|
||||||
|
|
||||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||||
@@ -324,6 +330,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
|||||||
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
|
||||||
|
if errDedupeFunctionCalls == nil {
|
||||||
|
mergedInput = dedupedInput
|
||||||
|
}
|
||||||
|
|
||||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
if errDelete != nil {
|
if errDelete != nil {
|
||||||
@@ -355,7 +365,8 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
|||||||
}
|
}
|
||||||
|
|
||||||
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
|
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
|
||||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||||
|
if requestType != wsRequestTypeCreate && requestType != wsRequestTypeAppend {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
|
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
|
||||||
@@ -402,6 +413,42 @@ func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte)
|
|||||||
return bytes.Clone(normalized)
|
return bytes.Clone(normalized)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
|
||||||
|
rawArray = strings.TrimSpace(rawArray)
|
||||||
|
if rawArray == "" {
|
||||||
|
return "[]", nil
|
||||||
|
}
|
||||||
|
var items []json.RawMessage
|
||||||
|
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
|
||||||
|
return "", errUnmarshal
|
||||||
|
}
|
||||||
|
|
||||||
|
seenCallIDs := make(map[string]struct{}, len(items))
|
||||||
|
filtered := make([]json.RawMessage, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
if len(item) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||||
|
if itemType == "function_call" {
|
||||||
|
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||||
|
if callID != "" {
|
||||||
|
if _, ok := seenCallIDs[callID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenCallIDs[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
out, errMarshal := json.Marshal(filtered)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", errMarshal
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||||
if len(attributes) > 0 {
|
if len(attributes) > 0 {
|
||||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||||
@@ -667,6 +714,10 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
|||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
completed := false
|
completed := false
|
||||||
completedOutput := []byte("[]")
|
completedOutput := []byte("[]")
|
||||||
|
downstreamSessionKey := ""
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -744,6 +795,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
|||||||
|
|
||||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
for i := range payloads {
|
for i := range payloads {
|
||||||
|
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
|
||||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||||
if eventType == wsEventTypeCompleted {
|
if eventType == wsEventTypeCompleted {
|
||||||
completed = true
|
completed = true
|
||||||
@@ -891,18 +943,53 @@ func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []
|
|||||||
if builder == nil {
|
if builder == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if builder.Len() >= wsBodyLogMaxSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
trimmedPayload := bytes.TrimSpace(payload)
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
if len(trimmedPayload) == 0 {
|
if len(trimmedPayload) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
separator := []byte{}
|
||||||
if builder.Len() > 0 {
|
if builder.Len() > 0 {
|
||||||
builder.WriteString("\n")
|
separator = []byte("\n")
|
||||||
}
|
}
|
||||||
builder.WriteString("websocket.")
|
header := []byte("websocket." + eventType + "\n")
|
||||||
builder.WriteString(eventType)
|
footer := []byte("\n")
|
||||||
builder.WriteString("\n")
|
entryLen := len(separator) + len(header) + len(trimmedPayload) + len(footer)
|
||||||
|
remaining := wsBodyLogMaxSize - builder.Len()
|
||||||
|
|
||||||
|
if entryLen <= remaining {
|
||||||
|
builder.Write(separator)
|
||||||
|
builder.Write(header)
|
||||||
builder.Write(trimmedPayload)
|
builder.Write(trimmedPayload)
|
||||||
builder.WriteString("\n")
|
builder.Write(footer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
marker := []byte(wsBodyLogTruncated)
|
||||||
|
if len(marker) > remaining {
|
||||||
|
builder.Write(marker[:remaining])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := remaining - len(marker)
|
||||||
|
parts := [][]byte{separator, header, trimmedPayload, footer}
|
||||||
|
for _, part := range parts {
|
||||||
|
if allowed <= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if len(part) <= allowed {
|
||||||
|
builder.Write(part)
|
||||||
|
allowed -= len(part)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
builder.Write(part[:allowed])
|
||||||
|
allowed = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
builder.Write(marker)
|
||||||
}
|
}
|
||||||
|
|
||||||
func websocketPayloadEventType(payload []byte) string {
|
func websocketPayloadEventType(payload []byte) string {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@@ -442,6 +443,108 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) {
|
||||||
|
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
sessionKey := "session-1"
|
||||||
|
|
||||||
|
cacheWarm := []byte(`{"previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","output":"ok"}]}`)
|
||||||
|
warmed := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, cacheWarm)
|
||||||
|
if gjson.GetBytes(warmed, "input.0.call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("expected warmup output to remain")
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
|
||||||
|
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
|
||||||
|
|
||||||
|
input := gjson.GetBytes(repaired, "input").Array()
|
||||||
|
if len(input) != 3 {
|
||||||
|
t.Fatalf("repaired input len = %d, want 3", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("unexpected first item: %s", input[0].Raw)
|
||||||
|
}
|
||||||
|
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("missing inserted output: %s", input[1].Raw)
|
||||||
|
}
|
||||||
|
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
|
||||||
|
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepairResponsesWebsocketToolCallsDropsOrphanFunctionCall(t *testing.T) {
|
||||||
|
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
sessionKey := "session-1"
|
||||||
|
|
||||||
|
raw := []byte(`{"input":[{"type":"function_call","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
|
||||||
|
repaired := repairResponsesWebsocketToolCallsWithCache(cache, sessionKey, raw)
|
||||||
|
|
||||||
|
input := gjson.GetBytes(repaired, "input").Array()
|
||||||
|
if len(input) != 1 {
|
||||||
|
t.Fatalf("repaired input len = %d, want 1", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
|
||||||
|
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *testing.T) {
|
||||||
|
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
callCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
sessionKey := "session-1"
|
||||||
|
|
||||||
|
callCache.record(sessionKey, "call-1", []byte(`{"type":"function_call","call_id":"call-1","name":"tool"}`))
|
||||||
|
|
||||||
|
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
|
||||||
|
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
|
||||||
|
|
||||||
|
input := gjson.GetBytes(repaired, "input").Array()
|
||||||
|
if len(input) != 3 {
|
||||||
|
t.Fatalf("repaired input len = %d, want 3", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("missing inserted call: %s", input[0].Raw)
|
||||||
|
}
|
||||||
|
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("unexpected output item: %s", input[1].Raw)
|
||||||
|
}
|
||||||
|
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
|
||||||
|
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRepairResponsesWebsocketToolCallsDropsOrphanOutputWhenCallMissing(t *testing.T) {
|
||||||
|
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
callCache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
sessionKey := "session-1"
|
||||||
|
|
||||||
|
raw := []byte(`{"input":[{"type":"function_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
|
||||||
|
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
|
||||||
|
|
||||||
|
input := gjson.GetBytes(repaired, "input").Array()
|
||||||
|
if len(input) != 1 {
|
||||||
|
t.Fatalf("repaired input len = %d, want 1", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("type").String() != "message" || input[0].Get("id").String() != "msg-1" {
|
||||||
|
t.Fatalf("unexpected remaining item: %s", input[0].Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordResponsesWebsocketToolCallsFromPayloadWithCache(t *testing.T) {
|
||||||
|
cache := newWebsocketToolOutputCache(time.Minute, 10)
|
||||||
|
sessionKey := "session-1"
|
||||||
|
|
||||||
|
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool","arguments":"{}"}]}}`)
|
||||||
|
recordResponsesWebsocketToolCallsFromPayloadWithCache(cache, sessionKey, payload)
|
||||||
|
|
||||||
|
cached, ok := cache.get(sessionKey, "call-1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected cached tool call")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(cached, "type").String() != "function_call" || gjson.GetBytes(cached, "call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("unexpected cached tool call: %s", cached)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@@ -767,6 +870,29 @@ func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplace
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`)
|
||||||
|
|
||||||
|
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
items := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(items) != 3 {
|
||||||
|
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
|
||||||
|
}
|
||||||
|
if items[0].Get("id").String() != "fc-1" ||
|
||||||
|
items[1].Get("id").String() != "tool-out-1" ||
|
||||||
|
items[2].Get("id").String() != "msg-2" {
|
||||||
|
t.Fatalf("unexpected merged input order: %s", normalized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
|
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,327 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
websocketToolOutputCacheMaxPerSession = 256
|
||||||
|
websocketToolOutputCacheTTL = 30 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultWebsocketToolOutputCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession)
|
||||||
|
var defaultWebsocketToolCallCache = newWebsocketToolOutputCache(websocketToolOutputCacheTTL, websocketToolOutputCacheMaxPerSession)
|
||||||
|
|
||||||
|
type websocketToolOutputCache struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
ttl time.Duration
|
||||||
|
maxPerSession int
|
||||||
|
sessions map[string]*websocketToolOutputSession
|
||||||
|
}
|
||||||
|
|
||||||
|
type websocketToolOutputSession struct {
|
||||||
|
lastSeen time.Time
|
||||||
|
outputs map[string]json.RawMessage
|
||||||
|
order []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWebsocketToolOutputCache(ttl time.Duration, maxPerSession int) *websocketToolOutputCache {
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = websocketToolOutputCacheTTL
|
||||||
|
}
|
||||||
|
if maxPerSession <= 0 {
|
||||||
|
maxPerSession = websocketToolOutputCacheMaxPerSession
|
||||||
|
}
|
||||||
|
return &websocketToolOutputCache{
|
||||||
|
ttl: ttl,
|
||||||
|
maxPerSession: maxPerSession,
|
||||||
|
sessions: make(map[string]*websocketToolOutputSession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketToolOutputCache) record(sessionKey string, callID string, item json.RawMessage) {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
callID = strings.TrimSpace(callID)
|
||||||
|
if sessionKey == "" || callID == "" || c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.cleanupLocked(now)
|
||||||
|
|
||||||
|
session, ok := c.sessions[sessionKey]
|
||||||
|
if !ok || session == nil {
|
||||||
|
session = &websocketToolOutputSession{
|
||||||
|
lastSeen: now,
|
||||||
|
outputs: make(map[string]json.RawMessage),
|
||||||
|
}
|
||||||
|
c.sessions[sessionKey] = session
|
||||||
|
}
|
||||||
|
session.lastSeen = now
|
||||||
|
|
||||||
|
if _, exists := session.outputs[callID]; !exists {
|
||||||
|
session.order = append(session.order, callID)
|
||||||
|
}
|
||||||
|
session.outputs[callID] = append(json.RawMessage(nil), item...)
|
||||||
|
|
||||||
|
for len(session.order) > c.maxPerSession {
|
||||||
|
evict := session.order[0]
|
||||||
|
session.order = session.order[1:]
|
||||||
|
delete(session.outputs, evict)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketToolOutputCache) get(sessionKey string, callID string) (json.RawMessage, bool) {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
callID = strings.TrimSpace(callID)
|
||||||
|
if sessionKey == "" || callID == "" || c == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.cleanupLocked(now)
|
||||||
|
|
||||||
|
session, ok := c.sessions[sessionKey]
|
||||||
|
if !ok || session == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
session.lastSeen = now
|
||||||
|
item, ok := session.outputs[callID]
|
||||||
|
if !ok || len(item) == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return append(json.RawMessage(nil), item...), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *websocketToolOutputCache) cleanupLocked(now time.Time) {
|
||||||
|
if c == nil || c.ttl <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, session := range c.sessions {
|
||||||
|
if session == nil {
|
||||||
|
delete(c.sessions, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if now.Sub(session.lastSeen) > c.ttl {
|
||||||
|
delete(c.sessions, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketDownstreamSessionKey(req *http.Request) string {
|
||||||
|
if req == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if sessionID := strings.TrimSpace(req.Header.Get("Session_id")); sessionID != "" {
|
||||||
|
return sessionID
|
||||||
|
}
|
||||||
|
if requestID := strings.TrimSpace(req.Header.Get("X-Client-Request-Id")); requestID != "" {
|
||||||
|
return requestID
|
||||||
|
}
|
||||||
|
if raw := strings.TrimSpace(req.Header.Get("X-Codex-Turn-Metadata")); raw != "" {
|
||||||
|
if sessionID := strings.TrimSpace(gjson.Get(raw, "session_id").String()); sessionID != "" {
|
||||||
|
return sessionID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func repairResponsesWebsocketToolCalls(sessionKey string, payload []byte) []byte {
|
||||||
|
return repairResponsesWebsocketToolCallsWithCaches(defaultWebsocketToolOutputCache, defaultWebsocketToolCallCache, sessionKey, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func repairResponsesWebsocketToolCallsWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
|
||||||
|
return repairResponsesWebsocketToolCallsWithCaches(cache, nil, sessionKey, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache *websocketToolOutputCache, sessionKey string, payload []byte) []byte {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
if sessionKey == "" || outputCache == nil || len(payload) == 0 {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
input := gjson.GetBytes(payload, "input")
|
||||||
|
if !input.Exists() || !input.IsArray() {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
allowOrphanOutputs := strings.TrimSpace(gjson.GetBytes(payload, "previous_response_id").String()) != ""
|
||||||
|
updatedRaw, errRepair := repairResponsesToolCallsArray(outputCache, callCache, sessionKey, input.Raw, allowOrphanOutputs)
|
||||||
|
if errRepair != nil || updatedRaw == "" || updatedRaw == input.Raw {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, errSet := sjson.SetRawBytes(payload, "input", []byte(updatedRaw))
|
||||||
|
if errSet != nil {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
return updated
|
||||||
|
}
|
||||||
|
|
||||||
|
func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCache, sessionKey string, rawArray string, allowOrphanOutputs bool) (string, error) {
|
||||||
|
rawArray = strings.TrimSpace(rawArray)
|
||||||
|
if rawArray == "" {
|
||||||
|
return "[]", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var items []json.RawMessage
|
||||||
|
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
|
||||||
|
return "", errUnmarshal
|
||||||
|
}
|
||||||
|
|
||||||
|
// First pass: record tool outputs and remember which call_ids have outputs in this payload.
|
||||||
|
outputPresent := make(map[string]struct{}, len(items))
|
||||||
|
callPresent := make(map[string]struct{}, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
if len(item) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||||
|
switch itemType {
|
||||||
|
case "function_call_output":
|
||||||
|
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
outputPresent[callID] = struct{}{}
|
||||||
|
outputCache.record(sessionKey, callID, item)
|
||||||
|
case "function_call":
|
||||||
|
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callPresent[callID] = struct{}{}
|
||||||
|
if callCache != nil {
|
||||||
|
callCache.record(sessionKey, callID, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make([]json.RawMessage, 0, len(items))
|
||||||
|
insertedCalls := make(map[string]struct{}, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
if len(item) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
itemType := strings.TrimSpace(gjson.GetBytes(item, "type").String())
|
||||||
|
if itemType == "function_call_output" {
|
||||||
|
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
// Upstream rejects tool outputs without a call_id; drop it.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if allowOrphanOutputs {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := callPresent[callID]; ok {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if callCache != nil {
|
||||||
|
if cached, ok := callCache.get(sessionKey, callID); ok {
|
||||||
|
if _, already := insertedCalls[callID]; !already {
|
||||||
|
filtered = append(filtered, cached)
|
||||||
|
insertedCalls[callID] = struct{}{}
|
||||||
|
callPresent[callID] = struct{}{}
|
||||||
|
}
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if itemType != "function_call" {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
callID := strings.TrimSpace(gjson.GetBytes(item, "call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
// Upstream rejects tool calls without a call_id; drop it.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := outputPresent[callID]; ok {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if cached, ok := outputCache.get(sessionKey, callID); ok {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
filtered = append(filtered, cached)
|
||||||
|
outputPresent[callID] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop orphaned function_call items; upstream rejects transcripts with missing outputs.
|
||||||
|
}
|
||||||
|
|
||||||
|
out, errMarshal := json.Marshal(filtered)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", errMarshal
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordResponsesWebsocketToolCallsFromPayload(sessionKey string, payload []byte) {
|
||||||
|
recordResponsesWebsocketToolCallsFromPayloadWithCache(defaultWebsocketToolCallCache, sessionKey, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func recordResponsesWebsocketToolCallsFromPayloadWithCache(cache *websocketToolOutputCache, sessionKey string, payload []byte) {
|
||||||
|
sessionKey = strings.TrimSpace(sessionKey)
|
||||||
|
if sessionKey == "" || cache == nil || len(payload) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||||
|
switch eventType {
|
||||||
|
case "response.completed":
|
||||||
|
output := gjson.GetBytes(payload, "response.output")
|
||||||
|
if !output.Exists() || !output.IsArray() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range output.Array() {
|
||||||
|
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
callID := strings.TrimSpace(item.Get("call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
|
||||||
|
}
|
||||||
|
case "response.output_item.added", "response.output_item.done":
|
||||||
|
item := gjson.GetBytes(payload, "item")
|
||||||
|
if !item.Exists() || !item.IsObject() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(item.Get("type").String()) != "function_call" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callID := strings.TrimSpace(item.Get("call_id").String())
|
||||||
|
if callID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cache.record(sessionKey, callID, json.RawMessage(item.Raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user