feat(runtime): enhance payload rule resolution with dynamic path support
- Introduced `resolvePayloadRulePaths` function to dynamically resolve rule paths supporting array queries and complex logic. - Updated payload processing logic (`apply defaults`, `overrides`, `filters`) to handle resolved paths for better flexibility. - Added helper functions for path parsing, query matching, and logical resolution to improve modularity and reusability. - Introduced payload condition match logic, including `match`, `not-match`, `exist`, and `not-exist` rules in `PayloadConfig`. - Enhanced `payloadModelRulesMatch` function to support conditional checks at various levels. - Added helper methods for evaluating JSON path conditions and values. - Updated tests to validate new conditional rules against different payload scenarios.
This commit is contained in:
@@ -2,6 +2,8 @@ package helps
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -19,6 +21,11 @@ import (
|
||||
// model name before alias resolution so payload rules can target aliases precisely.
|
||||
// requestPath is the inbound HTTP request path (when available) used for endpoint-scoped gates.
|
||||
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string, requestPath string) []byte {
|
||||
return ApplyPayloadConfigWithRequest(cfg, model, protocol, "", root, payload, original, requestedModel, requestPath, nil)
|
||||
}
|
||||
|
||||
// ApplyPayloadConfigWithRequest applies payload config using source protocol and request header gates.
|
||||
func ApplyPayloadConfigWithRequest(cfg *config.Config, model, protocol, formProtocol, root string, payload, original []byte, requestedModel string, requestPath string, headers http.Header) []byte {
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
@@ -48,7 +55,7 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, formProtocol, headers, out, root, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
@@ -75,7 +82,7 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, formProtocol, headers, out, root, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
@@ -106,7 +113,7 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, formProtocol, headers, out, root, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
@@ -126,7 +133,7 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, formProtocol, headers, out, root, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
@@ -150,7 +157,7 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, formProtocol, headers, out, root, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
@@ -192,7 +199,7 @@ func isImagesEndpointRequestPath(path string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
||||
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, formProtocol string, headers http.Header, payload []byte, root string, models []string) bool {
|
||||
if len(rules) == 0 || len(models) == 0 {
|
||||
return false
|
||||
}
|
||||
@@ -205,7 +212,16 @@ func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, mo
|
||||
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
|
||||
continue
|
||||
}
|
||||
if matchModelPattern(name, model) {
|
||||
if !payloadFormProtocolMatches(entry.FormProtocol, formProtocol) {
|
||||
continue
|
||||
}
|
||||
if !payloadHeadersMatch(headers, entry.Headers) {
|
||||
continue
|
||||
}
|
||||
if !matchModelPattern(name, model) {
|
||||
continue
|
||||
}
|
||||
if payloadModelRuleConditionsMatch(payload, root, entry) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -213,6 +229,207 @@ func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, mo
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadModelRuleConditionsMatch(payload []byte, root string, rule config.PayloadModelRule) bool {
|
||||
if !payloadMatchConditionsMatch(payload, root, rule.Match) {
|
||||
return false
|
||||
}
|
||||
if !payloadNotMatchConditionsMatch(payload, root, rule.NotMatch) {
|
||||
return false
|
||||
}
|
||||
if !payloadExistConditionsMatch(payload, root, rule.Exist) {
|
||||
return false
|
||||
}
|
||||
if !payloadNotExistConditionsMatch(payload, root, rule.NotExist) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func payloadMatchConditionsMatch(payload []byte, root string, conditions []map[string]any) bool {
|
||||
for _, condition := range conditions {
|
||||
for path, value := range condition {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
if !payloadPathMatchesValue(payload, buildPayloadPath(root, path), value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func payloadNotMatchConditionsMatch(payload []byte, root string, conditions []map[string]any) bool {
|
||||
for _, condition := range conditions {
|
||||
for path, value := range condition {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
if payloadPathMatchesValue(payload, buildPayloadPath(root, path), value) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func payloadExistConditionsMatch(payload []byte, root string, paths []string) bool {
|
||||
for _, path := range paths {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
if !payloadPathExists(payload, buildPayloadPath(root, path)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func payloadNotExistConditionsMatch(payload []byte, root string, paths []string) bool {
|
||||
for _, path := range paths {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
continue
|
||||
}
|
||||
if payloadPathExists(payload, buildPayloadPath(root, path)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func payloadPathMatchesValue(payload []byte, path string, value any) bool {
|
||||
for _, resolvedPath := range resolvePayloadRulePaths(payload, path) {
|
||||
result := gjson.GetBytes(payload, resolvedPath)
|
||||
if !result.Exists() {
|
||||
continue
|
||||
}
|
||||
if payloadResultEquals(result, value) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadPathExists(payload []byte, path string) bool {
|
||||
for _, resolvedPath := range resolvePayloadRulePaths(payload, path) {
|
||||
result := gjson.GetBytes(payload, resolvedPath)
|
||||
if result.Exists() && result.Type != gjson.Null {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func payloadResultEquals(result gjson.Result, value any) bool {
|
||||
actual, ok := normalizedPayloadResult(result)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
expected, ok := normalizedPayloadValue(value)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return reflect.DeepEqual(actual, expected)
|
||||
}
|
||||
|
||||
func normalizedPayloadResult(result gjson.Result) (any, bool) {
|
||||
if !result.Exists() {
|
||||
return nil, false
|
||||
}
|
||||
raw := strings.TrimSpace(result.Raw)
|
||||
if raw == "" {
|
||||
encoded, errMarshal := json.Marshal(result.Value())
|
||||
if errMarshal != nil {
|
||||
return nil, false
|
||||
}
|
||||
raw = string(encoded)
|
||||
}
|
||||
return normalizedPayloadJSON([]byte(raw))
|
||||
}
|
||||
|
||||
func normalizedPayloadValue(value any) (any, bool) {
|
||||
encoded, errMarshal := json.Marshal(value)
|
||||
if errMarshal != nil {
|
||||
return nil, false
|
||||
}
|
||||
return normalizedPayloadJSON(encoded)
|
||||
}
|
||||
|
||||
func normalizedPayloadJSON(data []byte) (any, bool) {
|
||||
if len(strings.TrimSpace(string(data))) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
var out any
|
||||
if errUnmarshal := json.Unmarshal(data, &out); errUnmarshal != nil {
|
||||
return nil, false
|
||||
}
|
||||
return out, true
|
||||
}
|
||||
|
||||
func payloadFormProtocolMatches(pattern, formProtocol string) bool {
|
||||
pattern = normalizePayloadFormProtocol(pattern)
|
||||
if pattern == "" {
|
||||
return true
|
||||
}
|
||||
formProtocol = normalizePayloadFormProtocol(formProtocol)
|
||||
if formProtocol == "" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(pattern, formProtocol)
|
||||
}
|
||||
|
||||
func normalizePayloadFormProtocol(protocol string) string {
|
||||
protocol = strings.ToLower(strings.TrimSpace(protocol))
|
||||
switch protocol {
|
||||
case "openai-response", "openai-responses", "response":
|
||||
return "responses"
|
||||
case "gemini-cli":
|
||||
return "gemini"
|
||||
default:
|
||||
return protocol
|
||||
}
|
||||
}
|
||||
|
||||
func payloadHeadersMatch(headers http.Header, rules map[string]string) bool {
|
||||
if len(rules) == 0 {
|
||||
return true
|
||||
}
|
||||
for key, pattern := range rules {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
values := payloadHeaderValues(headers, key)
|
||||
if len(values) == 0 {
|
||||
return false
|
||||
}
|
||||
matched := false
|
||||
for _, value := range values {
|
||||
if matchModelPattern(pattern, value) {
|
||||
matched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func payloadHeaderValues(headers http.Header, key string) []string {
|
||||
if headers == nil {
|
||||
return nil
|
||||
}
|
||||
var values []string
|
||||
for headerKey, headerValues := range headers {
|
||||
if strings.EqualFold(headerKey, key) {
|
||||
values = append(values, headerValues...)
|
||||
}
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
func payloadModelCandidates(model, requestedModel string) []string {
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
|
||||
Reference in New Issue
Block a user