feat: support disabling image generation globally
- Added `disable-image-generation` configuration flag to disable the `image_generation` tool globally. - Updated payload handling to remove `image_generation` tools from request payload arrays when the flag is enabled. - Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to return 404 when the feature is disabled. - Enhanced configuration diff logging to track changes for the `disable-image-generation` flag. - Added accompanying unit tests for the new feature in payload helpers and image handler logic.
This commit is contained in:
@@ -1013,6 +1013,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration {
|
||||
log.Infof("disable-image-generation updated: %t -> %t", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration)
|
||||
}
|
||||
|
||||
applySignatureCacheConfig(oldCfg, cfg)
|
||||
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
|
||||
@@ -610,6 +610,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.DisableImageGeneration = false
|
||||
cfg.Pprof.Enable = false
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
|
||||
@@ -9,6 +9,12 @@ type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// DisableImageGeneration disables the built-in image_generation tool when true.
|
||||
// When enabled, the server will avoid injecting image_generation into request payloads,
|
||||
// will remove any existing image_generation tool entries from tools arrays, and will
|
||||
// return 404 for /v1/images/generations and /v1/images/edits.
|
||||
DisableImageGeneration bool `yaml:"disable-image-generation" json:"disable-image-generation"`
|
||||
|
||||
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
||||
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
||||
EnableGeminiCLIEndpoint bool `yaml:"enable-gemini-cli-endpoint" json:"enable-gemini-cli-endpoint"`
|
||||
|
||||
@@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
|
||||
@@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
rules := cfg.Payload
|
||||
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
|
||||
return payload
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return payload
|
||||
}
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
|
||||
rules := cfg.Payload
|
||||
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
|
||||
if hasPayloadRules {
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model != "" || requestedModel != "" {
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
|
||||
if cfg.DisableImageGeneration {
|
||||
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
|
||||
return r + "." + p
|
||||
}
|
||||
|
||||
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
||||
if len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
toolType = strings.TrimSpace(toolType)
|
||||
if toolType == "" {
|
||||
return payload
|
||||
}
|
||||
toolsPath := buildPayloadPath(root, "tools")
|
||||
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
|
||||
}
|
||||
|
||||
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
|
||||
tools := gjson.GetBytes(payload, toolsPath)
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return payload
|
||||
}
|
||||
removed := false
|
||||
filtered := []byte(`[]`)
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() == toolType {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
filtered = updated
|
||||
}
|
||||
if !removed {
|
||||
return payload
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
|
||||
if errSet != nil {
|
||||
return payload
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func payloadRawValue(value any) ([]byte, bool) {
|
||||
if value == nil {
|
||||
return nil, false
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
||||
}
|
||||
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
|
||||
|
||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
|
||||
|
||||
tools := gjson.GetBytes(out, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
t.Fatalf("expected tools array, got %v", tools.Type)
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
|
||||
}
|
||||
if got := arr[0].Get("type").String(); got != "function" {
|
||||
t.Fatalf("expected remaining tool type=function, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
||||
}
|
||||
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
|
||||
|
||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
|
||||
|
||||
tools := gjson.GetBytes(out, "request.tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
t.Fatalf("expected request.tools array, got %v", tools.Type)
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
|
||||
}
|
||||
if got := arr[0].Get("type").String(); got != "web_search" {
|
||||
t.Fatalf("expected remaining tool type=web_search, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -42,6 +42,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration {
|
||||
changes = append(changes, fmt.Sprintf("disable-image-generation: %t -> %t", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
|
||||
@@ -279,6 +279,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
NonStreamKeepAliveInterval: 5,
|
||||
DisableImageGeneration: true,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -287,6 +288,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
expectContains(t, details, "logging-to-file: false -> true")
|
||||
expectContains(t, details, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "disable-image-generation: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-credentials: 1 -> 3")
|
||||
@@ -403,9 +405,10 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
DisableImageGeneration: true,
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
@@ -431,6 +434,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
expectContains(t, changes, "logging-to-file: false -> true")
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "disable-image-generation: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
|
||||
Reference in New Issue
Block a user