Merge pull request #2424 from possible055/fix/websocket-transcript-replacement
fix(openai): handle transcript replacement after websocket v2 compaction
This commit is contained in:
@@ -277,6 +277,15 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compaction can cause clients to replace local websocket history with a new
|
||||||
|
// compact transcript on the next `response.create`. When the input already
|
||||||
|
// contains historical model output items, treating it as an incremental append
|
||||||
|
// duplicates stale turn-state and can leave late orphaned function_call items.
|
||||||
|
if shouldReplaceWebsocketTranscript(rawJSON, nextInput) {
|
||||||
|
normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest)
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||||
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||||
if allowIncrementalInputWithPreviousResponseID {
|
if allowIncrementalInputWithPreviousResponseID {
|
||||||
@@ -348,6 +357,54 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
|||||||
return normalized, bytes.Clone(normalized), nil
|
return normalized, bytes.Clone(normalized), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bool {
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, item := range nextInput.Array() {
|
||||||
|
switch strings.TrimSpace(item.Get("type").String()) {
|
||||||
|
case "function_call":
|
||||||
|
return true
|
||||||
|
case "message":
|
||||||
|
role := strings.TrimSpace(item.Get("role").String())
|
||||||
|
if role == "assistant" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte {
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||||
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
if modelName != "" {
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||||
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||||
|
if instructions.Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
return bytes.Clone(normalized)
|
||||||
|
}
|
||||||
|
|
||||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||||
if len(attributes) > 0 {
|
if len(attributes) > 0 {
|
||||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||||
|
|||||||
@@ -27,6 +27,12 @@ type websocketCaptureExecutor struct {
|
|||||||
payloads [][]byte
|
payloads [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type websocketCompactionCaptureExecutor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
streamPayloads [][]byte
|
||||||
|
compactPayload []byte
|
||||||
|
}
|
||||||
|
|
||||||
type orderedWebsocketSelector struct {
|
type orderedWebsocketSelector struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
order []string
|
order []string
|
||||||
@@ -126,6 +132,52 @@ func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth,
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *websocketCompactionCaptureExecutor) Identifier() string { return "test-provider" }
|
||||||
|
|
||||||
|
func (e *websocketCompactionCaptureExecutor) Execute(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
e.mu.Lock()
|
||||||
|
e.compactPayload = bytes.Clone(req.Payload)
|
||||||
|
e.mu.Unlock()
|
||||||
|
if opts.Alt != "responses/compact" {
|
||||||
|
return coreexecutor.Response{}, fmt.Errorf("unexpected non-compact execute alt: %q", opts.Alt)
|
||||||
|
}
|
||||||
|
return coreexecutor.Response{Payload: []byte(`{"id":"cmp-1","object":"response.compaction"}`)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCompactionCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
e.mu.Lock()
|
||||||
|
callIndex := len(e.streamPayloads)
|
||||||
|
e.streamPayloads = append(e.streamPayloads, bytes.Clone(req.Payload))
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
var payload []byte
|
||||||
|
switch callIndex {
|
||||||
|
case 0:
|
||||||
|
payload = []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}]}}`)
|
||||||
|
case 1:
|
||||||
|
payload = []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[{"type":"message","id":"assistant-1"}]}}`)
|
||||||
|
default:
|
||||||
|
payload = []byte(`{"type":"response.completed","response":{"id":"resp-3","output":[{"type":"message","id":"assistant-2"}]}}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
chunks <- coreexecutor.StreamChunk{Payload: payload}
|
||||||
|
close(chunks)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCompactionCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCompactionCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCompactionCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
|
||||||
@@ -662,3 +714,160 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
|
|||||||
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
|
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"message","id":"assistant-1","role":"assistant"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
|
||||||
|
t.Fatalf("previous_response_id must not exist in transcript replacement mode")
|
||||||
|
}
|
||||||
|
items := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(items) != 2 {
|
||||||
|
t.Fatalf("replacement input len = %d, want 2: %s", len(items), normalized)
|
||||||
|
}
|
||||||
|
if items[0].Get("id").String() != "fc-compact" || items[1].Get("id").String() != "msg-2" {
|
||||||
|
t.Fatalf("replacement transcript was not preserved as-is: %s", normalized)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match replacement request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestDoesNotTreatDeveloperMessageAsReplacement(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"message","id":"assistant-1","role":"assistant"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","input":[{"type":"message","id":"dev-1","role":"developer"},{"type":"message","id":"msg-2"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
items := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(items) != 4 {
|
||||||
|
t.Fatalf("merged input len = %d, want 4: %s", len(items), normalized)
|
||||||
|
}
|
||||||
|
if items[0].Get("id").String() != "msg-1" ||
|
||||||
|
items[1].Get("id").String() != "assistant-1" ||
|
||||||
|
items[2].Get("id").String() != "dev-1" ||
|
||||||
|
items[3].Get("id").String() != "msg-2" {
|
||||||
|
t.Fatalf("developer follow-up should preserve merge behavior: %s", normalized)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match merged request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
executor := &websocketCompactionCaptureExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth: %v", err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||||
|
router.POST("/v1/responses/compact", h.Compact)
|
||||||
|
|
||||||
|
server := httptest.NewServer(router)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := conn.Close(); errClose != nil {
|
||||||
|
t.Fatalf("close websocket: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
requests := []string{
|
||||||
|
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
|
||||||
|
`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`,
|
||||||
|
}
|
||||||
|
for i := range requests {
|
||||||
|
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
|
||||||
|
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
|
||||||
|
}
|
||||||
|
_, payload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compactResp, errPost := server.Client().Post(
|
||||||
|
server.URL+"/v1/responses/compact",
|
||||||
|
"application/json",
|
||||||
|
strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"summary-1"}]}`),
|
||||||
|
)
|
||||||
|
if errPost != nil {
|
||||||
|
t.Fatalf("compact request failed: %v", errPost)
|
||||||
|
}
|
||||||
|
if errClose := compactResp.Body.Close(); errClose != nil {
|
||||||
|
t.Fatalf("close compact response body: %v", errClose)
|
||||||
|
}
|
||||||
|
if compactResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("compact status = %d, want %d", compactResp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate a post-compaction client turn that replaces local history with a compacted transcript.
|
||||||
|
// The websocket handler must treat this as a state reset, not append it to stale pre-compaction state.
|
||||||
|
postCompact := `{"type":"response.create","input":[{"type":"function_call","id":"fc-compact","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-2"}]}`
|
||||||
|
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(postCompact)); errWrite != nil {
|
||||||
|
t.Fatalf("write post-compact websocket message: %v", errWrite)
|
||||||
|
}
|
||||||
|
_, payload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read post-compact websocket message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("post-compact payload type = %s, want %s", got, wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
executor.mu.Lock()
|
||||||
|
defer executor.mu.Unlock()
|
||||||
|
|
||||||
|
if executor.compactPayload == nil {
|
||||||
|
t.Fatalf("compact payload was not captured")
|
||||||
|
}
|
||||||
|
if len(executor.streamPayloads) != 3 {
|
||||||
|
t.Fatalf("stream payload count = %d, want 3", len(executor.streamPayloads))
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := executor.streamPayloads[2]
|
||||||
|
items := gjson.GetBytes(merged, "input").Array()
|
||||||
|
if len(items) != 2 {
|
||||||
|
t.Fatalf("merged input len = %d, want 2: %s", len(items), merged)
|
||||||
|
}
|
||||||
|
if items[0].Get("id").String() != "fc-compact" ||
|
||||||
|
items[1].Get("id").String() != "msg-2" {
|
||||||
|
t.Fatalf("unexpected post-compact input order: %s", merged)
|
||||||
|
}
|
||||||
|
if items[0].Get("call_id").String() != "call-1" {
|
||||||
|
t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user