feat: add tri-state support for disable-image-generation configuration
- Introduced `DisableImageGenerationMode` with support for `false`, `true`, and `chat` values. - Updated payload handling to preserve `image_generation` on images endpoints when `chat` mode is enabled. - Modified OpenAI image handlers (`ImagesGenerations`, `ImagesEdits`) to respect tri-state logic. - Added unit tests for `DisableImageGenerationMode` behavior and endpoint-specific handling. - Enhanced configuration diff logging to support `DisableImageGenerationMode`.
This commit is contained in:
+3
-2
@@ -90,8 +90,9 @@ max-retry-interval: 30
|
|||||||
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
# When true, disable auth/model cooldown scheduling globally (prevents blackout windows after failure states).
|
||||||
disable-cooling: false
|
disable-cooling: false
|
||||||
|
|
||||||
# When true, disable the built-in image_generation tool globally.
|
# disable-image-generation supports: false (default), true, or "chat".
|
||||||
# The server will stop injecting image_generation and will also remove it from request payload tools arrays.
|
# - true: disable image_generation everywhere (also returns 404 for /v1/images/generations and /v1/images/edits).
|
||||||
|
# - "chat": disable image_generation injection on non-images endpoints, but keep /v1/images/generations and /v1/images/edits enabled.
|
||||||
disable-image-generation: false
|
disable-image-generation: false
|
||||||
|
|
||||||
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
|
# Core auth auto-refresh worker pool size (OAuth/file-based auth token refresh).
|
||||||
|
|||||||
@@ -1014,7 +1014,7 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration {
|
if oldCfg != nil && oldCfg.DisableImageGeneration != cfg.DisableImageGeneration {
|
||||||
log.Infof("disable-image-generation updated: %t -> %t", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration)
|
log.Infof("disable-image-generation updated: %v -> %v", oldCfg.DisableImageGeneration, cfg.DisableImageGeneration)
|
||||||
}
|
}
|
||||||
|
|
||||||
applySignatureCacheConfig(oldCfg, cfg)
|
applySignatureCacheConfig(oldCfg, cfg)
|
||||||
|
|||||||
@@ -610,7 +610,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.ErrorLogsMaxFiles = 10
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
cfg.UsageStatisticsEnabled = false
|
cfg.UsageStatisticsEnabled = false
|
||||||
cfg.DisableCooling = false
|
cfg.DisableCooling = false
|
||||||
cfg.DisableImageGeneration = false
|
cfg.DisableImageGeneration = DisableImageGenerationOff
|
||||||
cfg.Pprof.Enable = false
|
cfg.Pprof.Enable = false
|
||||||
cfg.Pprof.Addr = DefaultPprofAddr
|
cfg.Pprof.Addr = DefaultPprofAddr
|
||||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||||
|
|||||||
@@ -0,0 +1,136 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DisableImageGenerationMode is a tri-state config value for disable-image-generation.
|
||||||
|
//
|
||||||
|
// It supports:
|
||||||
|
// - false: enabled
|
||||||
|
// - true: disabled everywhere (including /v1/images/* endpoints)
|
||||||
|
// - "chat": disabled for all non-images endpoints, but enabled for /v1/images/generations and /v1/images/edits
|
||||||
|
type DisableImageGenerationMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
DisableImageGenerationOff DisableImageGenerationMode = iota
|
||||||
|
DisableImageGenerationAll
|
||||||
|
DisableImageGenerationChat
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m DisableImageGenerationMode) String() string {
|
||||||
|
switch m {
|
||||||
|
case DisableImageGenerationOff:
|
||||||
|
return "false"
|
||||||
|
case DisableImageGenerationAll:
|
||||||
|
return "true"
|
||||||
|
case DisableImageGenerationChat:
|
||||||
|
return "chat"
|
||||||
|
default:
|
||||||
|
return "false"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DisableImageGenerationMode) MarshalYAML() (any, error) {
|
||||||
|
switch m {
|
||||||
|
case DisableImageGenerationAll:
|
||||||
|
return true, nil
|
||||||
|
case DisableImageGenerationChat:
|
||||||
|
return "chat", nil
|
||||||
|
default:
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DisableImageGenerationMode) UnmarshalYAML(value *yaml.Node) error {
|
||||||
|
mode, err := parseDisableImageGenerationNode(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*m = mode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m DisableImageGenerationMode) MarshalJSON() ([]byte, error) {
|
||||||
|
switch m {
|
||||||
|
case DisableImageGenerationAll:
|
||||||
|
return []byte("true"), nil
|
||||||
|
case DisableImageGenerationChat:
|
||||||
|
return json.Marshal("chat")
|
||||||
|
default:
|
||||||
|
return []byte("false"), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DisableImageGenerationMode) UnmarshalJSON(data []byte) error {
|
||||||
|
mode, err := parseDisableImageGenerationJSON(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*m = mode
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDisableImageGenerationNode(value *yaml.Node) (DisableImageGenerationMode, error) {
|
||||||
|
if value == nil {
|
||||||
|
return DisableImageGenerationOff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// First try a typed bool decode (covers unquoted true/false and YAML 1.1 bools).
|
||||||
|
var b bool
|
||||||
|
if err := value.Decode(&b); err == nil && value.Kind == yaml.ScalarNode && value.ShortTag() == "!!bool" {
|
||||||
|
if b {
|
||||||
|
return DisableImageGenerationAll, nil
|
||||||
|
}
|
||||||
|
return DisableImageGenerationOff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to string decoding (covers quoted "true"/"false" and "chat").
|
||||||
|
var s string
|
||||||
|
if err := value.Decode(&s); err != nil {
|
||||||
|
return DisableImageGenerationOff, fmt.Errorf("invalid disable-image-generation value")
|
||||||
|
}
|
||||||
|
return parseDisableImageGenerationString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDisableImageGenerationJSON(data []byte) (DisableImageGenerationMode, error) {
|
||||||
|
trimmed := bytes.TrimSpace(data)
|
||||||
|
if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) {
|
||||||
|
return DisableImageGenerationOff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// bool
|
||||||
|
var b bool
|
||||||
|
if err := json.Unmarshal(trimmed, &b); err == nil {
|
||||||
|
if b {
|
||||||
|
return DisableImageGenerationAll, nil
|
||||||
|
}
|
||||||
|
return DisableImageGenerationOff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// string
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(trimmed, &s); err != nil {
|
||||||
|
return DisableImageGenerationOff, fmt.Errorf("invalid disable-image-generation value")
|
||||||
|
}
|
||||||
|
return parseDisableImageGenerationString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDisableImageGenerationString(s string) (DisableImageGenerationMode, error) {
|
||||||
|
s = strings.TrimSpace(strings.ToLower(s))
|
||||||
|
switch s {
|
||||||
|
case "", "false", "0", "off", "no":
|
||||||
|
return DisableImageGenerationOff, nil
|
||||||
|
case "true", "1", "on", "yes":
|
||||||
|
return DisableImageGenerationAll, nil
|
||||||
|
case "chat":
|
||||||
|
return DisableImageGenerationChat, nil
|
||||||
|
default:
|
||||||
|
return DisableImageGenerationOff, fmt.Errorf("invalid disable-image-generation value %q (allowed: true, false, chat)", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDisableImageGenerationMode_UnmarshalYAML(t *testing.T) {
|
||||||
|
type wrapper struct {
|
||||||
|
V DisableImageGenerationMode `yaml:"disable-image-generation"`
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
var w wrapper
|
||||||
|
if err := yaml.Unmarshal([]byte("disable-image-generation: false\n"), &w); err != nil {
|
||||||
|
t.Fatalf("unmarshal false: %v", err)
|
||||||
|
}
|
||||||
|
if w.V != DisableImageGenerationOff {
|
||||||
|
t.Fatalf("false => %v, want %v", w.V, DisableImageGenerationOff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
var w wrapper
|
||||||
|
if err := yaml.Unmarshal([]byte("disable-image-generation: true\n"), &w); err != nil {
|
||||||
|
t.Fatalf("unmarshal true: %v", err)
|
||||||
|
}
|
||||||
|
if w.V != DisableImageGenerationAll {
|
||||||
|
t.Fatalf("true => %v, want %v", w.V, DisableImageGenerationAll)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
var w wrapper
|
||||||
|
if err := yaml.Unmarshal([]byte("disable-image-generation: chat\n"), &w); err != nil {
|
||||||
|
t.Fatalf("unmarshal chat: %v", err)
|
||||||
|
}
|
||||||
|
if w.V != DisableImageGenerationChat {
|
||||||
|
t.Fatalf("chat => %v, want %v", w.V, DisableImageGenerationChat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDisableImageGenerationMode_UnmarshalJSON(t *testing.T) {
|
||||||
|
{
|
||||||
|
var v DisableImageGenerationMode
|
||||||
|
if err := json.Unmarshal([]byte("false"), &v); err != nil {
|
||||||
|
t.Fatalf("unmarshal false: %v", err)
|
||||||
|
}
|
||||||
|
if v != DisableImageGenerationOff {
|
||||||
|
t.Fatalf("false => %v, want %v", v, DisableImageGenerationOff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
var v DisableImageGenerationMode
|
||||||
|
if err := json.Unmarshal([]byte("true"), &v); err != nil {
|
||||||
|
t.Fatalf("unmarshal true: %v", err)
|
||||||
|
}
|
||||||
|
if v != DisableImageGenerationAll {
|
||||||
|
t.Fatalf("true => %v, want %v", v, DisableImageGenerationAll)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
var v DisableImageGenerationMode
|
||||||
|
if err := json.Unmarshal([]byte(`"chat"`), &v); err != nil {
|
||||||
|
t.Fatalf("unmarshal chat: %v", err)
|
||||||
|
}
|
||||||
|
if v != DisableImageGenerationChat {
|
||||||
|
t.Fatalf("chat => %v, want %v", v, DisableImageGenerationChat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,11 +9,15 @@ type SDKConfig struct {
|
|||||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||||
|
|
||||||
// DisableImageGeneration disables the built-in image_generation tool when true.
|
// DisableImageGeneration controls whether the built-in image_generation tool is injected/allowed.
|
||||||
// When enabled, the server will avoid injecting image_generation into request payloads,
|
//
|
||||||
// will remove any existing image_generation tool entries from tools arrays, and will
|
// Supported values:
|
||||||
// return 404 for /v1/images/generations and /v1/images/edits.
|
// - false (default): image_generation is enabled everywhere (normal behavior).
|
||||||
DisableImageGeneration bool `yaml:"disable-image-generation" json:"disable-image-generation"`
|
// - true: image_generation is disabled everywhere. The server stops injecting it, removes it from request payloads,
|
||||||
|
// and returns 404 for /v1/images/generations and /v1/images/edits.
|
||||||
|
// - "chat": disable image_generation injection for all non-images endpoints (e.g. /v1/responses, /v1/chat/completions),
|
||||||
|
// while keeping /v1/images/generations and /v1/images/edits enabled and preserving image_generation there.
|
||||||
|
DisableImageGeneration DisableImageGenerationMode `yaml:"disable-image-generation" json:"disable-image-generation"`
|
||||||
|
|
||||||
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
// EnableGeminiCLIEndpoint controls whether Gemini CLI internal endpoints (/v1internal:*) are enabled.
|
||||||
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
// Default is false for safety; when false, /v1internal:* requests are rejected.
|
||||||
|
|||||||
@@ -428,7 +428,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
|||||||
}
|
}
|
||||||
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
payload = fixGeminiImageAspectRatio(baseModel, payload)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
payload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel, requestPath)
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||||
|
|||||||
@@ -521,7 +521,8 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel, requestPath)
|
||||||
|
|
||||||
useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg)
|
useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg)
|
||||||
|
|
||||||
@@ -718,7 +719,8 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel, requestPath)
|
||||||
|
|
||||||
useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg)
|
useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg)
|
||||||
|
|
||||||
@@ -1178,7 +1180,8 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel, requestPath)
|
||||||
|
|
||||||
useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg)
|
useCredits := cliproxyauth.AntigravityCreditsRequested(ctx) && antigravityCreditsRetryEnabled(e.cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -164,7 +164,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body = ensureModelMaxTokens(body, baseModel)
|
body = ensureModelMaxTokens(body, baseModel)
|
||||||
|
|
||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
@@ -349,7 +350,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body = ensureModelMaxTokens(body, baseModel)
|
body = ensureModelMaxTokens(body, baseModel)
|
||||||
|
|
||||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||||
|
|||||||
@@ -173,7 +173,8 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
@@ -181,7 +182,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
|||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||||
body = normalizeCodexInstructions(body)
|
body = normalizeCodexInstructions(body)
|
||||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
if e.cfg == nil || e.cfg.DisableImageGeneration == config.DisableImageGenerationOff {
|
||||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,11 +328,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.DeleteBytes(body, "stream")
|
body, _ = sjson.DeleteBytes(body, "stream")
|
||||||
body = normalizeCodexInstructions(body)
|
body = normalizeCodexInstructions(body)
|
||||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
if e.cfg == nil || e.cfg.DisableImageGeneration == config.DisableImageGenerationOff {
|
||||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -421,14 +423,15 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body = normalizeCodexInstructions(body)
|
body = normalizeCodexInstructions(body)
|
||||||
if e.cfg == nil || !e.cfg.DisableImageGeneration {
|
if e.cfg == nil || e.cfg.DisableImageGeneration == config.DisableImageGenerationOff {
|
||||||
body = ensureImageGenerationTool(body, baseModel, auth)
|
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -184,14 +184,16 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
body, _ = sjson.SetBytes(body, "stream", true)
|
body, _ = sjson.SetBytes(body, "stream", true)
|
||||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
body = normalizeCodexInstructions(body)
|
||||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
if e.cfg == nil || e.cfg.DisableImageGeneration == config.DisableImageGenerationOff {
|
||||||
|
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
@@ -387,7 +389,12 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel, requestPath)
|
||||||
|
body = normalizeCodexInstructions(body)
|
||||||
|
if e.cfg == nil || e.cfg.DisableImageGeneration == config.DisableImageGenerationOff {
|
||||||
|
body = ensureImageGenerationTool(body, baseModel, auth)
|
||||||
|
}
|
||||||
|
|
||||||
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
|
||||||
|
|||||||
@@ -139,7 +139,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
|
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel, requestPath)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
@@ -294,7 +295,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
|
|
||||||
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
basePayload = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel, requestPath)
|
||||||
|
|
||||||
projectID := resolveGeminiProjectID(auth)
|
projectID := resolveGeminiProjectID(auth)
|
||||||
|
|
||||||
|
|||||||
@@ -132,7 +132,8 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
@@ -239,7 +240,8 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
baseURL := resolveGeminiBaseURL(auth)
|
baseURL := resolveGeminiBaseURL(auth)
|
||||||
|
|||||||
@@ -335,7 +335,8 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -455,7 +456,8 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := getVertexAction(baseModel, false)
|
action := getVertexAction(baseModel, false)
|
||||||
@@ -565,7 +567,8 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := getVertexAction(baseModel, true)
|
action := getVertexAction(baseModel, true)
|
||||||
@@ -694,7 +697,8 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := getVertexAction(baseModel, true)
|
action := getVertexAction(baseModel, true)
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ import (
|
|||||||
// and restricts matches to the given protocol when supplied. Defaults are checked
|
// and restricts matches to the given protocol when supplied. Defaults are checked
|
||||||
// against the original payload when provided. requestedModel carries the client-visible
|
// against the original payload when provided. requestedModel carries the client-visible
|
||||||
// model name before alias resolution so payload rules can target aliases precisely.
|
// model name before alias resolution so payload rules can target aliases precisely.
|
||||||
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
|
// requestPath is the inbound HTTP request path (when available) used for endpoint-scoped gates.
|
||||||
|
func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string, requestPath string) []byte {
|
||||||
if cfg == nil || len(payload) == 0 {
|
if cfg == nil || len(payload) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
@@ -149,13 +150,34 @@ func ApplyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.DisableImageGeneration {
|
if cfg.DisableImageGeneration != config.DisableImageGenerationOff {
|
||||||
|
if cfg.DisableImageGeneration == config.DisableImageGenerationChat && isImagesEndpointRequestPath(requestPath) {
|
||||||
|
return out
|
||||||
|
}
|
||||||
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
out = removeToolTypeFromPayloadWithRoot(out, root, "image_generation")
|
||||||
out = removeToolChoiceFromPayloadWithRoot(out, root, "image_generation")
|
out = removeToolChoiceFromPayloadWithRoot(out, root, "image_generation")
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isImagesEndpointRequestPath(path string) bool {
|
||||||
|
path = strings.TrimSpace(path)
|
||||||
|
if path == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if path == "/v1/images/generations" || path == "/v1/images/edits" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
// Be tolerant of prefix routers that may report a longer matched route.
|
||||||
|
if strings.HasSuffix(path, "/v1/images/generations") || strings.HasSuffix(path, "/v1/images/edits") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(path, "/images/generations") || strings.HasSuffix(path, "/images/edits") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
||||||
if len(rules) == 0 || len(models) == 0 {
|
if len(rules) == 0 || len(models) == 0 {
|
||||||
return false
|
return false
|
||||||
@@ -367,6 +389,24 @@ func PayloadRequestedModel(opts cliproxyexecutor.Options, fallback string) strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PayloadRequestPath(opts cliproxyexecutor.Options) string {
|
||||||
|
if len(opts.Metadata) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw, ok := opts.Metadata[cliproxyexecutor.RequestPathMetadataKey]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(v))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
|
||||||
// Examples:
|
// Examples:
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
|
|
||||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
|
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll},
|
||||||
}
|
}
|
||||||
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
|
payload := []byte(`{"tools":[{"type":"image_generation","output_format":"png"},{"type":"function","name":"f1"}]}`)
|
||||||
|
|
||||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
|
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "", "")
|
||||||
|
|
||||||
tools := gjson.GetBytes(out, "tools")
|
tools := gjson.GetBytes(out, "tools")
|
||||||
if !tools.Exists() || !tools.IsArray() {
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
@@ -30,11 +30,11 @@ func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntry(t *
|
|||||||
|
|
||||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
|
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWithRoot(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll},
|
||||||
}
|
}
|
||||||
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
|
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}]}}`)
|
||||||
|
|
||||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
|
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "", "")
|
||||||
|
|
||||||
tools := gjson.GetBytes(out, "request.tools")
|
tools := gjson.GetBytes(out, "request.tools")
|
||||||
if !tools.Exists() || !tools.IsArray() {
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
@@ -51,11 +51,11 @@ func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolsEntryWith
|
|||||||
|
|
||||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByType(t *testing.T) {
|
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByType(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll},
|
||||||
}
|
}
|
||||||
payload := []byte(`{"tools":[{"type":"image_generation"},{"type":"function","name":"f1"}],"tool_choice":{"type":"image_generation"}}`)
|
payload := []byte(`{"tools":[{"type":"image_generation"},{"type":"function","name":"f1"}],"tool_choice":{"type":"image_generation"}}`)
|
||||||
|
|
||||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "")
|
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "", "")
|
||||||
|
|
||||||
if gjson.GetBytes(out, "tool_choice").Exists() {
|
if gjson.GetBytes(out, "tool_choice").Exists() {
|
||||||
t.Fatalf("expected tool_choice to be removed")
|
t.Fatalf("expected tool_choice to be removed")
|
||||||
@@ -64,13 +64,34 @@ func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByTy
|
|||||||
|
|
||||||
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByNameWithRoot(t *testing.T) {
|
func TestApplyPayloadConfigWithRoot_DisableImageGeneration_RemovesToolChoiceByNameWithRoot(t *testing.T) {
|
||||||
cfg := &config.Config{
|
cfg := &config.Config{
|
||||||
SDKConfig: config.SDKConfig{DisableImageGeneration: true},
|
SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll},
|
||||||
}
|
}
|
||||||
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}],"tool_choice":{"type":"tool","name":"image_generation"}}}`)
|
payload := []byte(`{"request":{"tools":[{"type":"image_generation"},{"type":"web_search"}],"tool_choice":{"type":"tool","name":"image_generation"}}}`)
|
||||||
|
|
||||||
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "")
|
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "gemini-cli", "request", payload, nil, "", "")
|
||||||
|
|
||||||
if gjson.GetBytes(out, "request.tool_choice").Exists() {
|
if gjson.GetBytes(out, "request.tool_choice").Exists() {
|
||||||
t.Fatalf("expected request.tool_choice to be removed")
|
t.Fatalf("expected request.tool_choice to be removed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyPayloadConfigWithRoot_DisableImageGenerationChat_KeepsImageGenerationOnImagesEndpoints(t *testing.T) {
|
||||||
|
cfg := &config.Config{
|
||||||
|
SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationChat},
|
||||||
|
}
|
||||||
|
payload := []byte(`{"tools":[{"type":"image_generation"},{"type":"function","name":"f1"}],"tool_choice":{"type":"image_generation"}}`)
|
||||||
|
|
||||||
|
out := ApplyPayloadConfigWithRoot(cfg, "gpt-5.4", "openai-response", "", payload, nil, "", "/v1/images/generations")
|
||||||
|
|
||||||
|
tools := gjson.GetBytes(out, "tools")
|
||||||
|
if !tools.Exists() || !tools.IsArray() {
|
||||||
|
t.Fatalf("expected tools array, got %v", tools.Type)
|
||||||
|
}
|
||||||
|
arr := tools.Array()
|
||||||
|
if len(arr) != 2 {
|
||||||
|
t.Fatalf("expected 2 tools (no removal), got %d", len(arr))
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(out, "tool_choice").Exists() {
|
||||||
|
t.Fatalf("expected tool_choice to be kept on images endpoint")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -108,7 +108,8 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, err = normalizeKimiToolMessageLinks(body)
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return resp, err
|
return resp, err
|
||||||
@@ -217,7 +218,8 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||||
}
|
}
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
body, err = normalizeKimiToolMessageLinks(body)
|
body, err = normalizeKimiToolMessageLinks(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -97,7 +97,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
|
||||||
if opts.Alt == "responses/compact" {
|
if opts.Alt == "responses/compact" {
|
||||||
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
|
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
|
||||||
translated = updated
|
translated = updated
|
||||||
@@ -199,7 +200,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
|
||||||
|
|
||||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||||
}
|
}
|
||||||
if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration {
|
if oldCfg.DisableImageGeneration != newCfg.DisableImageGeneration {
|
||||||
changes = append(changes, fmt.Sprintf("disable-image-generation: %t -> %t", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration))
|
changes = append(changes, fmt.Sprintf("disable-image-generation: %v -> %v", oldCfg.DisableImageGeneration, newCfg.DisableImageGeneration))
|
||||||
}
|
}
|
||||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
|||||||
APIKeys: []string{" key-1 ", "key-2"},
|
APIKeys: []string{" key-1 ", "key-2"},
|
||||||
ForceModelPrefix: true,
|
ForceModelPrefix: true,
|
||||||
NonStreamKeepAliveInterval: 5,
|
NonStreamKeepAliveInterval: 5,
|
||||||
DisableImageGeneration: true,
|
DisableImageGeneration: config.DisableImageGenerationAll,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
|||||||
RequestLog: true,
|
RequestLog: true,
|
||||||
ProxyURL: "http://new-proxy",
|
ProxyURL: "http://new-proxy",
|
||||||
APIKeys: []string{"keyB"},
|
APIKeys: []string{"keyB"},
|
||||||
DisableImageGeneration: true,
|
DisableImageGeneration: config.DisableImageGenerationAll,
|
||||||
},
|
},
|
||||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||||
OpenAICompatibility: []config.OpenAICompatibility{
|
OpenAICompatibility: []config.OpenAICompatibility{
|
||||||
|
|||||||
@@ -198,9 +198,14 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||||
// Only include it if the client explicitly provides it.
|
// Only include it if the client explicitly provides it.
|
||||||
key := ""
|
key := ""
|
||||||
|
requestPath := ""
|
||||||
if ctx != nil {
|
if ctx != nil {
|
||||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
|
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
|
||||||
|
requestPath = strings.TrimSpace(ginCtx.FullPath())
|
||||||
|
if requestPath == "" && ginCtx.Request.URL != nil {
|
||||||
|
requestPath = strings.TrimSpace(ginCtx.Request.URL.Path)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -208,6 +213,9 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
if key != "" {
|
if key != "" {
|
||||||
meta[idempotencyKeyMetadataKey] = key
|
meta[idempotencyKeyMetadataKey] = key
|
||||||
}
|
}
|
||||||
|
if requestPath != "" {
|
||||||
|
meta[coreexecutor.RequestPathMetadataKey] = requestPath
|
||||||
|
}
|
||||||
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
|
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
|
||||||
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
|
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -198,7 +199,7 @@ func parseBoolField(raw string, fallback bool) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||||
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration {
|
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration == internalconfig.DisableImageGenerationAll {
|
||||||
c.AbortWithStatus(http.StatusNotFound)
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -286,7 +287,7 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *OpenAIAPIHandler) ImagesEdits(c *gin.Context) {
|
func (h *OpenAIAPIHandler) ImagesEdits(c *gin.Context) {
|
||||||
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration {
|
if h != nil && h.BaseAPIHandler != nil && h.BaseAPIHandler.Cfg != nil && h.BaseAPIHandler.Cfg.DisableImageGeneration == internalconfig.DisableImageGenerationAll {
|
||||||
c.AbortWithStatus(http.StatusNotFound)
|
c.AbortWithStatus(http.StatusNotFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -97,7 +98,7 @@ func TestImagesEditsMultipartRejectsUnsupportedModel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestImagesGenerations_DisableImageGeneration_Returns404(t *testing.T) {
|
func TestImagesGenerations_DisableImageGeneration_Returns404(t *testing.T) {
|
||||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil)
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: internalconfig.DisableImageGenerationAll}, nil)
|
||||||
handler := NewOpenAIAPIHandler(base)
|
handler := NewOpenAIAPIHandler(base)
|
||||||
body := strings.NewReader(`{"prompt":"draw a square"}`)
|
body := strings.NewReader(`{"prompt":"draw a square"}`)
|
||||||
|
|
||||||
@@ -109,7 +110,7 @@ func TestImagesGenerations_DisableImageGeneration_Returns404(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestImagesEdits_DisableImageGeneration_Returns404(t *testing.T) {
|
func TestImagesEdits_DisableImageGeneration_Returns404(t *testing.T) {
|
||||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: true}, nil)
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: internalconfig.DisableImageGenerationAll}, nil)
|
||||||
handler := NewOpenAIAPIHandler(base)
|
handler := NewOpenAIAPIHandler(base)
|
||||||
body := strings.NewReader(`{"prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`)
|
body := strings.NewReader(`{"prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`)
|
||||||
|
|
||||||
@@ -119,3 +120,27 @@ func TestImagesEdits_DisableImageGeneration_Returns404(t *testing.T) {
|
|||||||
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String())
|
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusNotFound, resp.Body.String())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestImagesGenerations_DisableImageGenerationChat_DoesNotReturn404(t *testing.T) {
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: internalconfig.DisableImageGenerationChat}, nil)
|
||||||
|
handler := NewOpenAIAPIHandler(base)
|
||||||
|
body := strings.NewReader(`{"model":"gpt-5.4-mini","prompt":"draw a square"}`)
|
||||||
|
|
||||||
|
resp := performImagesEndpointRequest(t, imagesGenerationsPath, "application/json", body, handler.ImagesGenerations)
|
||||||
|
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestImagesEdits_DisableImageGenerationChat_DoesNotReturn404(t *testing.T) {
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{DisableImageGeneration: internalconfig.DisableImageGenerationChat}, nil)
|
||||||
|
handler := NewOpenAIAPIHandler(base)
|
||||||
|
body := strings.NewReader(`{"model":"gpt-5.4-mini","prompt":"edit this","images":[{"image_url":"data:image/png;base64,AA=="}]}`)
|
||||||
|
|
||||||
|
resp := performImagesEndpointRequest(t, imagesEditsPath, "application/json", body, handler.ImagesEdits)
|
||||||
|
|
||||||
|
if resp.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ import (
|
|||||||
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
||||||
const RequestedModelMetadataKey = "requested_model"
|
const RequestedModelMetadataKey = "requested_model"
|
||||||
|
|
||||||
|
// RequestPathMetadataKey stores the inbound HTTP request path (e.g. "/v1/images/generations") in Options.Metadata.
|
||||||
|
// It is optional and may be absent for non-HTTP executions.
|
||||||
|
const RequestPathMetadataKey = "request_path"
|
||||||
|
|
||||||
// DisallowFreeAuthMetadataKey instructs auth selection to skip known free-tier credentials.
|
// DisallowFreeAuthMetadataKey instructs auth selection to skip known free-tier credentials.
|
||||||
const DisallowFreeAuthMetadataKey = "disallow_free_auth"
|
const DisallowFreeAuthMetadataKey = "disallow_free_auth"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user