Refactor websocket logging and error handling

- Introduced new logging functions for websocket requests, handshakes, errors, and responses in `logging_helpers.go`.
- Updated `CodexWebsocketsExecutor` to utilize the new logging functions for improved clarity and consistency in websocket operations.
- Modified the handling of websocket upgrade rejections to log relevant metadata.
- Changed the request body key to a timeline body key in `openai_responses_websocket.go` to better reflect its purpose.
- Enhanced tests to verify the correct logging of websocket events and responses, including disconnect events and error handling scenarios.
This commit is contained in:
hkfires
2026-04-02 17:30:51 +08:00
parent 4f99bc54f1
commit 34339f61ee
8 changed files with 911 additions and 120 deletions
+62 -16
View File
@@ -15,6 +15,8 @@ import (
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
const responseBodyOverrideContextKey = "RESPONSE_BODY_OVERRIDE"
const websocketTimelineOverrideContextKey = "WEBSOCKET_TIMELINE_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
@@ -304,6 +306,10 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse)
}
apiWebsocketTimeline := w.extractAPIWebsocketTimeline(c)
if len(apiWebsocketTimeline) > 0 {
_ = w.streamWriter.WriteAPIWebsocketTimeline(apiWebsocketTimeline)
}
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err
@@ -312,7 +318,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.extractResponseBody(c), w.extractWebsocketTimeline(c), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIWebsocketTimeline(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -352,6 +358,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
return data
}
func (w *ResponseWriterWrapper) extractAPIWebsocketTimeline(c *gin.Context) []byte {
apiTimeline, isExist := c.Get("API_WEBSOCKET_TIMELINE")
if !isExist {
return nil
}
data, ok := apiTimeline.([]byte)
if !ok || len(data) == 0 {
return nil
}
return bytes.Clone(data)
}
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist {
@@ -364,19 +382,8 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
}
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
if body := extractBodyOverride(c, requestBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
@@ -384,13 +391,48 @@ func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractResponseBody(c *gin.Context) []byte {
if body := extractBodyOverride(c, responseBodyOverrideContextKey); len(body) > 0 {
return body
}
if w.body == nil || w.body.Len() == 0 {
return nil
}
return bytes.Clone(w.body.Bytes())
}
func (w *ResponseWriterWrapper) extractWebsocketTimeline(c *gin.Context) []byte {
return extractBodyOverride(c, websocketTimelineOverrideContextKey)
}
func extractBodyOverride(c *gin.Context, key string) []byte {
if c == nil {
return nil
}
bodyOverride, isExist := c.Get(key)
if !isExist {
return nil
}
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body, websocketTimeline, apiRequestBody, apiResponseBody, apiWebsocketTimeline []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
@@ -400,8 +442,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode,
headers,
body,
websocketTimeline,
apiRequestBody,
apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
@@ -418,8 +462,10 @@ func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, h
statusCode,
headers,
body,
websocketTimeline,
apiRequestBody,
apiResponseBody,
apiWebsocketTimeline,
apiResponseErrors,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,
+160 -1
View File
@@ -1,10 +1,14 @@
package middleware
import (
"bytes"
"net/http/httptest"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
@@ -33,7 +37,7 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
@@ -41,3 +45,158 @@ func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}
func TestExtractResponseBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{body: &bytes.Buffer{}}
wrapper.body.WriteString("original-response")
body := wrapper.extractResponseBody(c)
if string(body) != "original-response" {
t.Fatalf("response body = %q, want %q", string(body), "original-response")
}
c.Set(responseBodyOverrideContextKey, []byte("override-response"))
body = wrapper.extractResponseBody(c)
if string(body) != "override-response" {
t.Fatalf("response body = %q, want %q", string(body), "override-response")
}
body[0] = 'X'
if got := wrapper.extractResponseBody(c); string(got) != "override-response" {
t.Fatalf("response override should be cloned, got %q", string(got))
}
}
func TestExtractResponseBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(responseBodyOverrideContextKey, "override-response-as-string")
body := wrapper.extractResponseBody(c)
if string(body) != "override-response-as-string" {
t.Fatalf("response body = %q, want %q", string(body), "override-response-as-string")
}
}
func TestExtractBodyOverrideClonesBytes(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
override := []byte("body-override")
c.Set(requestBodyOverrideContextKey, override)
body := extractBodyOverride(c, requestBodyOverrideContextKey)
if !bytes.Equal(body, override) {
t.Fatalf("body override = %q, want %q", string(body), string(override))
}
body[0] = 'X'
if !bytes.Equal(override, []byte("body-override")) {
t.Fatalf("override mutated: %q", string(override))
}
}
func TestExtractWebsocketTimelineUsesOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
if got := wrapper.extractWebsocketTimeline(c); got != nil {
t.Fatalf("expected nil websocket timeline, got %q", string(got))
}
c.Set(websocketTimelineOverrideContextKey, []byte("timeline"))
body := wrapper.extractWebsocketTimeline(c)
if string(body) != "timeline" {
t.Fatalf("websocket timeline = %q, want %q", string(body), "timeline")
}
}
func TestFinalizeStreamingWritesAPIWebsocketTimeline(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
streamWriter := &testStreamingLogWriter{}
wrapper := &ResponseWriterWrapper{
ResponseWriter: c.Writer,
logger: &testRequestLogger{enabled: true},
requestInfo: &RequestInfo{
URL: "/v1/responses",
Method: "POST",
Headers: map[string][]string{"Content-Type": {"application/json"}},
RequestID: "req-1",
Timestamp: time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC),
},
isStreaming: true,
streamWriter: streamWriter,
}
c.Set("API_WEBSOCKET_TIMELINE", []byte("Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}"))
if err := wrapper.Finalize(c); err != nil {
t.Fatalf("Finalize error: %v", err)
}
if string(streamWriter.apiWebsocketTimeline) != "Timestamp: 2026-04-01T12:00:00Z\nEvent: api.websocket.request\n{}" {
t.Fatalf("stream writer websocket timeline = %q", string(streamWriter.apiWebsocketTimeline))
}
if !streamWriter.closed {
t.Fatal("expected stream writer to be closed")
}
}
type testRequestLogger struct {
enabled bool
}
func (l *testRequestLogger) LogRequest(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []byte, []byte, []*interfaces.ErrorMessage, string, time.Time, time.Time) error {
return nil
}
func (l *testRequestLogger) LogStreamingRequest(string, string, map[string][]string, []byte, string) (logging.StreamingLogWriter, error) {
return &testStreamingLogWriter{}, nil
}
func (l *testRequestLogger) IsEnabled() bool {
return l.enabled
}
type testStreamingLogWriter struct {
apiWebsocketTimeline []byte
closed bool
}
func (w *testStreamingLogWriter) WriteChunkAsync([]byte) {}
func (w *testStreamingLogWriter) WriteStatus(int, map[string][]string) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIRequest([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIResponse([]byte) error {
return nil
}
func (w *testStreamingLogWriter) WriteAPIWebsocketTimeline(apiWebsocketTimeline []byte) error {
w.apiWebsocketTimeline = bytes.Clone(apiWebsocketTimeline)
return nil
}
func (w *testStreamingLogWriter) SetFirstChunkTimestamp(time.Time) {}
func (w *testStreamingLogWriter) Close() error {
w.closed = true
return nil
}
+2
View File
@@ -172,6 +172,8 @@ func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
nil,
nil,
nil,
nil,
nil,
true,
"issue-1711",
time.Now(),