refactor(executor): remove legacy connCreateSent logic and standardize response.create usage for all websocket events
- Simplified connection logic by removing `connCreateSent` and related state handling. - Updated `buildCodexWebsocketRequestBody` to always use `response.create`. - Added unit tests to validate `response.create` behavior and beta header preservation. - Dropped unsupported `response.append` and outdated `response.done` event types.
This commit is contained in:
@@ -31,7 +31,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04"
|
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
|
||||||
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
|
||||||
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
codexResponsesWebsocketHandshakeTO = 30 * time.Second
|
||||||
)
|
)
|
||||||
@@ -57,11 +57,6 @@ type codexWebsocketSession struct {
|
|||||||
wsURL string
|
wsURL string
|
||||||
authID string
|
authID string
|
||||||
|
|
||||||
// connCreateSent tracks whether a `response.create` message has been successfully sent
|
|
||||||
// on the current websocket connection. The upstream expects the first message on each
|
|
||||||
// connection to be `response.create`.
|
|
||||||
connCreateSent bool
|
|
||||||
|
|
||||||
writeMu sync.Mutex
|
writeMu sync.Mutex
|
||||||
|
|
||||||
activeMu sync.Mutex
|
activeMu sync.Mutex
|
||||||
@@ -212,13 +207,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
defer sess.reqMu.Unlock()
|
defer sess.reqMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
allowAppend := true
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||||
if sess != nil {
|
|
||||||
sess.connMu.Lock()
|
|
||||||
allowAppend = sess.connCreateSent
|
|
||||||
sess.connMu.Unlock()
|
|
||||||
}
|
|
||||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
@@ -280,10 +269,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
// execution session.
|
// execution session.
|
||||||
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
|
||||||
if errDialRetry == nil && connRetry != nil {
|
if errDialRetry == nil && connRetry != nil {
|
||||||
sess.connMu.Lock()
|
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||||
allowAppend = sess.connCreateSent
|
|
||||||
sess.connMu.Unlock()
|
|
||||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
@@ -312,7 +298,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
return resp, errSend
|
return resp, errSend
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if ctx != nil && ctx.Err() != nil {
|
if ctx != nil && ctx.Err() != nil {
|
||||||
@@ -403,26 +388,20 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey)
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
var authID, authLabel, authType, authValue string
|
||||||
if auth != nil {
|
authID = auth.ID
|
||||||
authID = auth.ID
|
authLabel = auth.Label
|
||||||
authLabel = auth.Label
|
authType, authValue = auth.AccountInfo()
|
||||||
authType, authValue = auth.AccountInfo()
|
|
||||||
}
|
|
||||||
|
|
||||||
executionSessionID := executionSessionIDFromOptions(opts)
|
executionSessionID := executionSessionIDFromOptions(opts)
|
||||||
var sess *codexWebsocketSession
|
var sess *codexWebsocketSession
|
||||||
if executionSessionID != "" {
|
if executionSessionID != "" {
|
||||||
sess = e.getOrCreateSession(executionSessionID)
|
sess = e.getOrCreateSession(executionSessionID)
|
||||||
sess.reqMu.Lock()
|
if sess != nil {
|
||||||
|
sess.reqMu.Lock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
allowAppend := true
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||||
if sess != nil {
|
|
||||||
sess.connMu.Lock()
|
|
||||||
allowAppend = sess.connCreateSent
|
|
||||||
sess.connMu.Unlock()
|
|
||||||
}
|
|
||||||
wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend)
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
@@ -483,10 +462,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
sess.reqMu.Unlock()
|
sess.reqMu.Unlock()
|
||||||
return nil, errDialRetry
|
return nil, errDialRetry
|
||||||
}
|
}
|
||||||
sess.connMu.Lock()
|
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
|
||||||
allowAppend = sess.connCreateSent
|
|
||||||
sess.connMu.Unlock()
|
|
||||||
wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend)
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: wsURL,
|
URL: wsURL,
|
||||||
Method: "WEBSOCKET",
|
Method: "WEBSOCKET",
|
||||||
@@ -515,7 +491,6 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
return nil, errSend
|
return nil, errSend
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
markCodexWebsocketCreateSent(sess, conn, wsReqBody)
|
|
||||||
|
|
||||||
out := make(chan cliproxyexecutor.StreamChunk)
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -657,31 +632,14 @@ func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Con
|
|||||||
return conn.WriteMessage(websocket.TextMessage, payload)
|
return conn.WriteMessage(websocket.TextMessage, payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte {
|
func buildCodexWebsocketRequestBody(body []byte) []byte {
|
||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns.
|
// Match codex-rs websocket v2 semantics: every request is `response.create`.
|
||||||
// The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation).
|
// Incremental follow-up turns continue on the same websocket using
|
||||||
// Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive.
|
// `previous_response_id` + incremental `input`, not `response.append`.
|
||||||
//
|
|
||||||
// NOTE: The upstream expects the first websocket event on each connection to be `response.create`,
|
|
||||||
// so we only use `response.append` after we have initialized the current connection.
|
|
||||||
if allowAppend {
|
|
||||||
if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" {
|
|
||||||
inputNode := gjson.GetBytes(body, "input")
|
|
||||||
wsReqBody := []byte(`{}`)
|
|
||||||
wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append")
|
|
||||||
if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" {
|
|
||||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw))
|
|
||||||
return wsReqBody
|
|
||||||
}
|
|
||||||
wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]"))
|
|
||||||
return wsReqBody
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
|
||||||
if errSet == nil && len(wsReqBody) > 0 {
|
if errSet == nil && len(wsReqBody) > 0 {
|
||||||
return wsReqBody
|
return wsReqBody
|
||||||
@@ -725,21 +683,6 @@ func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) {
|
|
||||||
if sess == nil || conn == nil || len(payload) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sess.connMu.Lock()
|
|
||||||
if sess.conn == conn {
|
|
||||||
sess.connCreateSent = true
|
|
||||||
}
|
|
||||||
sess.connMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
|
||||||
dialer := &websocket.Dialer{
|
dialer := &websocket.Dialer{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
@@ -1017,36 +960,6 @@ func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
|
||||||
done := make(chan struct{})
|
|
||||||
if ctx == nil || conn == nil {
|
|
||||||
return done
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctx.Done():
|
|
||||||
_ = conn.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return done
|
|
||||||
}
|
|
||||||
|
|
||||||
func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} {
|
|
||||||
done := make(chan struct{})
|
|
||||||
if ctx == nil || conn == nil {
|
|
||||||
return done
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctx.Done():
|
|
||||||
_ = conn.SetReadDeadline(time.Now())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return done
|
|
||||||
}
|
|
||||||
|
|
||||||
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
|
||||||
if len(opts.Metadata) == 0 {
|
if len(opts.Metadata) == 0 {
|
||||||
return ""
|
return ""
|
||||||
@@ -1120,7 +1033,6 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *
|
|||||||
sess.conn = conn
|
sess.conn = conn
|
||||||
sess.wsURL = wsURL
|
sess.wsURL = wsURL
|
||||||
sess.authID = authID
|
sess.authID = authID
|
||||||
sess.connCreateSent = false
|
|
||||||
sess.readerConn = conn
|
sess.readerConn = conn
|
||||||
sess.connMu.Unlock()
|
sess.connMu.Unlock()
|
||||||
|
|
||||||
@@ -1206,7 +1118,6 @@ func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSes
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
sess.conn = nil
|
sess.conn = nil
|
||||||
sess.connCreateSent = false
|
|
||||||
if sess.readerConn == conn {
|
if sess.readerConn == conn {
|
||||||
sess.readerConn = nil
|
sess.readerConn = nil
|
||||||
}
|
}
|
||||||
@@ -1273,7 +1184,6 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess
|
|||||||
authID := sess.authID
|
authID := sess.authID
|
||||||
wsURL := sess.wsURL
|
wsURL := sess.wsURL
|
||||||
sess.conn = nil
|
sess.conn = nil
|
||||||
sess.connCreateSent = false
|
|
||||||
if sess.readerConn == conn {
|
if sess.readerConn == conn {
|
||||||
sess.readerConn = nil
|
sess.readerConn = nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
|
||||||
|
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||||
|
|
||||||
|
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
|
||||||
|
t.Fatalf("type = %s, want response.create", got)
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
|
||||||
|
t.Fatalf("previous_response_id = %s, want resp-1", got)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
|
||||||
|
t.Fatalf("input item id mismatch")
|
||||||
|
}
|
||||||
|
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
|
||||||
|
t.Fatalf("unexpected websocket request type: %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||||
|
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "")
|
||||||
|
|
||||||
|
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||||
|
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,7 +26,6 @@ const (
|
|||||||
wsRequestTypeAppend = "response.append"
|
wsRequestTypeAppend = "response.append"
|
||||||
wsEventTypeError = "error"
|
wsEventTypeError = "error"
|
||||||
wsEventTypeCompleted = "response.completed"
|
wsEventTypeCompleted = "response.completed"
|
||||||
wsEventTypeDone = "response.done"
|
|
||||||
wsDoneMarker = "[DONE]"
|
wsDoneMarker = "[DONE]"
|
||||||
wsTurnStateHeader = "x-codex-turn-state"
|
wsTurnStateHeader = "x-codex-turn-state"
|
||||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||||
@@ -469,9 +468,6 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
|||||||
for i := range payloads {
|
for i := range payloads {
|
||||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||||
if eventType == wsEventTypeCompleted {
|
if eventType == wsEventTypeCompleted {
|
||||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
|
||||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
|
||||||
|
|
||||||
completed = true
|
completed = true
|
||||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,15 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -247,3 +250,79 @@ func TestSetWebsocketRequestBody(t *testing.T) {
|
|||||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
errClose := conn.Close()
|
||||||
|
if errClose != nil {
|
||||||
|
serverErrCh <- errClose
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
|
||||||
|
ctx.Request = r
|
||||||
|
|
||||||
|
data := make(chan []byte, 1)
|
||||||
|
errCh := make(chan *interfaces.ErrorMessage)
|
||||||
|
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
|
||||||
|
close(data)
|
||||||
|
close(errCh)
|
||||||
|
|
||||||
|
var bodyLog strings.Builder
|
||||||
|
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
|
||||||
|
ctx,
|
||||||
|
conn,
|
||||||
|
func(...interface{}) {},
|
||||||
|
data,
|
||||||
|
errCh,
|
||||||
|
&bodyLog,
|
||||||
|
"session-1",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
|
||||||
|
serverErrCh <- errors.New("completed output not captured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
serverErrCh <- nil
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
|
||||||
|
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial websocket: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
errClose := conn.Close()
|
||||||
|
if errClose != nil {
|
||||||
|
t.Fatalf("close websocket: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, payload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
t.Fatalf("read websocket message: %v", errReadMessage)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
|
||||||
|
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
|
||||||
|
}
|
||||||
|
if strings.Contains(string(payload), "response.done") {
|
||||||
|
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errServer := <-serverErrCh; errServer != nil {
|
||||||
|
t.Fatalf("server error: %v", errServer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user