fix(openai): handle transcript replacement after websocket compaction

- Add shouldReplaceWebsocketTranscript() to detect historical model output in input
- Add normalizeResponseTranscriptReplacement() for full transcript reset handling
- Prevent duplicate stale turn-state when clients replace local history post-compaction
- Avoid orphaned function_call items from incremental append on compact transcripts
- Add unit tests for transcript replacement detection and state reset behavior
This commit is contained in:
apparition
2026-03-30 22:44:58 +08:00
parent 6570692291
commit c1d7599829
2 changed files with 240 additions and 0 deletions
@@ -277,6 +277,15 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}
// Compaction can cause clients to replace local websocket history with a new
// compact transcript on the next `response.create`. When the input already
// contains historical model output items, treating it as an incremental append
// duplicates stale turn-state and can leave late orphaned function_call items.
if shouldReplaceWebsocketTranscript(rawJSON, nextInput) {
normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest)
return normalized, bytes.Clone(normalized), nil
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
@@ -348,6 +357,54 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
return normalized, bytes.Clone(normalized), nil
}
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
return false
}
if !nextInput.Exists() || !nextInput.IsArray() {
return false
}
for _, item := range nextInput.Array() {
switch strings.TrimSpace(item.Get("type").String()) {
case "function_call":
return true
case "message":
role := strings.TrimSpace(item.Get("role").String())
if role == "assistant" || role == "developer" {
return true
}
}
}
return false
}
func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return bytes.Clone(normalized)
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {