feat: support disabling image generation globally

- Added `disable-image-generation` configuration flag to disable the `image_generation` tool globally.
- Updated payload handling to remove `image_generation` tools from request payload arrays when the flag is enabled.
- Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to return 404 when the feature is disabled.
- Enhanced configuration diff logging to track changes for the `disable-image-generation` flag.
- Added accompanying unit tests for the new feature in payload helpers and image handler logic.
This commit is contained in:
Luis Pater
2026-04-30 03:42:27 +08:00
parent 359ec30d0c
commit e3e60f914b
11 changed files with 284 additions and 126 deletions
+164 -120
View File
@@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if cfg == nil || len(payload) == 0 {
return payload
}
rules := cfg.Payload
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
return payload
}
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return payload
}
candidates := payloadModelCandidates(model, requestedModel)
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
rules := cfg.Payload
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
if hasPayloadRules {
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model != "" || requestedModel != "" {
candidates := payloadModelCandidates(model, requestedModel)
source := original
if len(source) == 0 {
source = payload
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
if cfg.DisableImageGeneration {
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
}
return out
}
@@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
return r + "." + p
}
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
if len(payload) == 0 {
return payload
}
toolType = strings.TrimSpace(toolType)
if toolType == "" {
return payload
}
toolsPath := buildPayloadPath(root, "tools")
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
}
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
tools := gjson.GetBytes(payload, toolsPath)
if !tools.Exists() || !tools.IsArray() {
return payload
}
removed := false
filtered := []byte(`[]`)
for _, tool := range tools.Array() {
if tool.Get("type").String() == toolType {
removed = true
continue
}
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
if errSet != nil {
continue
}
filtered = updated
}
if !removed {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
if errSet != nil {
return payload
}
return updated
}
func payloadRawValue(value any) ([]byte, bool) {
if value == nil {
return nil, false