diff --git a/internal/runtime/executor/helps/payload_helpers.go b/internal/runtime/executor/helps/payload_helpers.go index b868d445..5377a8c1 100644 --- a/internal/runtime/executor/helps/payload_helpers.go +++ b/internal/runtime/executor/helps/payload_helpers.go @@ -151,6 +151,7 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string if cfg.DisableImageGeneration { out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation") + out = removeToolChoiceFromPayloadWithRoot(out, root, "image_generation") } return out } @@ -242,6 +243,55 @@ func removeToolTypeFromPayloadWithRoot(payload []byte, root string, toolType str return removeToolTypeFromToolsArray(payload, toolsPath, toolType) } +func removeToolChoiceFromPayloadWithRoot(payload []byte, root string, toolType string) []byte { + if len(payload) == 0 { + return payload + } + toolType = strings.TrimSpace(toolType) + if toolType == "" { + return payload + } + toolChoicePath := buildPayloadPath(root, "tool_choice") + return removeToolChoiceFromPayload(payload, toolChoicePath, toolType) +} + +func removeToolChoiceFromPayload(payload []byte, toolChoicePath string, toolType string) []byte { + choice := gjson.GetBytes(payload, toolChoicePath) + if !choice.Exists() { + return payload + } + if choice.Type == gjson.String { + if strings.EqualFold(strings.TrimSpace(choice.String()), toolType) { + updated, errDel := sjson.DeleteBytes(payload, toolChoicePath) + if errDel == nil { + return updated + } + } + return payload + } + if choice.Type != gjson.JSON { + return payload + } + choiceType := strings.TrimSpace(choice.Get("type").String()) + if strings.EqualFold(choiceType, toolType) { + updated, errDel := sjson.DeleteBytes(payload, toolChoicePath) + if errDel == nil { + return updated + } + return payload + } + if strings.EqualFold(choiceType, "tool") { + name := strings.TrimSpace(choice.Get("name").String()) + if strings.EqualFold(name, toolType) { + updated, errDel := sjson.DeleteBytes(payload, toolChoicePath) + if errDel == nil { + return updated + } + } + } + return payload +} + func removeToolTypeFromToolsArray(payload []byte, toolsPath string, toolType string) []byte { tools := gjson.GetBytes(payload, toolsPath) if !tools.Exists() || !tools.IsArray() { diff --git a/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go b/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go index 143393dc..ae75f450 100644 --- a/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go +++ b/internal/runtime/executor/helps/payload_helpers_disable_image_generation_test.go @@ -48,3 +48,29 @@ func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWith t.Fatalf("expected remaining tool type=web_search, got %q", got) } } + +func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByType(t *testing.T) { + cfg := &config.Config{ + SDKConfig: config.SDKConfig{DisableImageGeneration: true}, + } + payload := []byte(`{"tools":[{"type":"image_generation"},{"type":"function","name":"f1"}],"tool_choice":{"type":"image_generation"}}`) + + out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "") + + if gjson.GetBytes(out, "tool_choice").Exists() { + t.Fatalf("expected tool_choice to be removed") + } +} + +func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByNameWithRoot(t *testing.T) { + cfg := &config.Config{ + SDKConfig: config.SDKConfig{DisableImageGeneration: true}, + } + payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}],"tool_choice":{"type":"tool","name":"image_generation"}}}`) + + out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "") + + if gjson.GetBytes(out, "request.tool_choice").Exists() { + t.Fatalf("expected request.tool_choice to be removed") + } +}