feat(api, watcher): add zstd decoding for request logs and payload diff support
- Added `zstd` decoding support in request logging, including helper functions to process `Content-Encoding` headers. - Enhanced config diff logic to compare payload-specific rules and track changes in payload configurations. - Added tests to validate `zstd` decoding and payload diff behavior.
This commit is contained in:
@@ -5,12 +5,14 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||||
)
|
)
|
||||||
@@ -136,7 +138,7 @@ func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error)
|
|||||||
|
|
||||||
// Restore the body for the actual request processing
|
// Restore the body for the actual request processing
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
body = bodyBytes
|
body = decodeCapturedRequestBodyForLog(bodyBytes, c.Request.Header.Get("Content-Encoding"))
|
||||||
}
|
}
|
||||||
|
|
||||||
return &RequestInfo{
|
return &RequestInfo{
|
||||||
@@ -149,6 +151,58 @@ func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error)
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func decodeCapturedRequestBodyForLog(raw []byte, encoding string) []byte {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, errDecode := decodeCapturedRequestBody(raw, encoding)
|
||||||
|
if errDecode != nil {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeCapturedRequestBody(raw []byte, encoding string) ([]byte, error) {
|
||||||
|
encoding = strings.TrimSpace(encoding)
|
||||||
|
if encoding == "" || strings.EqualFold(encoding, "identity") {
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(encoding, ",")
|
||||||
|
body := raw
|
||||||
|
for i := len(parts) - 1; i >= 0; i-- {
|
||||||
|
enc := strings.ToLower(strings.TrimSpace(parts[i]))
|
||||||
|
switch enc {
|
||||||
|
case "", "identity":
|
||||||
|
continue
|
||||||
|
case "zstd":
|
||||||
|
decoded, errDecode := decodeCapturedZstdRequestBody(body)
|
||||||
|
if errDecode != nil {
|
||||||
|
return nil, errDecode
|
||||||
|
}
|
||||||
|
body = decoded
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported request content encoding: %s", enc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeCapturedZstdRequestBody(raw []byte) ([]byte, error) {
|
||||||
|
decoder, errNewReader := zstd.NewReader(bytes.NewReader(raw))
|
||||||
|
if errNewReader != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create zstd request decoder: %w", errNewReader)
|
||||||
|
}
|
||||||
|
defer decoder.Close()
|
||||||
|
|
||||||
|
decoded, errRead := io.ReadAll(decoder)
|
||||||
|
if errRead != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode zstd request body: %w", errRead)
|
||||||
|
}
|
||||||
|
return decoded, nil
|
||||||
|
}
|
||||||
|
|
||||||
// shouldLogRequest determines whether the request should be logged.
|
// shouldLogRequest determines whether the request should be logged.
|
||||||
// It skips management endpoints to avoid leaking secrets but allows
|
// It skips management endpoints to avoid leaking secrets but allows
|
||||||
// all other routes, including module-provided ones, to honor request-log.
|
// all other routes, including module-provided ones, to honor request-log.
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/klauspost/compress/zstd"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||||
@@ -136,3 +141,43 @@ func TestShouldCaptureRequestBody(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCaptureRequestInfoDecodesZstdRequestBodyForLog(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
payload := []byte(`{"model":"test-model","stream":true}`)
|
||||||
|
var compressed bytes.Buffer
|
||||||
|
encoder, errNewWriter := zstd.NewWriter(&compressed)
|
||||||
|
if errNewWriter != nil {
|
||||||
|
t.Fatalf("zstd.NewWriter: %v", errNewWriter)
|
||||||
|
}
|
||||||
|
if _, errWrite := encoder.Write(payload); errWrite != nil {
|
||||||
|
t.Fatalf("zstd write: %v", errWrite)
|
||||||
|
}
|
||||||
|
if errClose := encoder.Close(); errClose != nil {
|
||||||
|
t.Fatalf("zstd close: %v", errClose)
|
||||||
|
}
|
||||||
|
compressedBytes := compressed.Bytes()
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(compressedBytes))
|
||||||
|
req.Header.Set("Content-Encoding", "zstd")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
info, errCapture := captureRequestInfo(c, true)
|
||||||
|
if errCapture != nil {
|
||||||
|
t.Fatalf("captureRequestInfo: %v", errCapture)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(info.Body, payload) {
|
||||||
|
t.Fatalf("logged request body = %q, want %q", string(info.Body), string(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
restoredBody, errRead := io.ReadAll(c.Request.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
t.Fatalf("read restored request body: %v", errRead)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(restoredBody, compressedBytes) {
|
||||||
|
t.Fatal("request body was not restored with the original compressed bytes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -93,6 +93,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.Routing.Strategy != newCfg.Routing.Strategy {
|
if oldCfg.Routing.Strategy != newCfg.Routing.Strategy {
|
||||||
changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy))
|
changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy))
|
||||||
}
|
}
|
||||||
|
if !reflect.DeepEqual(oldCfg.Payload, newCfg.Payload) {
|
||||||
|
changes = appendPayloadConfigChanges(changes, oldCfg.Payload, newCfg.Payload)
|
||||||
|
}
|
||||||
|
|
||||||
// API keys (redacted) and counts
|
// API keys (redacted) and counts
|
||||||
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
||||||
@@ -338,6 +341,29 @@ func trimStrings(in []string) []string {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendPayloadConfigChanges(changes []string, oldPayload, newPayload config.PayloadConfig) []string {
|
||||||
|
changes = appendPayloadRuleChanges(changes, "default", oldPayload.Default, newPayload.Default)
|
||||||
|
changes = appendPayloadRuleChanges(changes, "default-raw", oldPayload.DefaultRaw, newPayload.DefaultRaw)
|
||||||
|
changes = appendPayloadRuleChanges(changes, "override", oldPayload.Override, newPayload.Override)
|
||||||
|
changes = appendPayloadRuleChanges(changes, "override-raw", oldPayload.OverrideRaw, newPayload.OverrideRaw)
|
||||||
|
changes = appendPayloadFilterRuleChanges(changes, "filter", oldPayload.Filter, newPayload.Filter)
|
||||||
|
return changes
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPayloadRuleChanges(changes []string, section string, oldRules, newRules []config.PayloadRule) []string {
|
||||||
|
if reflect.DeepEqual(oldRules, newRules) {
|
||||||
|
return changes
|
||||||
|
}
|
||||||
|
return append(changes, fmt.Sprintf("payload.%s: updated (%d -> %d rules)", section, len(oldRules), len(newRules)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendPayloadFilterRuleChanges(changes []string, section string, oldRules, newRules []config.PayloadFilterRule) []string {
|
||||||
|
if reflect.DeepEqual(oldRules, newRules) {
|
||||||
|
return changes
|
||||||
|
}
|
||||||
|
return append(changes, fmt.Sprintf("payload.%s: updated (%d -> %d rules)", section, len(oldRules), len(newRules)))
|
||||||
|
}
|
||||||
|
|
||||||
func equalStringMap(a, b map[string]string) bool {
|
func equalStringMap(a, b map[string]string) bool {
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -555,6 +555,9 @@ func (s *Service) applyConfigUpdate(newCfg *config.Config) {
|
|||||||
s.coreManager.SetConfig(newCfg)
|
s.coreManager.SetConfig(newCfg)
|
||||||
s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias)
|
s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias)
|
||||||
}
|
}
|
||||||
|
if newCfg.Home.Enabled {
|
||||||
|
s.registerHomeExecutors()
|
||||||
|
}
|
||||||
s.rebindExecutors()
|
s.rebindExecutors()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user