fix(websocket): gate compact replay by downstream support
This commit is contained in:
@@ -242,7 +242,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true, false)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
@@ -278,7 +278,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncre
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
@@ -867,6 +867,53 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-codex",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth: %v", err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if !h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
|
||||
t.Fatalf("expected codex upstream to support compaction replay")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketUpstreamSupportsCompactionReplayForModelFalseWhenMixedBackends(t *testing.T) {
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
auths := []*coreauth.Auth{
|
||||
{ID: "auth-codex", Provider: "codex", Status: coreauth.StatusActive},
|
||||
{ID: "auth-claude", Provider: "claude", Status: coreauth.StatusActive},
|
||||
}
|
||||
for _, auth := range auths {
|
||||
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||
t.Fatalf("Register auth %s: %v", auth.ID, err)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
for _, auth := range auths {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||
}
|
||||
})
|
||||
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
if h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
|
||||
t.Fatalf("expected mixed backend model to disable compaction replay bypass")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -1469,6 +1516,45 @@ func TestNormalizeSubsequentRequestCompactSkipsMerge(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSubsequentRequestCompactMergesWhenCompactionReplayUnsupported(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
||||
{"type":"message","role":"user","id":"msg-1","content":"original long prompt"},
|
||||
{"type":"message","role":"assistant","id":"msg-2","content":"original long response"},
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-old","name":"bash","arguments":"{}"},
|
||||
{"type":"function_call_output","id":"fco-1","call_id":"call-old","output":"old result"}
|
||||
]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"message","role":"assistant","id":"msg-3","content":"another assistant reply"},
|
||||
{"type":"function_call","id":"fc-2","call_id":"call-stale","name":"read","arguments":"{}"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[
|
||||
{"type":"message","role":"user","id":"msg-1c","content":"compacted user msg"},
|
||||
{"type":"compaction","encrypted_content":"conversation summary"}
|
||||
]}`)
|
||||
|
||||
normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 7 {
|
||||
t.Fatalf("input len = %d, want 7 (merged fallback without compaction items)", len(input))
|
||||
}
|
||||
wantIDs := []string{"msg-1", "msg-2", "fc-1", "fco-1", "msg-3", "fc-2", "msg-1c"}
|
||||
for i, want := range wantIDs {
|
||||
got := input[i].Get("id").String()
|
||||
if got != want {
|
||||
t.Fatalf("input[%d].id = %q, want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
for _, item := range input {
|
||||
if item.Get("type").String() == "compaction" || item.Get("type").String() == "compaction_summary" {
|
||||
t.Fatalf("compaction items must be stripped for unsupported downstream fallback: %s", item.Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
|
||||
// Normal incremental flow: user sends function_call_output (no assistant message).
|
||||
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
||||
@@ -1502,7 +1588,9 @@ func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeSubsequentRequestAssistantIncrementalInputStillMerges(t *testing.T) {
|
||||
func TestNormalizeSubsequentRequestAssistantInputTriggersTranscriptReplacement(t *testing.T) {
|
||||
// After dev's shouldReplaceWebsocketTranscript, assistant messages in input
|
||||
// trigger transcript replacement (no merge with prior state).
|
||||
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
||||
{"type":"message","role":"user","id":"msg-1","content":"hello"}
|
||||
]}`)
|
||||
@@ -1520,14 +1608,10 @@ func TestNormalizeSubsequentRequestAssistantIncrementalInputStillMerges(t *testi
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 4 {
|
||||
t.Fatalf("input len = %d, want 4 (merged)", len(input))
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("input len = %d, want 1 (transcript replacement, not merge)", len(input))
|
||||
}
|
||||
wantIDs := []string{"msg-1", "msg-2", "fc-1", "msg-3"}
|
||||
for i, want := range wantIDs {
|
||||
got := input[i].Get("id").String()
|
||||
if got != want {
|
||||
t.Fatalf("input[%d].id = %q, want %q", i, got, want)
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-3" {
|
||||
t.Fatalf("input[0].id = %q, want %q", input[0].Get("id").String(), "msg-3")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user