Merge pull request #795 from router-for-me/modelmappings

refactor(executor): resolve upstream model at conductor level before execution
This commit is contained in:
Chén Mù
2025-12-30 05:31:19 -08:00
committed by GitHub
12 changed files with 260 additions and 263 deletions
@@ -59,6 +59,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
if err != nil { if err != nil {
return resp, err return resp, err
} }
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,
@@ -113,6 +114,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if err != nil { if err != nil {
return nil, err return nil, err
} }
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{ wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost, Method: http.MethodPost,
@@ -76,11 +76,7 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
// Execute performs a non-streaming request to the Antigravity API. // Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
if isClaude { if isClaude {
return e.executeClaudeNonStream(ctx, auth, req, opts) return e.executeClaudeNonStream(ctx, auth, req, opts)
} }
@@ -114,7 +110,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, false, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -195,11 +191,6 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
@@ -214,7 +205,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return resp, err return resp, err
@@ -530,16 +521,12 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model) translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
@@ -554,7 +541,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
var lastErr error var lastErr error
for idx, baseURL := range baseURLs { for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, upstreamModel, translated, true, opts.Alt, baseURL) httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
if errReq != nil { if errReq != nil {
err = errReq err = errReq
return nil, err return nil, err
@@ -692,11 +679,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
to := sdktranslator.FromString("antigravity") to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) isClaude := strings.Contains(strings.ToLower(req.Model), "claude")
if upstreamModel == "" {
upstreamModel = req.Model
}
isClaude := strings.Contains(strings.ToLower(upstreamModel), "claude")
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
+24 -45
View File
@@ -49,36 +49,29 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude. // Use streaming translation to preserve function calling, except for claude.
stream := from != to stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", model)
if upstreamModel == "" {
upstreamModel = req.Model
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
// Inject thinking config based on model metadata for thinking variants // Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body) body = e.injectThinkingConfig(model, req.Metadata, body)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint) // Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body) body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body) body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header // Extract betas from body and convert to header
var extraBetas []string var extraBetas []string
@@ -170,29 +163,22 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) model := req.Model
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
if upstreamModel == "" { model = override
upstreamModel = req.Model
} }
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
upstreamModel = modelOverride body, _ = sjson.SetBytes(body, "model", model)
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
// Inject thinking config based on model metadata for thinking variants // Inject thinking config based on model metadata for thinking variants
body = e.injectThinkingConfig(req.Model, req.Metadata, body) body = e.injectThinkingConfig(model, req.Metadata, body)
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint) // Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body) body = disableThinkingIfToolChoiceForced(body)
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled // Ensure max_tokens > thinking.budget_tokens when thinking is enabled
body = ensureMaxTokensForThinking(req.Model, body) body = ensureMaxTokensForThinking(model, body)
// Extract betas from body and convert to header // Extract betas from body and convert to header
var extraBetas []string var extraBetas []string
@@ -316,21 +302,14 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
to := sdktranslator.FromString("claude") to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude. // Use streaming translation to preserve function calling, except for claude.
stream := from != to stream := from != to
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) model := req.Model
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
if upstreamModel == "" { model = override
upstreamModel = req.Model
} }
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), stream)
upstreamModel = modelOverride body, _ = sjson.SetBytes(body, "model", model)
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
}
body, _ = sjson.SetBytes(body, "model", upstreamModel)
if !strings.HasPrefix(upstreamModel, "claude-3-5-haiku") { if !strings.HasPrefix(model, "claude-3-5-haiku") {
body = checkSystemInstructions(body) body = checkSystemInstructions(body)
} }
+25 -48
View File
@@ -49,28 +49,21 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false) body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
body, _ = sjson.SetBytes(body, "stream", true) body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
@@ -156,30 +149,23 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body = NormalizeThinkingConfig(body, upstreamModel, false) body = NormalizeThinkingConfig(body, model, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil { if errValidate := ValidateThinkingConfig(body, model); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
url := strings.TrimSuffix(baseURL, "/") + "/responses" url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body) httpReq, err := e.cacheHelper(ctx, from, url, req, body)
@@ -266,30 +252,21 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
} }
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if upstreamModel == "" { if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
upstreamModel = req.Model model = override
}
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" {
upstreamModel = modelOverride
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("codex") to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
modelForCounting := upstreamModel body = ApplyReasoningEffortMetadata(body, req.Metadata, model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", model)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
body, _ = sjson.SetBytes(body, "model", upstreamModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id") body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.SetBytes(body, "stream", false) body, _ = sjson.SetBytes(body, "stream", false)
enc, err := tokenizerForCodexModel(modelForCounting) enc, err := tokenizerForCodexModel(model)
if err != nil { if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err) return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
} }
@@ -318,7 +318,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
out := make(chan cliproxyexecutor.StreamChunk) out := make(chan cliproxyexecutor.StreamChunk)
stream = out stream = out
go func(resp *http.Response, reqBody []byte, attempt string) { go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out) defer close(out)
defer func() { defer func() {
if errClose := resp.Body.Close(); errClose != nil { if errClose := resp.Body.Close(); errClose != nil {
@@ -336,14 +336,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
reporter.publish(ctx, detail) reporter.publish(ctx, detail)
} }
if bytes.HasPrefix(line, dataTag) { if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
} }
} }
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
@@ -365,12 +365,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
appendAPIResponseChunk(ctx, e.cfg, data) appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data)) reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, data, &param) segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
segments = sdktranslator.TranslateStream(respCtx, to, from, attempt, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param) segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), &param)
for i := range segments { for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])} out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
} }
@@ -417,6 +417,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
var lastStatus int var lastStatus int
var lastBody []byte var lastBody []byte
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for _, attemptModel := range models { for _, attemptModel := range models {
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model) payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
@@ -425,7 +427,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
payload = deleteJSONField(payload, "model") payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings") payload = deleteJSONField(payload, "request.safetySettings")
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiCLIImageAspectRatio(attemptModel, payload) payload = fixGeminiCLIImageAspectRatio(req.Model, payload)
tok, errTok := tokenSource.Token() tok, errTok := tokenSource.Token()
if errTok != nil { if errTok != nil {
+33 -45
View File
@@ -77,26 +77,22 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { if override := e.resolveUpstreamModel(model, auth); override != "" {
upstreamModel = modelOverride model = override
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
// Official Gemini API via API key or OAuth bearer // Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model) body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -105,7 +101,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} }
} }
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, action) url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -180,28 +176,24 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { if override := e.resolveUpstreamModel(model, auth); override != "" {
upstreamModel = modelOverride model = override
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
body = ApplyThinkingMetadata(body, req.Metadata, req.Model) body = ApplyThinkingMetadata(body, req.Metadata, model)
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -301,29 +293,25 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
apiKey, bearer := geminiCreds(auth) apiKey, bearer := geminiCreds(auth)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { if override := e.resolveUpstreamModel(model, auth); override != "" {
upstreamModel = modelOverride model = override
} else if !strings.EqualFold(upstreamModel, req.Model) {
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
upstreamModel = modelOverride
}
} }
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model) translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, model)
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
baseURL := resolveGeminiBaseURL(auth) baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "countTokens") url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, model, "countTokens")
requestBody := bytes.NewReader(translatedReq) requestBody := bytes.NewReader(translatedReq)
@@ -120,8 +120,6 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -137,7 +135,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", req.Model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -146,7 +144,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
} }
} }
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action) url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -220,24 +218,27 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
action := "generateContent" action := "generateContent"
if req.Metadata != nil { if req.Metadata != nil {
@@ -250,7 +251,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action) url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, action)
if opts.Alt != "" && action != "countTokens" { if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt) url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
} }
@@ -321,8 +322,6 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
@@ -338,10 +337,10 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(req.Model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", req.Model)
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -438,30 +437,33 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err) defer reporter.trackFailure(ctx, &err)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
} }
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body) body = util.ApplyDefaultThinkingIfNeeded(model, body)
body = util.NormalizeGeminiThinkingBudget(req.Model, body) body = util.NormalizeGeminiThinkingBudget(model, body)
body = util.StripThinkingConfigIfUnsupported(req.Model, body) body = util.StripThinkingConfigIfUnsupported(model, body)
body = fixGeminiImageAspectRatio(req.Model, body) body = fixGeminiImageAspectRatio(model, body)
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, model, body)
body, _ = sjson.SetBytes(body, "model", upstreamModel) body, _ = sjson.SetBytes(body, "model", model)
// For API key auth, use simpler URL format without project/location // For API key auth, use simpler URL format without project/location
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "streamGenerateContent")
if opts.Alt == "" { if opts.Alt == "" {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -552,8 +554,6 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
// countTokensWithServiceAccount counts tokens using service account credentials. // countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
@@ -566,14 +566,14 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
} }
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", req.Model)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location) baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens") url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil { if errNewReq != nil {
@@ -641,21 +641,24 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
// countTokensWithAPIKey handles token counting using API key credentials. // countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) model := req.Model
if override := e.resolveUpstreamModel(req.Model, auth); override != "" {
model = override
}
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("gemini") to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) translatedReq := sdktranslator.TranslateRequest(from, to, model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, req.Metadata); ok && util.ModelSupportsThinking(model) {
if budgetOverride != nil { if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) norm := util.NormalizeThinkingBudget(model, *budgetOverride)
budgetOverride = &norm budgetOverride = &norm
} }
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
} }
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) translatedReq = util.StripThinkingConfigIfUnsupported(model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) translatedReq = fixGeminiImageAspectRatio(model, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) translatedReq, _ = sjson.SetBytes(translatedReq, "model", model)
respCtx := context.WithValue(ctx, "alt", opts.Alt) respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -665,7 +668,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
if baseURL == "" { if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com" baseURL = "https://generativelanguage.googleapis.com"
} }
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil { if errNewReq != nil {
@@ -808,3 +811,90 @@ func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyau
} }
return tok.AccessToken, nil return tok.AccessToken, nil
} }
// resolveUpstreamModel resolves the upstream model name from vertex-api-key configuration.
// It matches the requested model alias against configured models and returns the actual upstream name.
func (e *GeminiVertexExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
trimmed := strings.TrimSpace(alias)
if trimmed == "" {
return ""
}
entry := e.resolveVertexConfig(auth)
if entry == nil {
return ""
}
normalizedModel, metadata := util.NormalizeThinkingModel(trimmed)
// Candidate names to match against configured aliases/names.
candidates := []string{strings.TrimSpace(normalizedModel)}
if !strings.EqualFold(normalizedModel, trimmed) {
candidates = append(candidates, trimmed)
}
if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) {
candidates = append(candidates, original)
}
for i := range entry.Models {
model := entry.Models[i]
name := strings.TrimSpace(model.Name)
modelAlias := strings.TrimSpace(model.Alias)
for _, candidate := range candidates {
if candidate == "" {
continue
}
if modelAlias != "" && strings.EqualFold(modelAlias, candidate) {
if name != "" {
return name
}
return candidate
}
if name != "" && strings.EqualFold(name, candidate) {
return name
}
}
}
return ""
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
+6 -12
View File
@@ -58,12 +58,9 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
@@ -151,12 +148,9 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
body = applyIFlowThinkingConfig(body) body = applyIFlowThinkingConfig(body)
@@ -61,12 +61,8 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth) allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if upstreamModel != "" && modelOverride == "" { if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
@@ -157,12 +153,8 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
allowCompat := e.allowCompatReasoningEffort(req.Model, auth) allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat) translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) translated = NormalizeThinkingConfig(translated, req.Model, allowCompat)
if upstreamModel != "" && modelOverride == "" { if errValidate := ValidateThinkingConfig(translated, req.Model); errValidate != nil {
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
}
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
+6 -13
View File
@@ -12,7 +12,6 @@ import (
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen" qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -52,12 +51,9 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return resp, errValidate return resp, errValidate
} }
body = applyPayloadConfig(e.cfg, req.Model, body) body = applyPayloadConfig(e.cfg, req.Model, body)
@@ -132,12 +128,9 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false) body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) body, _ = sjson.SetBytes(body, "model", req.Model)
if upstreamModel != "" { body = NormalizeThinkingConfig(body, req.Model, false)
body, _ = sjson.SetBytes(body, "model", upstreamModel) if errValidate := ValidateThinkingConfig(body, req.Model); errValidate != nil {
}
body = NormalizeThinkingConfig(body, upstreamModel, false)
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
return nil, errValidate return nil, errValidate
} }
toolsResult := gjson.GetBytes(body, "tools") toolsResult := gjson.GetBytes(body, "tools")
+3 -3
View File
@@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts) resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil { if errExec != nil {
@@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
} }
execReq := req execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata) execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil { if errStream != nil {
rerr := &Error{Message: errStream.Error()} rerr := &Error{Message: errStream.Error()}
+10 -13
View File
@@ -65,17 +65,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod
m.modelNameMappings.Store(table) m.modelNameMappings.Store(table)
} }
func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { // applyOAuthModelMapping resolves the upstream model from OAuth model mappings
original := m.resolveOAuthUpstreamModel(auth, requestedModel) // and returns the resolved model along with updated metadata. If a mapping exists,
if original == "" { // the returned model is the upstream model and metadata contains the original
return metadata // requested model for response translation.
} func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
if metadata != nil { upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { if upstreamModel == "" {
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { return requestedModel, metadata
return metadata
}
}
} }
out := make(map[string]any, 1) out := make(map[string]any, 1)
if len(metadata) > 0 { if len(metadata) > 0 {
@@ -84,8 +81,8 @@ func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel stri
out[k] = v out[k] = v
} }
} }
out[util.ModelMappingOriginalModelMetadataKey] = original out[util.ModelMappingOriginalModelMetadataKey] = upstreamModel
return out return upstreamModel, out
} }
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string { func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {