package helps import ( "encoding/json" "strings" "github.com/router-for-me/CLIProxyAPI/v7/internal/config" "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) // ApplyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter // paths as relative to the provided root path (for example, "request" for Gemini CLI) // and restricts matches to the given protocol when supplied. Defaults are checked // against the original payload when provided. requestedModel carries the client-visible // model name before alias resolution so payload rules can target aliases precisely. // requestPath is the inbound HTTP request path (when available) used for endpoint-scoped gates. func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string, requestPath string) []byte { if cfg == nil || len(payload) == 0 { return payload } out := payload // Apply disable-image-generation filtering before payload rules so config payload // overrides can explicitly re-enable image_generation when desired. if cfg.DisableImageGeneration != config.DisableImageGenerationOff { if cfg.DisableImageGeneration != config.DisableImageGenerationChat || !isImagesEndpointRequestPath(requestPath) { out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation") out = removeToolChoiceFromPayloadWithRoot(out, root, "image_generation") } } rules := cfg.Payload hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0 if hasPayloadRules { model = strings.TrimSpace(model) requestedModel = strings.TrimSpace(requestedModel) if model != "" || requestedModel != "" { candidates := payloadModelCandidates(model, requestedModel) source := original if len(source) == 0 { source = payload } appliedDefaults := make(map[string]struct{}) // Apply default rules: first write wins per field across all matching rules. for i := range rules.Default { rule := &rules.Default[i] if !payloadModelRulesMatch(rule.Models, protocol, candidates) { continue } for path, value := range rule.Params { fullPath := buildPayloadPath(root, path) if fullPath == "" { continue } if gjson.GetBytes(source, fullPath).Exists() { continue } if _, ok := appliedDefaults[fullPath]; ok { continue } updated, errSet := sjson.SetBytes(out, fullPath, value) if errSet != nil { continue } out = updated appliedDefaults[fullPath] = struct{}{} } } // Apply default raw rules: first write wins per field across all matching rules. for i := range rules.DefaultRaw { rule := &rules.DefaultRaw[i] if !payloadModelRulesMatch(rule.Models, protocol, candidates) { continue } for path, value := range rule.Params { fullPath := buildPayloadPath(root, path) if fullPath == "" { continue } if gjson.GetBytes(source, fullPath).Exists() { continue } if _, ok := appliedDefaults[fullPath]; ok { continue } rawValue, ok := payloadRawValue(value) if !ok { continue } updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) if errSet != nil { continue } out = updated appliedDefaults[fullPath] = struct{}{} } } // Apply override rules: last write wins per field across all matching rules. for i := range rules.Override { rule := &rules.Override[i] if !payloadModelRulesMatch(rule.Models, protocol, candidates) { continue } for path, value := range rule.Params { fullPath := buildPayloadPath(root, path) if fullPath == "" { continue } updated, errSet := sjson.SetBytes(out, fullPath, value) if errSet != nil { continue } out = updated } } // Apply override raw rules: last write wins per field across all matching rules. for i := range rules.OverrideRaw { rule := &rules.OverrideRaw[i] if !payloadModelRulesMatch(rule.Models, protocol, candidates) { continue } for path, value := range rule.Params { fullPath := buildPayloadPath(root, path) if fullPath == "" { continue } rawValue, ok := payloadRawValue(value) if !ok { continue } updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) if errSet != nil { continue } out = updated } } // Apply filter rules: remove matching paths from payload. for i := range rules.Filter { rule := &rules.Filter[i] if !payloadModelRulesMatch(rule.Models, protocol, candidates) { continue } for _, path := range rule.Params { fullPath := buildPayloadPath(root, path) if fullPath == "" { continue } updated, errDel := sjson.DeleteBytes(out, fullPath) if errDel != nil { continue } out = updated } } } } return out } func isImagesEndpointRequestPath(path string) bool { path = strings.TrimSpace(path) if path == "" { return false } if path == "/v1/images/generations" || path == "/v1/images/edits" { return true } // Be tolerant of prefix routers that may report a longer matched route. if strings.HasSuffix(path, "/v1/images/generations") || strings.HasSuffix(path, "/v1/images/edits") { return true } if strings.HasSuffix(path, "/images/generations") || strings.HasSuffix(path, "/images/edits") { return true } return false } func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool { if len(rules) == 0 || len(models) == 0 { return false } for _, model := range models { for _, entry := range rules { name := strings.TrimSpace(entry.Name) if name == "" { continue } if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) { continue } if matchModelPattern(name, model) { return true } } } return false } func payloadModelCandidates(model, requestedModel string) []string { model = strings.TrimSpace(model) requestedModel = strings.TrimSpace(requestedModel) if model == "" && requestedModel == "" { return nil } candidates := make([]string, 0, 3) seen := make(map[string]struct{}, 3) addCandidate := func(value string) { value = strings.TrimSpace(value) if value == "" { return } key := strings.ToLower(value) if _, ok := seen[key]; ok { return } seen[key] = struct{}{} candidates = append(candidates, value) } if model != "" { addCandidate(model) } if requestedModel != "" { parsed := thinking.ParseSuffix(requestedModel) base := strings.TrimSpace(parsed.ModelName) if base != "" { addCandidate(base) } if parsed.HasSuffix { addCandidate(requestedModel) } } return candidates } // buildPayloadPath combines an optional root path with a relative parameter path. // When root is empty, the parameter path is used as-is. When root is non-empty, // the parameter path is treated as relative to root. func buildPayloadPath(root, path string) string { r := strings.TrimSpace(root) p := strings.TrimSpace(path) if r == "" { return p } if p == "" { return r } if strings.HasPrefix(p, ".") { p = p[1:] } return r + "." + p } func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte { if len(payload) == 0 { return payload } toolType = strings.TrimSpace(toolType) if toolType == "" { return payload } toolsPath := buildPayloadPath(root, "tools") return removeToolTypeFromToolsArray(payload, toolsPath, toolType) } func removeToolChoiceFromPayloadWithRoot(payload []byte, root string, toolType string) []byte { if len(payload) == 0 { return payload } toolType = strings.TrimSpace(toolType) if toolType == "" { return payload } toolChoicePath := buildPayloadPath(root, "tool_choice") return removeToolChoiceFromPayload(payload, toolChoicePath, toolType) } func removeToolChoiceFromPayload(payload []byte, toolChoicePath string, toolType string) []byte { choice := gjson.GetBytes(payload, toolChoicePath) if !choice.Exists() { return payload } if choice.Type == gjson.String { if strings.EqualFold(strings.TrimSpace(choice.String()), toolType) { updated, errDel := sjson.DeleteBytes(payload, toolChoicePath) if errDel == nil { return updated } } return payload } if choice.Type != gjson.JSON { return payload } choiceType := strings.TrimSpace(choice.Get("type").String()) if strings.EqualFold(choiceType, toolType) { updated, errDel := sjson.DeleteBytes(payload, toolChoicePath) if errDel == nil { return updated } return payload } if strings.EqualFold(choiceType, "tool") { name := strings.TrimSpace(choice.Get("name").String()) if strings.EqualFold(name, toolType) { updated, errDel := sjson.DeleteBytes(payload, toolChoicePath) if errDel == nil { return updated } } } return payload } func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte { tools := gjson.GetBytes(payload, toolsPath) if !tools.Exists() || !tools.IsArray() { return payload } removed := false filtered := []byte(`[]`) for _, tool := range tools.Array() { if tool.Get("type").String() == toolType { removed = true continue } updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw)) if errSet != nil { continue } filtered = updated } if !removed { return payload } updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered) if errSet != nil { return payload } return updated } func payloadRawValue(value any) ([]byte, bool) { if value == nil { return nil, false } switch typed := value.(type) { case string: return []byte(typed), true case []byte: return typed, true default: raw, errMarshal := json.Marshal(typed) if errMarshal != nil { return nil, false } return raw, true } } func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string { fallback = strings.TrimSpace(fallback) if len(opts.Metadata) == 0 { return fallback } raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey] if !ok || raw == nil { return fallback } switch v := raw.(type) { case string: if strings.TrimSpace(v) == "" { return fallback } return strings.TrimSpace(v) case []byte: if len(v) == 0 { return fallback } trimmed := strings.TrimSpace(string(v)) if trimmed == "" { return fallback } return trimmed default: return fallback } } func PayloadRequestPath(opts cliproxyexecutor.Options) string { if len(opts.Metadata) == 0 { return "" } raw, ok := opts.Metadata[cliproxyexecutor.RequestPathMetadataKey] if !ok || raw == nil { return "" } switch v := raw.(type) { case string: return strings.TrimSpace(v) case []byte: return strings.TrimSpace(string(v)) default: return "" } } // matchModelPattern performs simple wildcard matching where '*' matches zero or more characters. // Examples: // // "*-5" matches "gpt-5" // "gpt-*" matches "gpt-5" and "gpt-4" // "gemini-*-pro" matches "gemini-2.5-pro" and "gemini-3-pro". func matchModelPattern(pattern, model string) bool { pattern = strings.TrimSpace(pattern) model = strings.TrimSpace(model) if pattern == "" { return false } if pattern == "*" { return true } // Iterative glob-style matcher supporting only '*' wildcard. pi, si := 0, 0 starIdx := -1 matchIdx := 0 for si < len(model) { if pi < len(pattern) && (pattern[pi] == model[si]) { pi++ si++ continue } if pi < len(pattern) && pattern[pi] == '*' { starIdx = pi matchIdx = si pi++ continue } if starIdx != -1 { pi = starIdx + 1 matchIdx++ si = matchIdx continue } return false } for pi < len(pattern) && pattern[pi] == '*' { pi++ } return pi == len(pattern) }