From c19ae1d5be32537218b61e26ebe7845720966801 Mon Sep 17 00:00:00 2001 From: Kenny Date: Sun, 3 May 2026 15:56:39 -0700 Subject: [PATCH] Align Codex websocket protocol semantics --- internal/runtime/executor/codex_executor.go | 2 +- .../executor/codex_websockets_executor.go | 153 +++++++++++--- .../codex_websockets_executor_test.go | 189 +++++++++++++++++- 3 files changed, 310 insertions(+), 34 deletions(-) diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index aa8223f4..5e892ecd 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -31,7 +31,7 @@ import ( const ( codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)" - codexOriginator = "codex-tui" + codexOriginator = "codex_cli_rs" codexDefaultImageToolModel = "gpt-image-2" ) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 40ba7e92..87ae0efe 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -188,7 +188,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath) body, _ = sjson.SetBytes(body, "model", baseModel) body, _ = sjson.SetBytes(body, "stream", true) - body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") body, _ = sjson.DeleteBytes(body, "safety_identifier") body = normalizeCodexInstructions(body) @@ -776,6 +775,11 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) { parsed.Scheme = "ws" case "https": parsed.Scheme = "wss" + default: + return "", fmt.Errorf("codex websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme) + } + if strings.TrimSpace(parsed.Host) == "" { + return "", fmt.Errorf("codex websockets executor: responses websocket URL host is empty") } return parsed.String(), nil } @@ -809,6 +813,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto if cache.ID != "" { rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID) + setHeaderCasePreserved(headers, "session_id", cache.ID) headers.Set("Conversation_id", cache.ID) } @@ -828,13 +833,19 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * ginHeaders = ginCtx.Request.Header.Clone() } - _, cfgBetaFeatures := codexHeaderDefaults(cfg, auth) + isAPIKey := codexAuthUsesAPIKey(auth) + cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth) ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "") misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "") misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "") misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "") misc.EnsureHeader(headers, ginHeaders, "Version", "") + if isAPIKey { + ensureHeaderWithPriority(headers, ginHeaders, "User-Agent", "", "") + } else { + ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent) + } betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta")) if betaHeader == "" && ginHeaders != nil { @@ -845,16 +856,9 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * } headers.Set("OpenAI-Beta", betaHeader) if strings.Contains(headers.Get("User-Agent"), "Mac OS") { - misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString()) - } - headers.Del("User-Agent") - - isAPIKey := false - if auth != nil && auth.Attributes != nil { - if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" { - isAPIKey = true - } + ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", uuid.NewString()) } + ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", "") if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" { headers.Set("Originator", originator) } else if !isAPIKey { @@ -864,7 +868,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * if auth != nil && auth.Metadata != nil { if accountID, ok := auth.Metadata["account_id"].(string); ok { if trimmed := strings.TrimSpace(accountID); trimmed != "" { - headers.Set("Chatgpt-Account-Id", trimmed) + headers.Set("ChatGPT-Account-ID", trimmed) } } } @@ -879,6 +883,77 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth * return headers } +func codexAuthUsesAPIKey(auth *cliproxyauth.Auth) bool { + if auth == nil || auth.Attributes == nil { + return false + } + return strings.TrimSpace(auth.Attributes["api_key"]) != "" +} + +func ensureHeaderCasePreserved(target http.Header, source http.Header, key, configValue, fallbackValue string) { + if target == nil { + return + } + if strings.TrimSpace(headerValueCaseInsensitive(target, key)) != "" { + return + } + if source != nil { + if val := strings.TrimSpace(headerValueCaseInsensitive(source, key)); val != "" { + setHeaderCasePreserved(target, key, val) + return + } + } + if val := strings.TrimSpace(configValue); val != "" { + setHeaderCasePreserved(target, key, val) + return + } + if val := strings.TrimSpace(fallbackValue); val != "" { + setHeaderCasePreserved(target, key, val) + } +} + +func setHeaderCasePreserved(headers http.Header, key string, value string) { + if headers == nil { + return + } + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) + if key == "" || value == "" { + return + } + deleteHeaderCaseInsensitive(headers, key) + headers[key] = []string{value} +} + +func headerValueCaseInsensitive(headers http.Header, key string) string { + key = strings.TrimSpace(key) + if headers == nil || key == "" { + return "" + } + if val := strings.TrimSpace(headers.Get(key)); val != "" { + return val + } + for existingKey, values := range headers { + if !strings.EqualFold(existingKey, key) { + continue + } + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + } + return "" +} + +func deleteHeaderCaseInsensitive(headers http.Header, key string) { + for existingKey := range headers { + if strings.EqualFold(existingKey, key) { + delete(headers, existingKey) + } + } +} + func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) { if cfg == nil || auth == nil { return "", "" @@ -962,25 +1037,53 @@ func parseCodexWebsocketError(payload []byte) (error, bool) { return nil, false } - out := []byte(`{}`) - if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() { - raw := errNode.Raw - if errNode.Type == gjson.String { - raw = errNode.Raw - } - out, _ = sjson.SetRawBytes(out, "error", []byte(raw)) - } else { - out, _ = sjson.SetBytes(out, "error.type", "server_error") - out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status)) - } - + out := buildCodexWebsocketErrorPayload(payload, status) headers := parseCodexWebsocketErrorHeaders(payload) + statusError := statusErr{code: status, msg: string(out)} + if isCodexWebsocketConnectionLimitError(payload) { + retryAfter := time.Duration(0) + statusError.retryAfter = &retryAfter + } return statusErrWithHeaders{ - statusErr: statusErr{code: status, msg: string(out)}, + statusErr: statusError, headers: headers, }, true } +func buildCodexWebsocketErrorPayload(payload []byte, status int) []byte { + out := []byte(`{}`) + out, _ = sjson.SetBytes(out, "status", status) + + if bodyNode := gjson.GetBytes(payload, "body"); bodyNode.Exists() { + out, _ = sjson.SetRawBytes(out, "body", []byte(bodyNode.Raw)) + if bodyErrorNode := bodyNode.Get("error"); bodyErrorNode.Exists() { + out, _ = sjson.SetRawBytes(out, "error", []byte(bodyErrorNode.Raw)) + return out + } + } + + if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() { + out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw)) + return out + } + + out, _ = sjson.SetBytes(out, "error.type", "server_error") + out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status)) + return out +} + +func isCodexWebsocketConnectionLimitError(payload []byte) bool { + if len(payload) == 0 { + return false + } + for _, path := range []string{"error.code", "error.type", "body.error.code", "body.error.type", "code", "error"} { + if strings.TrimSpace(gjson.GetBytes(payload, path).String()) == "websocket_connection_limit_reached" { + return true + } + } + return false +} + func parseCodexWebsocketErrorHeaders(payload []byte) http.Header { headersNode := gjson.GetBytes(payload, "headers") if !headersNode.Exists() || !headersNode.IsObject() { diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index dec356de..0b7a546e 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -1,15 +1,20 @@ package executor import ( + "bytes" "context" "net/http" "net/http/httptest" "testing" + "time" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" "github.com/tidwall/gjson" ) @@ -32,14 +37,71 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) } } +func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T) { + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + capturedPayload := make(chan []byte, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + t.Fatalf("request path = %s, want /responses", r.URL.Path) + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Fatalf("upgrade websocket: %v", err) + } + defer func() { _ = conn.Close() }() + + msgType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read upstream websocket message: %v", err) + } + if msgType != websocket.TextMessage { + t.Fatalf("message type = %d, want text", msgType) + } + capturedPayload <- bytes.Clone(payload) + + completed := []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`) + if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil { + t.Fatalf("write completed websocket message: %v", errWrite) + } + })) + defer server.Close() + + exec := NewCodexWebsocketsExecutor(&config.Config{SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll}}) + auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "sk-test", "base_url": server.URL}} + req := cliproxyexecutor.Request{ + Model: "gpt-5-codex", + Payload: []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`), + } + opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("codex")} + + if _, err := exec.Execute(context.Background(), auth, req, opts); err != nil { + t.Fatalf("Execute() error = %v", err) + } + + select { + case payload := <-capturedPayload: + if got := gjson.GetBytes(payload, "type").String(); got != "response.create" { + t.Fatalf("upstream type = %s, want response.create; payload=%s", got, payload) + } + if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-1" { + t.Fatalf("upstream previous_response_id = %s, want resp-1; payload=%s", got, payload) + } + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for upstream websocket payload") + } +} + func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) { headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil) if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue { t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue) } - if got := headers.Get("User-Agent"); got != "" { - t.Fatalf("User-Agent = %s, want empty", got) + if got := headers.Get("User-Agent"); got != codexUserAgent { + t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent) + } + if got := headers.Get("Originator"); got != codexOriginator { + t.Fatalf("Originator = %s, want %s", got, codexOriginator) } if got := headers.Get("Version"); got != "" { t.Fatalf("Version = %q, want empty", got) @@ -62,9 +124,11 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing } ctx := contextWithGinHeaders(map[string]string{ "Originator": "Codex Desktop", + "User-Agent": "codex_cli_rs/0.1.0", "Version": "0.115.0-alpha.27", "X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`, "X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d", + "session_id": "sess-client", }) headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil) @@ -72,6 +136,9 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing if got := headers.Get("Originator"); got != "Codex Desktop" { t.Fatalf("Originator = %s, want %s", got, "Codex Desktop") } + if got := headers.Get("User-Agent"); got != "codex_cli_rs/0.1.0" { + t.Fatalf("User-Agent = %s, want %s", got, "codex_cli_rs/0.1.0") + } if got := headers.Get("Version"); got != "0.115.0-alpha.27" { t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27") } @@ -81,6 +148,12 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" { t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d") } + if got := headerValueCaseInsensitive(headers, "session_id"); got != "sess-client" { + t.Fatalf("session_id = %s, want sess-client", got) + } + if _, ok := headers["session_id"]; !ok { + t.Fatalf("expected lowercase session_id header key, got %#v", headers) + } } func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) { @@ -97,8 +170,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) { headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg) - if got := headers.Get("User-Agent"); got != "" { - t.Fatalf("User-Agent = %s, want empty", got) + if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" { + t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0") } if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" { t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b") @@ -129,8 +202,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t * got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg) - if gotVal := got.Get("User-Agent"); gotVal != "" { - t.Fatalf("User-Agent = %s, want empty", gotVal) + if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" { + t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua") } if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" { t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta") @@ -155,8 +228,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg) - if got := headers.Get("User-Agent"); got != "" { - t.Fatalf("User-Agent = %s, want empty", got) + if got := headers.Get("User-Agent"); got != "config-ua" { + t.Fatalf("User-Agent = %s, want %s", got, "config-ua") } if got := headers.Get("x-codex-beta-features"); got != "client-beta" { t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta") @@ -183,6 +256,106 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) { if got := headers.Get("x-codex-beta-features"); got != "" { t.Fatalf("x-codex-beta-features = %q, want empty", got) } + if got := headers.Get("Originator"); got != "" { + t.Fatalf("Originator = %s, want empty", got) + } +} + +func TestApplyCodexWebsocketHeadersPreservesExplicitAPIKeyUserAgent(t *testing.T) { + auth := &cliproxyauth.Auth{Provider: "codex", Attributes: map[string]string{"api_key": "sk-test"}} + ctx := contextWithGinHeaders(map[string]string{"User-Agent": "api-key-client/1.0", "Originator": "explicit-origin"}) + + headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "sk-test", nil) + + if got := headers.Get("User-Agent"); got != "api-key-client/1.0" { + t.Fatalf("User-Agent = %s, want api-key-client/1.0", got) + } + if got := headers.Get("Originator"); got != "explicit-origin" { + t.Fatalf("Originator = %s, want explicit-origin", got) + } +} + +func TestApplyCodexPromptCacheHeadersSetsLowercaseSessionAndLegacyConversation(t *testing.T) { + req := cliproxyexecutor.Request{Model: "gpt-5-codex", Payload: []byte(`{"prompt_cache_key":"cache-1"}`)} + + _, headers := applyCodexPromptCacheHeaders("openai-response", req, []byte(`{"model":"gpt-5-codex"}`)) + + if got := headerValueCaseInsensitive(headers, "session_id"); got != "cache-1" { + t.Fatalf("session_id = %s, want cache-1", got) + } + if _, ok := headers["session_id"]; !ok { + t.Fatalf("expected lowercase session_id key, got %#v", headers) + } + if got := headers.Get("Conversation_id"); got != "cache-1" { + t.Fatalf("Conversation_id = %s, want cache-1", got) + } +} + +func TestApplyCodexWebsocketHeadersUsesCanonicalAccountHeader(t *testing.T) { + auth := &cliproxyauth.Auth{Provider: "codex", Metadata: map[string]any{"account_id": "acct-1"}} + + headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", nil) + + if got := headers.Get("ChatGPT-Account-ID"); got != "acct-1" { + t.Fatalf("ChatGPT-Account-ID = %s, want acct-1", got) + } +} + +func TestBuildCodexResponsesWebsocketURLRequiresHTTPURL(t *testing.T) { + if got, err := buildCodexResponsesWebsocketURL("https://example.com/backend/responses"); err != nil || got != "wss://example.com/backend/responses" { + t.Fatalf("https URL = %q, %v; want wss URL", got, err) + } + if _, err := buildCodexResponsesWebsocketURL("ftp://example.com/responses"); err == nil { + t.Fatalf("expected unsupported scheme error") + } + if _, err := buildCodexResponsesWebsocketURL("https:///responses"); err == nil { + t.Fatalf("expected empty host error") + } +} + +func TestParseCodexWebsocketErrorMarksConnectionLimitRetryable(t *testing.T) { + err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"code":"websocket_connection_limit_reached","message":"too many websockets"},"headers":{"retry-after":"1"}}`)) + if !ok { + t.Fatalf("expected websocket error") + } + status, ok := err.(interface{ StatusCode() int }) + if !ok || status.StatusCode() != http.StatusTooManyRequests { + t.Fatalf("status = %#v, want 429", err) + } + retryable, ok := err.(interface{ RetryAfter() *time.Duration }) + if !ok || retryable.RetryAfter() == nil { + t.Fatalf("expected retryable websocket connection limit error") + } + withHeaders, ok := err.(interface{ Headers() http.Header }) + if !ok || withHeaders.Headers().Get("retry-after") != "1" { + t.Fatalf("headers = %#v, want retry-after", err) + } +} + +func TestParseCodexWebsocketErrorPreservesWrappedBodyAndHeaders(t *testing.T) { + err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"code":"websocket_connection_limit_reached","type":"server_error","message":"too many websocket connections"}},"headers":{"x-request-id":"req-1"}}`)) + if !ok { + t.Fatalf("expected websocket error") + } + + parsed := gjson.Parse(err.Error()) + if got := parsed.Get("status").Int(); got != http.StatusTooManyRequests { + t.Fatalf("wrapped status = %d, want 429; payload=%s", got, err.Error()) + } + if got := parsed.Get("body.error.code").String(); got != "websocket_connection_limit_reached" { + t.Fatalf("wrapped body error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error()) + } + if got := parsed.Get("error.code").String(); got != "websocket_connection_limit_reached" { + t.Fatalf("surface error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error()) + } + retryable, ok := err.(interface{ RetryAfter() *time.Duration }) + if !ok || retryable.RetryAfter() == nil { + t.Fatalf("expected body.error.code websocket connection limit to be retryable") + } + withHeaders, ok := err.(interface{ Headers() http.Header }) + if !ok || withHeaders.Headers().Get("x-request-id") != "req-1" { + t.Fatalf("headers = %#v, want x-request-id", err) + } } func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {