package middleware import ( "bytes" "io" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" ) func TestShouldSkipMethodForRequestLogging(t *testing.T) { tests := []struct { name string req *http.Request skip bool }{ { name: "nil request", req: nil, skip: true, }, { name: "post request should not skip", req: &http.Request{ Method: http.MethodPost, URL: &url.URL{Path: "/v1/responses"}, }, skip: false, }, { name: "plain get should skip", req: &http.Request{ Method: http.MethodGet, URL: &url.URL{Path: "/v1/models"}, Header: http.Header{}, }, skip: true, }, { name: "responses websocket upgrade should not skip", req: &http.Request{ Method: http.MethodGet, URL: &url.URL{Path: "/v1/responses"}, Header: http.Header{"Upgrade": []string{"websocket"}}, }, skip: false, }, { name: "responses get without upgrade should skip", req: &http.Request{ Method: http.MethodGet, URL: &url.URL{Path: "/v1/responses"}, Header: http.Header{}, }, skip: true, }, } for i := range tests { got := shouldSkipMethodForRequestLogging(tests[i].req) if got != tests[i].skip { t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip) } } } func TestShouldCaptureRequestBody(t *testing.T) { tests := []struct { name string loggerEnabled bool req *http.Request want bool }{ { name: "logger enabled always captures", loggerEnabled: true, req: &http.Request{ Body: io.NopCloser(strings.NewReader("{}")), ContentLength: -1, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: true, }, { name: "nil request", loggerEnabled: false, req: nil, want: false, }, { name: "small known size json in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("{}")), ContentLength: 2, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: true, }, { name: "large known size skipped in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("x")), ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: false, }, { name: "unknown size skipped in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("x")), ContentLength: -1, Header: http.Header{"Content-Type": []string{"application/json"}}, }, want: false, }, { name: "multipart skipped in error-only mode", loggerEnabled: false, req: &http.Request{ Body: io.NopCloser(strings.NewReader("x")), ContentLength: 1, Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}}, }, want: false, }, } for i := range tests { got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req) if got != tests[i].want { t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want) } } } func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) { gin.SetMode(gin.TestMode) payload := []byte(`{"model":"test-model","stream":true}`) var compressed bytes.Buffer encoder, errNewWriter := zstd.NewWriter(&compressed) if errNewWriter != nil { t.Fatalf("zstd.NewWriter: %v", errNewWriter) } if _, errWrite := encoder.Write(payload); errWrite != nil { t.Fatalf("zstd write: %v", errWrite) } if errClose := encoder.Close(); errClose != nil { t.Fatalf("zstd close: %v", errClose) } compressedBytes := compressed.Bytes() recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes)) req.Header.Set("Content-Encoding", "zstd") c.Request = req info, errCapture := captureRequestInfo(c, true) if errCapture != nil { t.Fatalf("captureRequestInfo: %v", errCapture) } if !bytes.Equal(info.Body, payload) { t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload)) } restoredBody, errRead := io.ReadAll(c.Request.Body) if errRead != nil { t.Fatalf("read restored request body: %v", errRead) } if !bytes.Equal(restoredBody, compressedBytes) { t.Fatal("request body was not restored with the original compressed bytes") } }