feat: support disabling image generation globally
- Added `disable-image-generation` configuration flag to disable the `image_generation` tool globally. - Updated payload handling to remove `image_generation` tools from request payload arrays when the flag is enabled. - Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to return 404 when the feature is disabled. - Enhanced configuration diff logging to track changes for the `disable-image-generation` flag. - Added accompanying unit tests for the new feature in payload helpers and image handler logic.
This commit is contained in:
@@ -181,7 +181,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -329,7 +331,9 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
@@ -424,7 +428,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
|
||||
@@ -20,133 +20,137 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
||||
if cfg == nil || len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
rules := cfg.Payload
|
||||
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
|
||||
return payload
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model == "" && requestedModel == "" {
|
||||
return payload
|
||||
}
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
out := payload
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
|
||||
rules := cfg.Payload
|
||||
hasPayloadRules := len(rules.Default) != 0 || len(rules.DefaultRaw) != 0 || len(rules.Override) != 0 || len(rules.OverrideRaw) != 0 || len(rules.Filter) != 0
|
||||
if hasPayloadRules {
|
||||
model = strings.TrimSpace(model)
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if model != "" || requestedModel != "" {
|
||||
candidates := payloadModelCandidates(model, requestedModel)
|
||||
source := original
|
||||
if len(source) == 0 {
|
||||
source = payload
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
appliedDefaults := make(map[string]struct{})
|
||||
// Apply default rules: first write wins per field across all matching rules.
|
||||
for i := range rules.Default {
|
||||
rule := &rules.Default[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply default raw rules: first write wins per field across all matching rules.
|
||||
for i := range rules.DefaultRaw {
|
||||
rule := &rules.DefaultRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
if gjson.GetBytes(source, fullPath).Exists() {
|
||||
continue
|
||||
}
|
||||
if _, ok := appliedDefaults[fullPath]; ok {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
appliedDefaults[fullPath] = struct{}{}
|
||||
}
|
||||
}
|
||||
// Apply override rules: last write wins per field across all matching rules.
|
||||
for i := range rules.Override {
|
||||
rule := &rules.Override[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetBytes(out, fullPath, value)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply override raw rules: last write wins per field across all matching rules.
|
||||
for i := range rules.OverrideRaw {
|
||||
rule := &rules.OverrideRaw[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for path, value := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
rawValue, ok := payloadRawValue(value)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
// Apply filter rules: remove matching paths from payload.
|
||||
for i := range rules.Filter {
|
||||
rule := &rules.Filter[i]
|
||||
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||
continue
|
||||
}
|
||||
for _, path := range rule.Params {
|
||||
fullPath := buildPayloadPath(root, path)
|
||||
if fullPath == "" {
|
||||
continue
|
||||
}
|
||||
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||
if errDel != nil {
|
||||
continue
|
||||
}
|
||||
out = updated
|
||||
}
|
||||
|
||||
if cfg.DisableImageGeneration {
|
||||
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -226,6 +230,46 @@ func buildPayloadPath(root, path string) string {
|
||||
return r + "." + p
|
||||
}
|
||||
|
||||
func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType string) []byte {
|
||||
if len(payload) == 0 {
|
||||
return payload
|
||||
}
|
||||
toolType = strings.TrimSpace(toolType)
|
||||
if toolType == "" {
|
||||
return payload
|
||||
}
|
||||
toolsPath := buildPayloadPath(root, "tools")
|
||||
return removeToolTypeFromToolsArray(payload, toolsPath, toolType)
|
||||
}
|
||||
|
||||
func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte {
|
||||
tools := gjson.GetBytes(payload, toolsPath)
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return payload
|
||||
}
|
||||
removed := false
|
||||
filtered := []byte(`[]`)
|
||||
for _, tool := range tools.Array() {
|
||||
if tool.Get("type").String() == toolType {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(filtered, "-1", []byte(tool.Raw))
|
||||
if errSet != nil {
|
||||
continue
|
||||
}
|
||||
filtered = updated
|
||||
}
|
||||
if !removed {
|
||||
return payload
|
||||
}
|
||||
updated, errSet := sjson.SetRawBytes(payload, toolsPath, filtered)
|
||||
if errSet != nil {
|
||||
return payload
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
func payloadRawValue(value any) ([]byte, bool) {
|
||||
if value == nil {
|
||||
return nil, false
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
||||
}
|
||||
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
|
||||
|
||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
|
||||
|
||||
tools := gjson.GetBytes(out, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
t.Fatalf("expected tools array, got %v", tools.Type)
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
|
||||
}
|
||||
if got := arr[0].Get("type").String(); got != "function" {
|
||||
t.Fatalf("expected remaining tool type=function, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
||||
}
|
||||
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
|
||||
|
||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
|
||||
|
||||
tools := gjson.GetBytes(out, "request.tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
t.Fatalf("expected request.tools array, got %v", tools.Type)
|
||||
}
|
||||
arr := tools.Array()
|
||||
if len(arr) != 1 {
|
||||
t.Fatalf("expected 1 tool after removal, got %d", len(arr))
|
||||
}
|
||||
if got := arr[0].Get("type").String(); got != "web_search" {
|
||||
t.Fatalf("expected remaining tool type=web_search, got %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user