test(websocket): add tests for incremental input and prewarm handling logic - Added test cases for incremental input support based on upstream capabilities. - Introduced validation for prewarm handling of `response.create` messages locally. - Enhanced test coverage for websocket executor behavior, including payload forwarding checks. - Updated websocket implementation with prewarm and incremental input logic for better testability.
887 lines
28 KiB
Go
887 lines
28 KiB
Go
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
wsRequestTypeCreate = "response.create"
|
|
wsRequestTypeAppend = "response.append"
|
|
wsEventTypeError = "error"
|
|
wsEventTypeCompleted = "response.completed"
|
|
wsDoneMarker = "[DONE]"
|
|
wsTurnStateHeader = "x-codex-turn-state"
|
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
|
wsPayloadLogMaxSize = 2048
|
|
)
|
|
|
|
var responsesWebsocketUpgrader = websocket.Upgrader{
|
|
ReadBufferSize: 4096,
|
|
WriteBufferSize: 4096,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
}
|
|
|
|
// ResponsesWebsocket handles websocket requests for /v1/responses.
|
|
// It accepts `response.create` and `response.append` requests and streams
|
|
// response events back as JSON websocket text messages.
|
|
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
|
|
if err != nil {
|
|
return
|
|
}
|
|
passthroughSessionID := uuid.NewString()
|
|
clientRemoteAddr := ""
|
|
if c != nil && c.Request != nil {
|
|
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
|
}
|
|
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
|
var wsTerminateErr error
|
|
var wsBodyLog strings.Builder
|
|
defer func() {
|
|
if wsTerminateErr != nil {
|
|
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
|
} else {
|
|
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
|
}
|
|
if h != nil && h.AuthManager != nil {
|
|
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
|
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
|
}
|
|
setWebsocketRequestBody(c, wsBodyLog.String())
|
|
if errClose := conn.Close(); errClose != nil {
|
|
log.Warnf("responses websocket: close connection error: %v", errClose)
|
|
}
|
|
}()
|
|
|
|
var lastRequest []byte
|
|
lastResponseOutput := []byte("[]")
|
|
pinnedAuthID := ""
|
|
|
|
for {
|
|
msgType, payload, errReadMessage := conn.ReadMessage()
|
|
if errReadMessage != nil {
|
|
wsTerminateErr = errReadMessage
|
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
|
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
|
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
|
} else {
|
|
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
|
|
}
|
|
return
|
|
}
|
|
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
|
continue
|
|
}
|
|
// log.Infof(
|
|
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
|
|
// passthroughSessionID,
|
|
// msgType,
|
|
// websocketPayloadEventType(payload),
|
|
// websocketPayloadPreview(payload),
|
|
// )
|
|
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
|
|
|
allowIncrementalInputWithPreviousResponseID := false
|
|
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
|
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
|
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
|
}
|
|
} else {
|
|
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
|
if requestModelName == "" {
|
|
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
|
}
|
|
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
|
}
|
|
|
|
var requestJSON []byte
|
|
var updatedLastRequest []byte
|
|
var errMsg *interfaces.ErrorMessage
|
|
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
|
|
payload,
|
|
lastRequest,
|
|
lastResponseOutput,
|
|
allowIncrementalInputWithPreviousResponseID,
|
|
)
|
|
if errMsg != nil {
|
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
|
markAPIResponseTimestamp(c)
|
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
|
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
|
log.Infof(
|
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
|
passthroughSessionID,
|
|
websocket.TextMessage,
|
|
websocketPayloadEventType(errorPayload),
|
|
websocketPayloadPreview(errorPayload),
|
|
)
|
|
if errWrite != nil {
|
|
log.Warnf(
|
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
|
passthroughSessionID,
|
|
websocketPayloadEventType(errorPayload),
|
|
errWrite,
|
|
)
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
|
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
|
requestJSON = updated
|
|
}
|
|
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
|
|
updatedLastRequest = updated
|
|
}
|
|
lastRequest = updatedLastRequest
|
|
lastResponseOutput = []byte("[]")
|
|
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
|
wsTerminateErr = errWrite
|
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
lastRequest = updatedLastRequest
|
|
|
|
modelName := gjson.GetBytes(requestJSON, "model").String()
|
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
|
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
|
|
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
|
|
if pinnedAuthID != "" {
|
|
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
|
} else {
|
|
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
|
pinnedAuthID = strings.TrimSpace(authID)
|
|
})
|
|
}
|
|
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
|
|
|
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
|
if errForward != nil {
|
|
wsTerminateErr = errForward
|
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
|
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
|
return
|
|
}
|
|
lastResponseOutput = completedOutput
|
|
}
|
|
}
|
|
|
|
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
|
headers := http.Header{}
|
|
if req == nil {
|
|
return headers
|
|
}
|
|
|
|
// Keep the same sticky turn-state across reconnects when provided by the client.
|
|
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
|
|
if turnState != "" {
|
|
headers.Set(wsTurnStateHeader, turnState)
|
|
}
|
|
return headers
|
|
}
|
|
|
|
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
|
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
|
}
|
|
|
|
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
|
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
|
switch requestType {
|
|
case wsRequestTypeCreate:
|
|
// log.Infof("responses websocket: response.create request")
|
|
if len(lastRequest) == 0 {
|
|
return normalizeResponseCreateRequest(rawJSON)
|
|
}
|
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
|
case wsRequestTypeAppend:
|
|
// log.Infof("responses websocket: response.append request")
|
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
|
default:
|
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
|
}
|
|
}
|
|
}
|
|
|
|
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
|
if errDelete != nil {
|
|
normalized = bytes.Clone(rawJSON)
|
|
}
|
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
|
if !gjson.GetBytes(normalized, "input").Exists() {
|
|
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
|
|
}
|
|
|
|
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
|
|
if modelName == "" {
|
|
return nil, nil, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("missing model in response.create request"),
|
|
}
|
|
}
|
|
return normalized, bytes.Clone(normalized), nil
|
|
}
|
|
|
|
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
|
if len(lastRequest) == 0 {
|
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("websocket request received before response.create"),
|
|
}
|
|
}
|
|
|
|
nextInput := gjson.GetBytes(rawJSON, "input")
|
|
if !nextInput.Exists() || !nextInput.IsArray() {
|
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("websocket request requires array field: input"),
|
|
}
|
|
}
|
|
|
|
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
|
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
|
if allowIncrementalInputWithPreviousResponseID {
|
|
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
|
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
|
if errDelete != nil {
|
|
normalized = bytes.Clone(rawJSON)
|
|
}
|
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
|
if modelName != "" {
|
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
|
}
|
|
}
|
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
|
if instructions.Exists() {
|
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
|
}
|
|
}
|
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
|
return normalized, bytes.Clone(normalized), nil
|
|
}
|
|
}
|
|
|
|
existingInput := gjson.GetBytes(lastRequest, "input")
|
|
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
|
if errMerge != nil {
|
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
|
|
}
|
|
}
|
|
|
|
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
|
if errMerge != nil {
|
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
|
}
|
|
}
|
|
|
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
|
if errDelete != nil {
|
|
normalized = bytes.Clone(rawJSON)
|
|
}
|
|
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
|
var errSet error
|
|
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
|
|
if errSet != nil {
|
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusBadRequest,
|
|
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
|
|
}
|
|
}
|
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
|
if modelName != "" {
|
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
|
}
|
|
}
|
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
|
if instructions.Exists() {
|
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
|
}
|
|
}
|
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
|
return normalized, bytes.Clone(normalized), nil
|
|
}
|
|
|
|
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
|
if len(attributes) > 0 {
|
|
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
|
parsed, errParse := strconv.ParseBool(raw)
|
|
if errParse == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
}
|
|
if len(metadata) == 0 {
|
|
return false
|
|
}
|
|
raw, ok := metadata["websockets"]
|
|
if !ok || raw == nil {
|
|
return false
|
|
}
|
|
switch value := raw.(type) {
|
|
case bool:
|
|
return value
|
|
case string:
|
|
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
|
|
if errParse == nil {
|
|
return parsed
|
|
}
|
|
default:
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
|
if h == nil || h.AuthManager == nil {
|
|
return false
|
|
}
|
|
|
|
resolvedModelName := modelName
|
|
initialSuffix := thinking.ParseSuffix(modelName)
|
|
if initialSuffix.ModelName == "auto" {
|
|
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
|
if initialSuffix.HasSuffix {
|
|
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
|
} else {
|
|
resolvedModelName = resolvedBase
|
|
}
|
|
} else {
|
|
resolvedModelName = util.ResolveAutoModel(modelName)
|
|
}
|
|
|
|
parsed := thinking.ParseSuffix(resolvedModelName)
|
|
baseModel := strings.TrimSpace(parsed.ModelName)
|
|
providers := util.GetProviderName(baseModel)
|
|
if len(providers) == 0 && baseModel != resolvedModelName {
|
|
providers = util.GetProviderName(resolvedModelName)
|
|
}
|
|
if len(providers) == 0 {
|
|
return false
|
|
}
|
|
|
|
providerSet := make(map[string]struct{}, len(providers))
|
|
for i := 0; i < len(providers); i++ {
|
|
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
|
if providerKey == "" {
|
|
continue
|
|
}
|
|
providerSet[providerKey] = struct{}{}
|
|
}
|
|
if len(providerSet) == 0 {
|
|
return false
|
|
}
|
|
|
|
modelKey := baseModel
|
|
if modelKey == "" {
|
|
modelKey = strings.TrimSpace(resolvedModelName)
|
|
}
|
|
registryRef := registry.GetGlobalRegistry()
|
|
now := time.Now()
|
|
auths := h.AuthManager.List()
|
|
for i := 0; i < len(auths); i++ {
|
|
auth := auths[i]
|
|
if auth == nil {
|
|
continue
|
|
}
|
|
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
|
if _, ok := providerSet[providerKey]; !ok {
|
|
continue
|
|
}
|
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
|
continue
|
|
}
|
|
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
|
continue
|
|
}
|
|
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
|
if auth == nil {
|
|
return false
|
|
}
|
|
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
|
|
return false
|
|
}
|
|
if modelName != "" && len(auth.ModelStates) > 0 {
|
|
state, ok := auth.ModelStates[modelName]
|
|
if (!ok || state == nil) && modelName != "" {
|
|
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
|
|
if baseModel != "" && baseModel != modelName {
|
|
state, ok = auth.ModelStates[baseModel]
|
|
}
|
|
}
|
|
if ok && state != nil {
|
|
if state.Status == coreauth.StatusDisabled {
|
|
return false
|
|
}
|
|
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
}
|
|
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
|
|
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
|
|
return false
|
|
}
|
|
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
|
return false
|
|
}
|
|
generateResult := gjson.GetBytes(rawJSON, "generate")
|
|
return generateResult.Exists() && !generateResult.Bool()
|
|
}
|
|
|
|
func writeResponsesWebsocketSyntheticPrewarm(
|
|
c *gin.Context,
|
|
conn *websocket.Conn,
|
|
requestJSON []byte,
|
|
wsBodyLog *strings.Builder,
|
|
sessionID string,
|
|
) error {
|
|
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
|
if errPayloads != nil {
|
|
return errPayloads
|
|
}
|
|
for i := 0; i < len(payloads); i++ {
|
|
markAPIResponseTimestamp(c)
|
|
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
|
// log.Infof(
|
|
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
|
// sessionID,
|
|
// websocket.TextMessage,
|
|
// websocketPayloadEventType(payloads[i]),
|
|
// websocketPayloadPreview(payloads[i]),
|
|
// )
|
|
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
|
log.Warnf(
|
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
|
sessionID,
|
|
websocketPayloadEventType(payloads[i]),
|
|
errWrite,
|
|
)
|
|
return errWrite
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
|
|
responseID := "resp_prewarm_" + uuid.NewString()
|
|
createdAt := time.Now().Unix()
|
|
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
|
|
|
|
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
|
var errSet error
|
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
if modelName != "" {
|
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
}
|
|
|
|
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
if modelName != "" {
|
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
}
|
|
|
|
return [][]byte{createdPayload, completedPayload}, nil
|
|
}
|
|
|
|
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
|
existingRaw = strings.TrimSpace(existingRaw)
|
|
appendRaw = strings.TrimSpace(appendRaw)
|
|
if existingRaw == "" {
|
|
existingRaw = "[]"
|
|
}
|
|
if appendRaw == "" {
|
|
appendRaw = "[]"
|
|
}
|
|
|
|
var existing []json.RawMessage
|
|
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
|
|
return "", err
|
|
}
|
|
var appendItems []json.RawMessage
|
|
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
merged := append(existing, appendItems...)
|
|
out, err := json.Marshal(merged)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(out), nil
|
|
}
|
|
|
|
func normalizeJSONArrayRaw(raw []byte) string {
|
|
trimmed := strings.TrimSpace(string(raw))
|
|
if trimmed == "" {
|
|
return "[]"
|
|
}
|
|
result := gjson.Parse(trimmed)
|
|
if result.Type == gjson.JSON && result.IsArray() {
|
|
return trimmed
|
|
}
|
|
return "[]"
|
|
}
|
|
|
|
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
|
c *gin.Context,
|
|
conn *websocket.Conn,
|
|
cancel handlers.APIHandlerCancelFunc,
|
|
data <-chan []byte,
|
|
errs <-chan *interfaces.ErrorMessage,
|
|
wsBodyLog *strings.Builder,
|
|
sessionID string,
|
|
) ([]byte, error) {
|
|
completed := false
|
|
completedOutput := []byte("[]")
|
|
|
|
for {
|
|
select {
|
|
case <-c.Request.Context().Done():
|
|
cancel(c.Request.Context().Err())
|
|
return completedOutput, c.Request.Context().Err()
|
|
case errMsg, ok := <-errs:
|
|
if !ok {
|
|
errs = nil
|
|
continue
|
|
}
|
|
if errMsg != nil {
|
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
|
markAPIResponseTimestamp(c)
|
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
|
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
|
log.Infof(
|
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
|
sessionID,
|
|
websocket.TextMessage,
|
|
websocketPayloadEventType(errorPayload),
|
|
websocketPayloadPreview(errorPayload),
|
|
)
|
|
if errWrite != nil {
|
|
// log.Warnf(
|
|
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
|
// sessionID,
|
|
// websocketPayloadEventType(errorPayload),
|
|
// errWrite,
|
|
// )
|
|
cancel(errMsg.Error)
|
|
return completedOutput, errWrite
|
|
}
|
|
}
|
|
if errMsg != nil {
|
|
cancel(errMsg.Error)
|
|
} else {
|
|
cancel(nil)
|
|
}
|
|
return completedOutput, nil
|
|
case chunk, ok := <-data:
|
|
if !ok {
|
|
if !completed {
|
|
errMsg := &interfaces.ErrorMessage{
|
|
StatusCode: http.StatusRequestTimeout,
|
|
Error: fmt.Errorf("stream closed before response.completed"),
|
|
}
|
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
|
markAPIResponseTimestamp(c)
|
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
|
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
|
log.Infof(
|
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
|
sessionID,
|
|
websocket.TextMessage,
|
|
websocketPayloadEventType(errorPayload),
|
|
websocketPayloadPreview(errorPayload),
|
|
)
|
|
if errWrite != nil {
|
|
log.Warnf(
|
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
|
sessionID,
|
|
websocketPayloadEventType(errorPayload),
|
|
errWrite,
|
|
)
|
|
cancel(errMsg.Error)
|
|
return completedOutput, errWrite
|
|
}
|
|
cancel(errMsg.Error)
|
|
return completedOutput, nil
|
|
}
|
|
cancel(nil)
|
|
return completedOutput, nil
|
|
}
|
|
|
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
|
for i := range payloads {
|
|
eventType := gjson.GetBytes(payloads[i], "type").String()
|
|
if eventType == wsEventTypeCompleted {
|
|
completed = true
|
|
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
|
}
|
|
markAPIResponseTimestamp(c)
|
|
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
|
// log.Infof(
|
|
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
|
// sessionID,
|
|
// websocket.TextMessage,
|
|
// websocketPayloadEventType(payloads[i]),
|
|
// websocketPayloadPreview(payloads[i]),
|
|
// )
|
|
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
|
log.Warnf(
|
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
|
sessionID,
|
|
websocketPayloadEventType(payloads[i]),
|
|
errWrite,
|
|
)
|
|
cancel(errWrite)
|
|
return completedOutput, errWrite
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
|
output := gjson.GetBytes(payload, "response.output")
|
|
if output.Exists() && output.IsArray() {
|
|
return bytes.Clone([]byte(output.Raw))
|
|
}
|
|
return []byte("[]")
|
|
}
|
|
|
|
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
|
payloads := make([][]byte, 0, 2)
|
|
lines := bytes.Split(chunk, []byte("\n"))
|
|
for i := range lines {
|
|
line := bytes.TrimSpace(lines[i])
|
|
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
|
|
continue
|
|
}
|
|
if bytes.HasPrefix(line, []byte("data:")) {
|
|
line = bytes.TrimSpace(line[len("data:"):])
|
|
}
|
|
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
|
|
continue
|
|
}
|
|
if json.Valid(line) {
|
|
payloads = append(payloads, bytes.Clone(line))
|
|
}
|
|
}
|
|
|
|
if len(payloads) > 0 {
|
|
return payloads
|
|
}
|
|
|
|
trimmed := bytes.TrimSpace(chunk)
|
|
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
|
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
|
}
|
|
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
|
|
payloads = append(payloads, bytes.Clone(trimmed))
|
|
}
|
|
return payloads
|
|
}
|
|
|
|
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
|
status := http.StatusInternalServerError
|
|
errText := http.StatusText(status)
|
|
if errMsg != nil {
|
|
if errMsg.StatusCode > 0 {
|
|
status = errMsg.StatusCode
|
|
errText = http.StatusText(status)
|
|
}
|
|
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
|
errText = errMsg.Error.Error()
|
|
}
|
|
}
|
|
|
|
body := handlers.BuildErrorResponseBody(status, errText)
|
|
payload := []byte(`{}`)
|
|
var errSet error
|
|
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
payload, errSet = sjson.SetBytes(payload, "status", status)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
|
|
if errMsg != nil && errMsg.Addon != nil {
|
|
headers := []byte(`{}`)
|
|
hasHeaders := false
|
|
for key, values := range errMsg.Addon {
|
|
if len(values) == 0 {
|
|
continue
|
|
}
|
|
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
|
|
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
hasHeaders = true
|
|
}
|
|
if hasHeaders {
|
|
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(body) > 0 && json.Valid(body) {
|
|
errorNode := gjson.GetBytes(body, "error")
|
|
if errorNode.Exists() {
|
|
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
|
|
} else {
|
|
payload, errSet = sjson.SetRawBytes(payload, "error", body)
|
|
}
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
}
|
|
|
|
if !gjson.GetBytes(payload, "error").Exists() {
|
|
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
|
|
if errSet != nil {
|
|
return nil, errSet
|
|
}
|
|
}
|
|
|
|
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
|
}
|
|
|
|
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
|
if builder == nil {
|
|
return
|
|
}
|
|
trimmedPayload := bytes.TrimSpace(payload)
|
|
if len(trimmedPayload) == 0 {
|
|
return
|
|
}
|
|
if builder.Len() > 0 {
|
|
builder.WriteString("\n")
|
|
}
|
|
builder.WriteString("websocket.")
|
|
builder.WriteString(eventType)
|
|
builder.WriteString("\n")
|
|
builder.Write(trimmedPayload)
|
|
builder.WriteString("\n")
|
|
}
|
|
|
|
func websocketPayloadEventType(payload []byte) string {
|
|
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
|
if eventType == "" {
|
|
return "-"
|
|
}
|
|
return eventType
|
|
}
|
|
|
|
func websocketPayloadPreview(payload []byte) string {
|
|
trimmedPayload := bytes.TrimSpace(payload)
|
|
if len(trimmedPayload) == 0 {
|
|
return "<empty>"
|
|
}
|
|
preview := trimmedPayload
|
|
if len(preview) > wsPayloadLogMaxSize {
|
|
preview = preview[:wsPayloadLogMaxSize]
|
|
}
|
|
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
|
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
|
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
|
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
|
}
|
|
return previewText
|
|
}
|
|
|
|
func setWebsocketRequestBody(c *gin.Context, body string) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
trimmedBody := strings.TrimSpace(body)
|
|
if trimmedBody == "" {
|
|
return
|
|
}
|
|
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
|
}
|
|
|
|
func markAPIResponseTimestamp(c *gin.Context) {
|
|
if c == nil {
|
|
return
|
|
}
|
|
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
|
|
return
|
|
}
|
|
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
|
}
|