diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index 17c8170f..38c348f2 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, if err != nil { return resp, err } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, @@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth if err != nil { return nil, err } + endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 9ade4fbb..950141f0 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -76,11 +76,7 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") if isClaude { return e.executeClaudeNonStream(ctx, auth, req, opts) } @@ -114,7 +110,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -195,11 +191,6 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) @@ -214,7 +205,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return resp, err @@ -530,16 +521,12 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") + from := opts.SourceFormat to := sdktranslator.FromString("antigravity") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") - translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) @@ -554,7 +541,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya var lastErr error for idx, baseURL := range baseURLs { - httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) + httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL) if errReq != nil { err = errReq return nil, err @@ -692,11 +679,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut to := sdktranslator.FromString("antigravity") respCtx := context.WithValue(ctx, "alt", opts.Alt) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude") + isClaude := strings.Contains(strings.ToLower(req.Model), "claude") baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 2fbb235b..f74dc1e0 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } - } - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream) + body, _ = sjson.SetBytes(body, "model", model) // Inject thinking config based on model metadata for thinking variants - body = e.injectThinkingConfig(req.Model, req.Metadata, body) + body = e.injectThinkingConfig(model, req.Metadata, body) - if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { + if !strings.HasPrefix(model, "claude-3-5-haiku") { body = checkSystemInstructions(body) } - body = applyPayloadConfig(e.cfg, req.Model, body) + body = applyPayloadConfig(e.cfg, model, body) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) // Ensure max_tokens > thinking.budget_tokens when thinking is enabled - body = ensureMaxTokensForThinking(req.Model, body) + body = ensureMaxTokensForThinking(model, body) // Extract betas from body and convert to header var extraBetas []string @@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A defer reporter.trackFailure(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("claude") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } - } - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + body, _ = sjson.SetBytes(body, "model", model) // Inject thinking config based on model metadata for thinking variants - body = e.injectThinkingConfig(req.Model, req.Metadata, body) + body = e.injectThinkingConfig(model, req.Metadata, body) body = checkSystemInstructions(body) - body = applyPayloadConfig(e.cfg, req.Model, body) + body = applyPayloadConfig(e.cfg, model, body) // Disable thinking if tool_choice forces tool use (Anthropic API constraint) body = disableThinkingIfToolChoiceForced(body) // Ensure max_tokens > thinking.budget_tokens when thinking is enabled - body = ensureMaxTokensForThinking(req.Model, body) + body = ensureMaxTokensForThinking(model, body) // Extract betas from body and convert to header var extraBetas []string @@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. stream := from != to - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } - } - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream) + body, _ = sjson.SetBytes(body, "model", model) - if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { + if !strings.HasPrefix(model, "claude-3-5-haiku") { body = checkSystemInstructions(body) } diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 310988c1..98678c4d 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body = NormalizeThinkingConfig(body, model, false) + if errValidate := ValidateThinkingConfig(body, model); errValidate != nil { return resp, errValidate } - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.DeleteBytes(body, "previous_response_id") @@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) - body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body = NormalizeThinkingConfig(body, model, false) + if errValidate := ValidateThinkingConfig(body, model); errValidate != nil { return nil, errValidate } - body = applyPayloadConfig(e.cfg, req.Model, body) + body = applyPayloadConfig(e.cfg, model, body) body, _ = sjson.DeleteBytes(body, "previous_response_id") - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", model) url := strings.TrimSuffix(baseURL, "/") + "/responses" httpReq, err := e.cacheHelper(ctx, from, url, req, body) @@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel == "" { - upstreamModel = req.Model - } - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("codex") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) - modelForCounting := upstreamModel - - body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false) + body, _ = sjson.SetBytes(body, "model", model) body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.SetBytes(body, "stream", false) - enc, err := tokenizerForCodexModel(modelForCounting) + enc, err := tokenizerForCodexModel(model) if err != nil { return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err) } diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index b171041a..a3b75839 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -318,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut out := make(chan cliproxyexecutor.StreamChunk) stream = out - go func(resp *http.Response, reqBody []byte, attempt string) { + go func(resp *http.Response, reqBody []byte, attemptModel string) { defer close(out) defer func() { if errClose := resp.Body.Close(); errClose != nil { @@ -336,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut reporter.publish(ctx, detail) } if bytes.HasPrefix(line, dataTag) { - segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } } } - segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } @@ -365,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut appendAPIResponseChunk(ctx, e.cfg, data) reporter.publish(ctx, parseGeminiCLIUsage(data)) var param any - segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m) + segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } - segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) + segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m) for i := range segments { out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} } @@ -417,6 +417,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. var lastStatus int var lastBody []byte + // The loop variable attemptModel is only used as the concrete model id sent to the upstream + // Gemini CLI endpoint when iterating fallback variants. for _, attemptModel := range models { payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) @@ -425,7 +427,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "request.safetySettings") payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) - payload = fixGeminiCLIImageAspectRatio(attemptModel, payload) + payload = fixGeminiCLIImageAspectRatio(req.Model, payload) tok, errTok := tokenSource.Token() if errTok != nil { diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index da57150d..d69044b8 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -77,26 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override } // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - body = ApplyThinkingMetadata(body, req.Metadata, req.Model) - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + body = ApplyThinkingMetadata(body, req.Metadata, model) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) action := "generateContent" if req.Metadata != nil { @@ -105,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } } baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action) + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -180,28 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - body = ApplyThinkingMetadata(body, req.Metadata, req.Model) - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + body = ApplyThinkingMetadata(body, req.Metadata, model) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -301,29 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { apiKey, bearer := geminiCreds(auth) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { - upstreamModel = modelOverride - } else if !strings.EqualFold(upstreamModel, req.Model) { - if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { - upstreamModel = modelOverride - } + model := req.Model + if override := e.resolveUpstreamModel(model, auth); override != "" { + model = override } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model) - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model) + translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(model, translatedReq) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", model) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "countTokens") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens") requestBody := bytes.NewReader(translatedReq) diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index df8ee506..f8f4a63a 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - from := opts.SourceFormat to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) @@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", req.Model) action := "generateContent" if req.Metadata != nil { @@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } } baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) action := "generateContent" if req.Metadata != nil { @@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action) + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action) if opts.Alt != "" && action != "countTokens" { url = url + fmt.Sprintf("?$alt=%s", opts.Alt) } @@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - from := opts.SourceFormat to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) @@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body, _ = sjson.SetBytes(body, "model", req.Model) baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) - body = util.NormalizeGeminiThinkingBudget(req.Model, body) - body = util.StripThinkingConfigIfUnsupported(req.Model, body) - body = fixGeminiImageAspectRatio(req.Model, body) - body = applyPayloadConfig(e.cfg, req.Model, body) - body, _ = sjson.SetBytes(body, "model", upstreamModel) + body = util.ApplyDefaultThinkingIfNeeded(model, body) + body = util.NormalizeGeminiThinkingBudget(model, body) + body = util.StripThinkingConfigIfUnsupported(model, body) + body = fixGeminiImageAspectRatio(model, body) + body = applyPayloadConfig(e.cfg, model, body) + body, _ = sjson.SetBytes(body, "model", model) // For API key auth, use simpler URL format without project/location if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent") + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent") if opts.Alt == "" { url = url + "?alt=sse" } else { @@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth // countTokensWithServiceAccount counts tokens using service account credentials. func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - from := opts.SourceFormat to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) @@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context } translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) if errNewReq != nil { @@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context // countTokensWithAPIKey handles token counting using API key credentials. func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + model := req.Model + if override := e.resolveUpstreamModel(req.Model, auth); override != "" { + model = override + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) - if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) { if budgetOverride != nil { - norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + norm := util.NormalizeThinkingBudget(model, *budgetOverride) budgetOverride = &norm } translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) + translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(model, translatedReq) + translatedReq, _ = sjson.SetBytes(translatedReq, "model", model) respCtx := context.WithValue(ctx, "alt", opts.Alt) translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") @@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth * if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens") httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) if errNewReq != nil { @@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau } return tok.AccessToken, nil } + +// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration. +// It matches the requested model alias against configured models and returns the actual upstream name. +func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string { + trimmed := strings.TrimSpace(alias) + if trimmed == "" { + return "" + } + + entry := e.resolveVertexConfig(auth) + if entry == nil { + return "" + } + + normalizedModel, metadata := util.NormalizeThinkingModel(trimmed) + + // Candidate names to match against configured aliases/names. + candidates := []string{strings.TrimSpace(normalizedModel)} + if !strings.EqualFold(normalizedModel, trimmed) { + candidates = append(candidates, trimmed) + } + if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) { + candidates = append(candidates, original) + } + + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + modelAlias := strings.TrimSpace(model.Alias) + + for _, candidate := range candidates { + if candidate == "" { + continue + } + if modelAlias != "" && strings.EqualFold(modelAlias, candidate) { + if name != "" { + return name + } + return candidate + } + if name != "" && strings.EqualFold(name, candidate) { + return name + } + } + } + return "" +} + +// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth. +func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey { + if auth == nil || e.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range e.cfg.VertexCompatAPIKey { + entry := &e.cfg.VertexCompatAPIKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range e.cfg.VertexCompatAPIKey { + entry := &e.cfg.VertexCompatAPIKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 124a984e..49fd4eb7 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -58,12 +58,9 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re to := sdktranslator.FromString("openai") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return resp, errValidate } body = applyIFlowThinkingConfig(body) @@ -151,12 +148,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return nil, errValidate } body = applyIFlowThinkingConfig(body) diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 1c57c9b7..81fc31a1 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) allowCompat := e.allowCompatReasoningEffort(req.Model, auth) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" && modelOverride == "" { - translated, _ = sjson.SetBytes(translated, "model", upstreamModel) - } - translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat) - if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil { + translated = NormalizeThinkingConfig(translated, req.Model, allowCompat) + if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil { return resp, errValidate } @@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) allowCompat := e.allowCompatReasoningEffort(req.Model, auth) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" && modelOverride == "" { - translated, _ = sjson.SetBytes(translated, "model", upstreamModel) - } - translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat) - if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil { + translated = NormalizeThinkingConfig(translated, req.Model, allowCompat) + if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil { return nil, errValidate } diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 1d4ef52d..ff6fa414 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -12,7 +12,6 @@ import ( qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/util" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -52,12 +51,9 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req to := sdktranslator.FromString("openai") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return resp, errValidate } body = applyPayloadConfig(e.cfg, req.Model, body) @@ -132,12 +128,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) - upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) - if upstreamModel != "" { - body, _ = sjson.SetBytes(body, "model", upstreamModel) - } - body = NormalizeThinkingConfig(body, upstreamModel, false) - if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { + body, _ = sjson.SetBytes(body, "model", req.Model) + body = NormalizeThinkingConfig(body, req.Model, false) + if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil { return nil, errValidate } toolsResult := gjson.GetBytes(body, "tools") diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index a6eaf3c5..c480d965 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.Execute(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { @@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string } execReq := req execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) - execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) + execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { rerr := &Error{Message: errStream.Error()} diff --git a/sdk/cliproxy/auth/model_name_mappings.go b/sdk/cliproxy/auth/model_name_mappings.go index 483cb9c9..f1b31aa5 100644 --- a/sdk/cliproxy/auth/model_name_mappings.go +++ b/sdk/cliproxy/auth/model_name_mappings.go @@ -65,17 +65,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod m.modelNameMappings.Store(table) } -func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { - original := m.resolveOAuthUpstreamModel(auth, requestedModel) - if original == "" { - return metadata - } - if metadata != nil { - if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { - if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { - return metadata - } - } +// applyOAuthModelMapping resolves the upstream model from OAuth model mappings +// and returns the resolved model along with updated metadata. If a mapping exists, +// the returned model is the upstream model and metadata contains the original +// requested model for response translation. +func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) { + upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel) + if upstreamModel == "" { + return requestedModel, metadata } out := make(map[string]any, 1) if len(metadata) > 0 { @@ -84,8 +81,8 @@ func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel stri out[k] = v } } - out[util.ModelMappingOriginalModelMetadataKey] = original - return out + out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel + return upstreamModel, out } func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {