feat: support disabling image generation globally

- Added `disable-image-generation` configuration flag to disable the `image_generation` tool globally.
- Updated payload handling to remove `image_generation` tools from request payload arrays when the flag is enabled.
- Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to return 404 when the feature is disabled.
- Enhanced configuration diff logging to track changes for the `disable-image-generation` flag.
- Added accompanying unit tests for the new feature in payload helpers and image handler logic.
This commit is contained in:
Luis Pater
2026-04-30 03:42:27 +08:00
parent 359ec30d0c
commit e3e60f914b
11 changed files with 284 additions and 126 deletions
+9 -3
View File
@@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream")
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
body, _ = sjson.DeleteBytes(body, "stream_options")
body, _ = sjson.SetBytes(body, "model", baseModel)
body = normalizeCodexInstructions(body)
body = ensureImageGenerationTool(body, baseModel, auth)
if e.cfg == nil || !e.cfg.DisableImageGeneration {
body = ensureImageGenerationTool(body, baseModel, auth)
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
+164 -120
View File
@@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
if cfg == nil || len(payload) == 0 {
return payload
}
rules := cfg.Payload
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
return payload
}
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return payload
}
candidates := payloadModelCandidates(model, requestedModel)
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
rules := cfg.Payload
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
if hasPayloadRules {
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model != "" || requestedModel != "" {
candidates := payloadModelCandidates(model, requestedModel)
source := original
if len(source) == 0 {
source = payload
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
if cfg.DisableImageGeneration {
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
}
return out
}
@@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
return r + "." + p
}
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
if len(payload) == 0 {
return payload
}
toolType = strings.TrimSpace(toolType)
if toolType == "" {
return payload
}
toolsPath := buildPayloadPath(root, "tools")
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
}
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
tools := gjson.GetBytes(payload, toolsPath)
if !tools.Exists() || !tools.IsArray() {
return payload
}
removed := false
filtered := []byte(`[]`)
for _, tool := range tools.Array() {
if tool.Get("type").String() == toolType {
removed = true
continue
}
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
if errSet != nil {
continue
}
filtered = updated
}
if !removed {
return payload
}
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
if errSet != nil {
return payload
}
return updated
}
func payloadRawValue(value any) ([]byte, bool) {
if value == nil {
return nil, false
@@ -0,0 +1,50 @@
package helps
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/tidwall/gjson"
)
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
cfg := &config.Config{
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
}
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
tools := gjson.GetBytes(out, "tools")
if !tools.Exists() || !tools.IsArray() {
t.Fatalf("expected tools array, got %v", tools.Type)
}
arr := tools.Array()
if len(arr) != 1 {
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
}
if got := arr[0].Get("type").String(); got != "function" {
t.Fatalf("expected remaining tool type=function, got %q", got)
}
}
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
cfg := &config.Config{
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
}
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
tools := gjson.GetBytes(out, "request.tools")
if !tools.Exists() || !tools.IsArray() {
t.Fatalf("expected request.tools array, got %v", tools.Type)
}
arr := tools.Array()
if len(arr) != 1 {
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
}
if got := arr[0].Get("type").String(); got != "web_search" {
t.Fatalf("expected remaining tool type=web_search, got %q", got)
}
}