// Package middleware provides HTTP middleware components for the CLI Proxy API server. // This file contains the request logging middleware that captures comprehensive // request and response data when enabled through configuration. package middleware import ( "bytes" "fmt" "io" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/klauspost/compress/zstd" "github.com/router-for-me/CLIProxyAPI/v7/internal/logging" "github.com/router-for-me/CLIProxyAPI/v7/internal/util" ) const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB // RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses. // It captures detailed information about the request and response, including headers and body, // and uses the provided RequestLogger to record this data. When full request logging is disabled, // body capture is limited to small known-size payloads to avoid large per-request memory spikes. func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { return func(c *gin.Context) { if logger == nil { c.Next() return } if shouldSkipMethodForRequestLogging(c.Request) { c.Next() return } path := c.Request.URL.Path if !shouldLogRequest(path) { c.Next() return } loggerEnabled := logger.IsEnabled() // Capture request information requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request)) if err != nil { // Log error but continue processing // In a real implementation, you might want to use a proper logger here c.Next() return } // Create response writer wrapper wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo) if !loggerEnabled { wrapper.logOnErrorOnly = true } c.Writer = wrapper // Process the request c.Next() // Finalize logging after request processing if err = wrapper.Finalize(c); err != nil { // Log error but don't interrupt the response // In a real implementation, you might want to use a proper logger here } } } func shouldSkipMethodForRequestLogging(req *http.Request) bool { if req == nil { return true } if req.Method != http.MethodGet { return false } return !isResponsesWebsocketUpgrade(req) } func isResponsesWebsocketUpgrade(req *http.Request) bool { if req == nil || req.URL == nil { return false } if req.URL.Path != "/v1/responses" { return false } return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket") } func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool { if loggerEnabled { return true } if req == nil || req.Body == nil { return false } contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type"))) if strings.HasPrefix(contentType, "multipart/form-data") { return false } if req.ContentLength <= 0 { return false } return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes } // captureRequestInfo extracts relevant information from the incoming HTTP request. // It captures the URL, method, headers, and body. The request body is read and then // restored so that it can be processed by subsequent handlers. func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) { // Capture URL with sensitive query parameters masked maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery) url := c.Request.URL.Path if maskedQuery != "" { url += "?" + maskedQuery } // Capture method method := c.Request.Method // Capture headers headers := make(map[string][]string) for key, values := range c.Request.Header { headers[key] = values } // Capture request body var body []byte if captureBody && c.Request.Body != nil { // Read the body bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { return nil, err } // Restore the body for the actual request processing c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) body = decodeCapturedRequestBodyForLog(bodyBytes, c.Request.Header.Get("Content-Encoding")) } return &RequestInfo{ URL: url, Method: method, Headers: headers, Body: body, RequestID: logging.GetGinRequestID(c), Timestamp: time.Now(), }, nil } func decodeCapturedRequestBodyForLog(raw []byte, encoding string) []byte { if len(raw) == 0 { return raw } decoded, errDecode := decodeCapturedRequestBody(raw, encoding) if errDecode != nil { return raw } return decoded } func decodeCapturedRequestBody(raw []byte, encoding string) ([]byte, error) { encoding = strings.TrimSpace(encoding) if encoding == "" || strings.EqualFold(encoding, "identity") { return raw, nil } parts := strings.Split(encoding, ",") body := raw for i := len(parts) - 1; i >= 0; i-- { enc := strings.ToLower(strings.TrimSpace(parts[i])) switch enc { case "", "identity": continue case "zstd": decoded, errDecode := decodeCapturedZstdRequestBody(body) if errDecode != nil { return nil, errDecode } body = decoded default: return nil, fmt.Errorf("unsupported request content encoding: %s", enc) } } return body, nil } func decodeCapturedZstdRequestBody(raw []byte) ([]byte, error) { decoder, errNewReader := zstd.NewReader(bytes.NewReader(raw)) if errNewReader != nil { return nil, fmt.Errorf("failed to create zstd request decoder: %w", errNewReader) } defer decoder.Close() decoded, errRead := io.ReadAll(decoder) if errRead != nil { return nil, fmt.Errorf("failed to decode zstd request body: %w", errRead) } return decoded, nil } // shouldLogRequest determines whether the request should be logged. // It skips management endpoints to avoid leaking secrets but allows // all other routes, including module-provided ones, to honor request-log. func shouldLogRequest(path string) bool { if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") { return false } if strings.HasPrefix(path, "/api") { return strings.HasPrefix(path, "/api/provider") } return true }