Fixed: #1901
test(websocket): add tests for incremental input and prewarm handling logic - Added test cases for incremental input support based on upstream capabilities. - Introduced validation for prewarm handling of `response.create` messages locally. - Enhanced test coverage for websocket executor behavior, including payload forwarding checks. - Updated websocket implementation with prewarm and incremental input logic for better testability.
This commit is contained in:
@@ -14,7 +14,11 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -100,11 +104,17 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
// )
|
// )
|
||||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||||
|
|
||||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
allowIncrementalInputWithPreviousResponseID := false
|
||||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
|
||||||
|
if requestModelName == "" {
|
||||||
|
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
}
|
||||||
|
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestJSON []byte
|
var requestJSON []byte
|
||||||
@@ -139,6 +149,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
|
||||||
|
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
|
||||||
|
requestJSON = updated
|
||||||
|
}
|
||||||
|
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
|
||||||
|
updatedLastRequest = updated
|
||||||
|
}
|
||||||
|
lastRequest = updatedLastRequest
|
||||||
|
lastResponseOutput = []byte("[]")
|
||||||
|
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
|
||||||
|
wsTerminateErr = errWrite
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
lastRequest = updatedLastRequest
|
lastRequest = updatedLastRequest
|
||||||
|
|
||||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||||
@@ -339,6 +365,192 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
|
||||||
|
if h == nil || h.AuthManager == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvedModelName := modelName
|
||||||
|
initialSuffix := thinking.ParseSuffix(modelName)
|
||||||
|
if initialSuffix.ModelName == "auto" {
|
||||||
|
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||||
|
if initialSuffix.HasSuffix {
|
||||||
|
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||||
|
} else {
|
||||||
|
resolvedModelName = resolvedBase
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||||
|
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||||
|
providers := util.GetProviderName(baseModel)
|
||||||
|
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||||
|
providers = util.GetProviderName(resolvedModelName)
|
||||||
|
}
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
|
for i := 0; i < len(providers); i++ {
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||||
|
if providerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerSet[providerKey] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(providerSet) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
modelKey := baseModel
|
||||||
|
if modelKey == "" {
|
||||||
|
modelKey = strings.TrimSpace(resolvedModelName)
|
||||||
|
}
|
||||||
|
registryRef := registry.GetGlobalRegistry()
|
||||||
|
now := time.Now()
|
||||||
|
auths := h.AuthManager.List()
|
||||||
|
for i := 0; i < len(auths); i++ {
|
||||||
|
auth := auths[i]
|
||||||
|
if auth == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||||
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
|
||||||
|
if auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if modelName != "" && len(auth.ModelStates) > 0 {
|
||||||
|
state, ok := auth.ModelStates[modelName]
|
||||||
|
if (!ok || state == nil) && modelName != "" {
|
||||||
|
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
|
||||||
|
if baseModel != "" && baseModel != modelName {
|
||||||
|
state, ok = auth.ModelStates[baseModel]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ok && state != nil {
|
||||||
|
if state.Status == coreauth.StatusDisabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
|
||||||
|
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
generateResult := gjson.GetBytes(rawJSON, "generate")
|
||||||
|
return generateResult.Exists() && !generateResult.Bool()
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeResponsesWebsocketSyntheticPrewarm(
|
||||||
|
c *gin.Context,
|
||||||
|
conn *websocket.Conn,
|
||||||
|
requestJSON []byte,
|
||||||
|
wsBodyLog *strings.Builder,
|
||||||
|
sessionID string,
|
||||||
|
) error {
|
||||||
|
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
|
||||||
|
if errPayloads != nil {
|
||||||
|
return errPayloads
|
||||||
|
}
|
||||||
|
for i := 0; i < len(payloads); i++ {
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||||
|
// log.Infof(
|
||||||
|
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
// sessionID,
|
||||||
|
// websocket.TextMessage,
|
||||||
|
// websocketPayloadEventType(payloads[i]),
|
||||||
|
// websocketPayloadPreview(payloads[i]),
|
||||||
|
// )
|
||||||
|
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
sessionID,
|
||||||
|
websocketPayloadEventType(payloads[i]),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
return errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
|
||||||
|
responseID := "resp_prewarm_" + uuid.NewString()
|
||||||
|
createdAt := time.Now().Unix()
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
|
||||||
|
|
||||||
|
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
|
||||||
|
var errSet error
|
||||||
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
if modelName != "" {
|
||||||
|
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||||
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
if modelName != "" {
|
||||||
|
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [][]byte{createdPayload, completedPayload}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||||
existingRaw = strings.TrimSpace(existingRaw)
|
existingRaw = strings.TrimSpace(existingRaw)
|
||||||
appendRaw = strings.TrimSpace(appendRaw)
|
appendRaw = strings.TrimSpace(appendRaw)
|
||||||
@@ -550,47 +762,63 @@ func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.Error
|
|||||||
}
|
}
|
||||||
|
|
||||||
body := handlers.BuildErrorResponseBody(status, errText)
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
payload := map[string]any{
|
payload := []byte(`{}`)
|
||||||
"type": wsEventTypeError,
|
var errSet error
|
||||||
"status": status,
|
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
payload, errSet = sjson.SetBytes(payload, "status", status)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
}
|
}
|
||||||
|
|
||||||
if errMsg != nil && errMsg.Addon != nil {
|
if errMsg != nil && errMsg.Addon != nil {
|
||||||
headers := map[string]any{}
|
headers := []byte(`{}`)
|
||||||
|
hasHeaders := false
|
||||||
for key, values := range errMsg.Addon {
|
for key, values := range errMsg.Addon {
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
headers[key] = values[0]
|
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
|
||||||
|
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
hasHeaders = true
|
||||||
}
|
}
|
||||||
if len(headers) > 0 {
|
if hasHeaders {
|
||||||
payload["headers"] = headers
|
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
|
||||||
}
|
if errSet != nil {
|
||||||
}
|
return nil, errSet
|
||||||
|
|
||||||
if len(body) > 0 && json.Valid(body) {
|
|
||||||
var decoded map[string]any
|
|
||||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
|
||||||
if inner, ok := decoded["error"]; ok {
|
|
||||||
payload["error"] = inner
|
|
||||||
} else {
|
|
||||||
payload["error"] = decoded
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := payload["error"]; !ok {
|
if len(body) > 0 && json.Valid(body) {
|
||||||
payload["error"] = map[string]any{
|
errorNode := gjson.GetBytes(body, "error")
|
||||||
"type": "server_error",
|
if errorNode.Exists() {
|
||||||
"message": errText,
|
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
|
||||||
|
} else {
|
||||||
|
payload, errSet = sjson.SetRawBytes(payload, "error", body)
|
||||||
|
}
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(payload)
|
if !gjson.GetBytes(payload, "error").Exists() {
|
||||||
if err != nil {
|
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
|
||||||
return nil, err
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
|
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, errSet
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
|
||||||
|
return payload, conn.WriteMessage(websocket.TextMessage, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -11,9 +13,46 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type websocketCaptureExecutor struct {
|
||||||
|
streamCalls int
|
||||||
|
payloads [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
e.streamCalls++
|
||||||
|
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
|
||||||
|
chunks := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
|
||||||
|
close(chunks)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: chunks}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
|
||||||
@@ -326,3 +365,130 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
|||||||
t.Fatalf("server error: %v", errServer)
|
t.Fatalf("server error: %v", errServer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "auth-ws",
|
||||||
|
Provider: "test-provider",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{"websockets": "true"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth: %v", err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
|
||||||
|
t.Fatalf("expected websocket-capable upstream for test-model")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
executor := &websocketCaptureExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
|
||||||
|
if _, err := manager.Register(context.Background(), auth); err != nil {
|
||||||
|
t.Fatalf("Register auth: %v", err)
|
||||||
|
}
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
router := gin.New()
|
||||||
|
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
|
||||||
|
|
||||||
|
server := httptest.NewServer(router)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
errClose := conn.Close()
|
||||||
|
if errClose != nil {
|
||||||
|
t.Fatalf("close websocket: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
|
||||||
|
if errWrite != nil {
|
||||||
|
t.Fatalf("write prewarm websocket message: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, createdPayload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read prewarm created message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
|
||||||
|
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
|
||||||
|
}
|
||||||
|
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
|
||||||
|
if prewarmResponseID == "" {
|
||||||
|
t.Fatalf("prewarm response id is empty")
|
||||||
|
}
|
||||||
|
if executor.streamCalls != 0 {
|
||||||
|
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, completedPayload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read prewarm completed message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
|
||||||
|
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
|
||||||
|
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
|
||||||
|
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
|
||||||
|
if errWrite != nil {
|
||||||
|
t.Fatalf("write follow-up websocket message: %v", errWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, upstreamPayload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read upstream completed message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
if executor.streamCalls != 1 {
|
||||||
|
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
|
||||||
|
}
|
||||||
|
if len(executor.payloads) != 1 {
|
||||||
|
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
|
||||||
|
}
|
||||||
|
forwarded := executor.payloads[0]
|
||||||
|
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
|
||||||
|
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(forwarded, "generate").Exists() {
|
||||||
|
t.Fatalf("generate leaked upstream: %s", forwarded)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(forwarded, "input").Array()
|
||||||
|
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
|
||||||
|
t.Fatalf("unexpected forwarded input: %s", forwarded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user