fix(websocket): gate compact replay by downstream support
This commit is contained in:
@@ -116,6 +116,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
allowCompactionReplayBypass := false
|
||||||
|
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||||
|
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
|
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
if requestModelName == "" {
|
||||||
|
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
}
|
||||||
|
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
|
||||||
|
}
|
||||||
|
|
||||||
var requestJSON []byte
|
var requestJSON []byte
|
||||||
var updatedLastRequest []byte
|
var updatedLastRequest []byte
|
||||||
var errMsg *interfaces.ErrorMessage
|
var errMsg *interfaces.ErrorMessage
|
||||||
@@ -124,6 +137,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
lastRequest,
|
lastRequest,
|
||||||
lastResponseOutput,
|
lastResponseOutput,
|
||||||
allowIncrementalInputWithPreviousResponseID,
|
allowIncrementalInputWithPreviousResponseID,
|
||||||
|
allowCompactionReplayBypass,
|
||||||
)
|
)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
@@ -222,10 +236,10 @@ func websocketUpgradeHeaders(req *http.Request) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||||
switch requestType {
|
switch requestType {
|
||||||
case wsRequestTypeCreate:
|
case wsRequestTypeCreate:
|
||||||
@@ -233,10 +247,10 @@ func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []by
|
|||||||
if len(lastRequest) == 0 {
|
if len(lastRequest) == 0 {
|
||||||
return normalizeResponseCreateRequest(rawJSON)
|
return normalizeResponseCreateRequest(rawJSON)
|
||||||
}
|
}
|
||||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
|
||||||
case wsRequestTypeAppend:
|
case wsRequestTypeAppend:
|
||||||
// log.Infof("responses websocket: response.append request")
|
// log.Infof("responses websocket: response.append request")
|
||||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
|
||||||
default:
|
default:
|
||||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
StatusCode: http.StatusBadRequest,
|
StatusCode: http.StatusBadRequest,
|
||||||
@@ -265,7 +279,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
|
|||||||
return normalized, bytes.Clone(normalized), nil
|
return normalized, bytes.Clone(normalized), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
if len(lastRequest) == 0 {
|
if len(lastRequest) == 0 {
|
||||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
StatusCode: http.StatusBadRequest,
|
StatusCode: http.StatusBadRequest,
|
||||||
@@ -315,16 +329,21 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// When the client sends a full conversation transcript (e.g. after compact),
|
// When the client sends a compact replay for a downstream that can consume it
|
||||||
// the input already contains the complete history including assistant messages.
|
// directly, the input already carries the canonical history. In that case,
|
||||||
// In that case, skip merging with stale lastRequest/lastResponseOutput to avoid
|
// skip merging with stale lastRequest/lastResponseOutput to avoid breaking
|
||||||
// breaking function_call / function_call_output pairings.
|
// function_call / function_call_output pairings.
|
||||||
// See: https://github.com/router-for-me/CLIProxyAPI/issues/2207
|
// See: https://github.com/router-for-me/CLIProxyAPI/issues/2207
|
||||||
var mergedInput string
|
var mergedInput string
|
||||||
if inputContainsFullTranscript(nextInput) {
|
if allowCompactionReplayBypass && inputContainsFullTranscript(nextInput) {
|
||||||
log.Infof("responses websocket: full transcript detected, skipping stale merge (input items=%d)", len(nextInput.Array()))
|
log.Infof("responses websocket: full transcript detected, skipping stale merge (input items=%d)", len(nextInput.Array()))
|
||||||
mergedInput = nextInput.Raw
|
mergedInput = nextInput.Raw
|
||||||
} else {
|
} else {
|
||||||
|
appendInputRaw := nextInput.Raw
|
||||||
|
if inputContainsFullTranscript(nextInput) {
|
||||||
|
appendInputRaw = inputWithoutCompactionItems(nextInput)
|
||||||
|
}
|
||||||
|
|
||||||
existingInput := gjson.GetBytes(lastRequest, "input")
|
existingInput := gjson.GetBytes(lastRequest, "input")
|
||||||
var errMerge error
|
var errMerge error
|
||||||
mergedInput, errMerge = mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
mergedInput, errMerge = mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
||||||
@@ -335,7 +354,7 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, appendInputRaw)
|
||||||
if errMerge != nil {
|
if errMerge != nil {
|
||||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
StatusCode: http.StatusBadRequest,
|
StatusCode: http.StatusBadRequest,
|
||||||
@@ -492,72 +511,104 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||||
if h == nil || h.AuthManager == nil {
|
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
|
||||||
|
for _, auth := range auths {
|
||||||
|
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsCompactionReplayForModel(modelName string) bool {
|
||||||
|
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
|
||||||
|
if len(auths) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
for _, auth := range auths {
|
||||||
|
if !responsesWebsocketAuthSupportsCompactionReplay(auth) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
resolvedModelName := modelName
|
func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(modelName string) ([]*coreauth.Auth, string) {
|
||||||
|
if h == nil || h.AuthManager == nil {
|
||||||
|
return nil, ""
|
||||||
|
}
|
||||||
|
resolvedModelName := responsesWebsocketResolvedModelName(modelName)
|
||||||
|
providerSet, modelKey := responsesWebsocketProviderSetForModel(resolvedModelName)
|
||||||
|
if len(providerSet) == 0 {
|
||||||
|
return nil, modelKey
|
||||||
|
}
|
||||||
|
|
||||||
|
registryRef := registry.GetGlobalRegistry()
|
||||||
|
now := time.Now()
|
||||||
|
auths := h.AuthManager.List()
|
||||||
|
available := make([]*coreauth.Auth, 0, len(auths))
|
||||||
|
for _, auth := range auths {
|
||||||
|
if !responsesWebsocketAuthMatchesModel(auth, providerSet, modelKey, registryRef, now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
available = append(available, auth)
|
||||||
|
}
|
||||||
|
return available, modelKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesWebsocketResolvedModelName(modelName string) string {
|
||||||
initialSuffix := thinking.ParseSuffix(modelName)
|
initialSuffix := thinking.ParseSuffix(modelName)
|
||||||
if initialSuffix.ModelName == "auto" {
|
if initialSuffix.ModelName == "auto" {
|
||||||
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||||
if initialSuffix.HasSuffix {
|
if initialSuffix.HasSuffix {
|
||||||
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
return fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||||
} else {
|
|
||||||
resolvedModelName = resolvedBase
|
|
||||||
}
|
}
|
||||||
} else {
|
return resolvedBase
|
||||||
resolvedModelName = util.ResolveAutoModel(modelName)
|
|
||||||
}
|
}
|
||||||
|
return util.ResolveAutoModel(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesWebsocketProviderSetForModel(resolvedModelName string) (map[string]struct{}, string) {
|
||||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||||
providers := util.GetProviderName(baseModel)
|
providers := util.GetProviderName(baseModel)
|
||||||
if len(providers) == 0 && baseModel != resolvedModelName {
|
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||||
providers = util.GetProviderName(resolvedModelName)
|
providers = util.GetProviderName(resolvedModelName)
|
||||||
}
|
}
|
||||||
if len(providers) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
providerSet := make(map[string]struct{}, len(providers))
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
for i := 0; i < len(providers); i++ {
|
for _, provider := range providers {
|
||||||
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
providerKey := strings.TrimSpace(strings.ToLower(provider))
|
||||||
if providerKey == "" {
|
if providerKey == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
providerSet[providerKey] = struct{}{}
|
providerSet[providerKey] = struct{}{}
|
||||||
}
|
}
|
||||||
if len(providerSet) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
modelKey := baseModel
|
modelKey := baseModel
|
||||||
if modelKey == "" {
|
if modelKey == "" {
|
||||||
modelKey = strings.TrimSpace(resolvedModelName)
|
modelKey = strings.TrimSpace(resolvedModelName)
|
||||||
}
|
}
|
||||||
registryRef := registry.GetGlobalRegistry()
|
return providerSet, modelKey
|
||||||
now := time.Now()
|
}
|
||||||
auths := h.AuthManager.List()
|
|
||||||
for i := 0; i < len(auths); i++ {
|
func responsesWebsocketAuthMatchesModel(auth *coreauth.Auth, providerSet map[string]struct{}, modelKey string, registryRef *registry.ModelRegistry, now time.Time) bool {
|
||||||
auth := auths[i]
|
if auth == nil {
|
||||||
if auth == nil {
|
return false
|
||||||
continue
|
|
||||||
}
|
|
||||||
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
|
||||||
if _, ok := providerSet[providerKey]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return false
|
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||||
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return responsesWebsocketAuthAvailableForModel(auth, modelKey, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesWebsocketAuthSupportsCompactionReplay(auth *coreauth.Auth) bool {
|
||||||
|
if auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(auth.Provider), "codex")
|
||||||
}
|
}
|
||||||
|
|
||||||
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
||||||
@@ -724,6 +775,21 @@ func inputContainsFullTranscript(input gjson.Result) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func inputWithoutCompactionItems(input gjson.Result) string {
|
||||||
|
if !input.IsArray() {
|
||||||
|
return normalizeJSONArrayRaw([]byte(input.Raw))
|
||||||
|
}
|
||||||
|
filtered := make([]string, 0, len(input.Array()))
|
||||||
|
for _, item := range input.Array() {
|
||||||
|
t := item.Get("type").String()
|
||||||
|
if t == "compaction" || t == "compaction_summary" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, item.Raw)
|
||||||
|
}
|
||||||
|
return "[" + strings.Join(filtered, ",") + "]"
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeJSONArrayRaw(raw []byte) string {
|
func normalizeJSONArrayRaw(raw []byte) string {
|
||||||
trimmed := strings.TrimSpace(string(raw))
|
trimmed := strings.TrimSpace(string(raw))
|
||||||
if trimmed == "" {
|
if trimmed == "" {
|
||||||
|
|||||||
@@ -242,7 +242,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *
|
|||||||
]`)
|
]`)
|
||||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
|
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true, false)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
}
|
}
|
||||||
@@ -278,7 +278,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncre
|
|||||||
]`)
|
]`)
|
||||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
|
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
}
|
}
|
||||||
@@ -867,6 +867,53 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) {
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "auth-codex",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth: %v", err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
if !h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
|
||||||
|
t.Fatalf("expected codex upstream to support compaction replay")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketUpstreamSupportsCompactionReplayForModelFalseWhenMixedBackends(t *testing.T) {
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
auths := []*coreauth.Auth{
|
||||||
|
{ID: "auth-codex", Provider: "codex", Status: coreauth.StatusActive},
|
||||||
|
{ID: "auth-claude", Provider: "claude", Status: coreauth.StatusActive},
|
||||||
|
}
|
||||||
|
for _, auth := range auths {
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth %s: %v", auth.ID, err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, auth := range auths {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
if h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
|
||||||
|
t.Fatalf("expected mixed backend model to disable compaction replay bypass")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
@@ -1469,6 +1516,45 @@ func TestNormalizeSubsequentRequestCompactSkipsMerge(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeSubsequentRequestCompactMergesWhenCompactionReplayUnsupported(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
||||||
|
{"type":"message","role":"user","id":"msg-1","content":"original long prompt"},
|
||||||
|
{"type":"message","role":"assistant","id":"msg-2","content":"original long response"},
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-old","name":"bash","arguments":"{}"},
|
||||||
|
{"type":"function_call_output","id":"fco-1","call_id":"call-old","output":"old result"}
|
||||||
|
]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"message","role":"assistant","id":"msg-3","content":"another assistant reply"},
|
||||||
|
{"type":"function_call","id":"fc-2","call_id":"call-stale","name":"read","arguments":"{}"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","input":[
|
||||||
|
{"type":"message","role":"user","id":"msg-1c","content":"compacted user msg"},
|
||||||
|
{"type":"compaction","encrypted_content":"conversation summary"}
|
||||||
|
]}`)
|
||||||
|
|
||||||
|
normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 7 {
|
||||||
|
t.Fatalf("input len = %d, want 7 (merged fallback without compaction items)", len(input))
|
||||||
|
}
|
||||||
|
wantIDs := []string{"msg-1", "msg-2", "fc-1", "fco-1", "msg-3", "fc-2", "msg-1c"}
|
||||||
|
for i, want := range wantIDs {
|
||||||
|
got := input[i].Get("id").String()
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("input[%d].id = %q, want %q", i, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, item := range input {
|
||||||
|
if item.Get("type").String() == "compaction" || item.Get("type").String() == "compaction_summary" {
|
||||||
|
t.Fatalf("compaction items must be stripped for unsupported downstream fallback: %s", item.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
|
func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
|
||||||
// Normal incremental flow: user sends function_call_output (no assistant message).
|
// Normal incremental flow: user sends function_call_output (no assistant message).
|
||||||
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
||||||
@@ -1502,7 +1588,9 @@ func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNormalizeSubsequentRequestAssistantIncrementalInputStillMerges(t *testing.T) {
|
func TestNormalizeSubsequentRequestAssistantInputTriggersTranscriptReplacement(t *testing.T) {
|
||||||
|
// After dev's shouldReplaceWebsocketTranscript, assistant messages in input
|
||||||
|
// trigger transcript replacement (no merge with prior state).
|
||||||
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
|
||||||
{"type":"message","role":"user","id":"msg-1","content":"hello"}
|
{"type":"message","role":"user","id":"msg-1","content":"hello"}
|
||||||
]}`)
|
]}`)
|
||||||
@@ -1520,14 +1608,10 @@ func TestNormalizeSubsequentRequestAssistantIncrementalInputStillMerges(t *testi
|
|||||||
}
|
}
|
||||||
|
|
||||||
input := gjson.GetBytes(normalized, "input").Array()
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
if len(input) != 4 {
|
if len(input) != 1 {
|
||||||
t.Fatalf("input len = %d, want 4 (merged)", len(input))
|
t.Fatalf("input len = %d, want 1 (transcript replacement, not merge)", len(input))
|
||||||
}
|
}
|
||||||
wantIDs := []string{"msg-1", "msg-2", "fc-1", "msg-3"}
|
if input[0].Get("id").String() != "msg-3" {
|
||||||
for i, want := range wantIDs {
|
t.Fatalf("input[0].id = %q, want %q", input[0].Get("id").String(), "msg-3")
|
||||||
got := input[i].Get("id").String()
|
|
||||||
if got != want {
|
|
||||||
t.Fatalf("input[%d].id = %q, want %q", i, got, want)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user