diff --git a/config.example.yaml b/config.example.yaml index 73e2a8ac..2a35fe68 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -90,6 +90,9 @@ ws-auth: false # headers: # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" +# models: +# - name: "gemini-2.5-flash" # upstream model name +# alias: "gemini-flash" # client alias mapped to the upstream model # excluded-models: # - "gemini-2.5-pro" # exclude specific models from this provider (exact match) # - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro) @@ -106,7 +109,7 @@ ws-auth: false # X-Custom-Header: "custom-value" # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # models: -# - name: "gpt-5-codex" # upstream model name +# - name: "gpt-5-codex" # upstream model name # alias: "codex-latest" # client alias mapped to the upstream model # excluded-models: # - "gpt-5.1" # exclude specific models (exact match) @@ -125,7 +128,7 @@ ws-auth: false # proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override # models: # - name: "claude-3-5-sonnet-20241022" # upstream model name -# alias: "claude-sonnet-latest" # client alias mapped to the upstream model +# alias: "claude-sonnet-latest" # client alias mapped to the upstream model # excluded-models: # - "claude-opus-4-5-20251101" # exclude specific models (exact match) # - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219) diff --git a/internal/config/config.go b/internal/config/config.go index 760be600..668764d9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -268,6 +268,9 @@ type ClaudeModel struct { Alias string `yaml:"alias" json:"alias"` } +func (m ClaudeModel) GetName() string { return m.Name } +func (m ClaudeModel) GetAlias() string { return m.Alias } + // CodexKey represents the configuration for a Codex API key, // including the API key itself and an optional base URL for the API endpoint. type CodexKey struct { @@ -303,6 +306,9 @@ type CodexModel struct { Alias string `yaml:"alias" json:"alias"` } +func (m CodexModel) GetName() string { return m.Name } +func (m CodexModel) GetAlias() string { return m.Alias } + // GeminiKey represents the configuration for a Gemini API key, // including optional overrides for upstream base URL, proxy routing, and headers. type GeminiKey struct { @@ -318,6 +324,9 @@ type GeminiKey struct { // ProxyURL optionally overrides the global proxy for this API key. ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + // Models defines upstream model names and aliases for request routing. + Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"` + // Headers optionally adds extra HTTP headers for requests sent with this key. Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` @@ -325,6 +334,18 @@ type GeminiKey struct { ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"` } +// GeminiModel describes a mapping between an alias and the actual upstream model name. +type GeminiModel struct { + // Name is the upstream model identifier used when issuing requests. + Name string `yaml:"name" json:"name"` + + // Alias is the client-facing model name that maps to Name. + Alias string `yaml:"alias" json:"alias"` +} + +func (m GeminiModel) GetName() string { return m.Name } +func (m GeminiModel) GetAlias() string { return m.Alias } + // OpenAICompatibility represents the configuration for OpenAI API compatibility // with external providers, allowing model aliases to be routed through OpenAI API format. type OpenAICompatibility struct { diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index a14f75bc..94e162b7 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -42,6 +42,9 @@ type VertexCompatModel struct { Alias string `yaml:"alias" json:"alias"` } +func (m VertexCompatModel) GetName() string { return m.Name } +func (m VertexCompatModel) GetAlias() string { return m.Alias } + // SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. func (cfg *Config) SanitizeVertexCompatKeys() { if cfg == nil { diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 1c51e898..ed4d1c21 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -781,3 +781,29 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { "gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000}, } } + +// LookupStaticModelInfo searches all static model definitions for a model by ID. +// Returns nil if no matching model is found. +func LookupStaticModelInfo(modelID string) *ModelInfo { + if modelID == "" { + return nil + } + allModels := [][]*ModelInfo{ + GetClaudeModels(), + GetGeminiModels(), + GetGeminiVertexModels(), + GetGeminiCLIModels(), + GetAIStudioModels(), + GetOpenAIModels(), + GetQwenModels(), + GetIFlowModels(), + } + for _, models := range allModels { + for _, m := range models { + if m != nil && m.ID == modelID { + return m + } + } + } + return nil +} diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index f211ba62..da57150d 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -78,6 +78,13 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r defer reporter.trackFailure(ctx, &err) upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { + upstreamModel = modelOverride + } else if !strings.EqualFold(upstreamModel, req.Model) { + if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { + upstreamModel = modelOverride + } + } // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat @@ -174,6 +181,13 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A defer reporter.trackFailure(ctx, &err) upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { + upstreamModel = modelOverride + } else if !strings.EqualFold(upstreamModel, req.Model) { + if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { + upstreamModel = modelOverride + } + } from := opts.SourceFormat to := sdktranslator.FromString("gemini") @@ -287,6 +301,15 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { apiKey, bearer := geminiCreds(auth) + upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata) + if modelOverride := e.resolveUpstreamModel(upstreamModel, auth); modelOverride != "" { + upstreamModel = modelOverride + } else if !strings.EqualFold(upstreamModel, req.Model) { + if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { + upstreamModel = modelOverride + } + } + from := opts.SourceFormat to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) @@ -297,9 +320,10 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel) baseURL := resolveGeminiBaseURL(auth) - url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens") + url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, upstreamModel, "countTokens") requestBody := bytes.NewReader(translatedReq) @@ -398,6 +422,90 @@ func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string { return base } +func (e *GeminiExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string { + trimmed := strings.TrimSpace(alias) + if trimmed == "" { + return "" + } + + entry := e.resolveGeminiConfig(auth) + if entry == nil { + return "" + } + + normalizedModel, metadata := util.NormalizeThinkingModel(trimmed) + + // Candidate names to match against configured aliases/names. + candidates := []string{strings.TrimSpace(normalizedModel)} + if !strings.EqualFold(normalizedModel, trimmed) { + candidates = append(candidates, trimmed) + } + if original := util.ResolveOriginalModel(normalizedModel, metadata); original != "" && !strings.EqualFold(original, normalizedModel) { + candidates = append(candidates, original) + } + + for i := range entry.Models { + model := entry.Models[i] + name := strings.TrimSpace(model.Name) + modelAlias := strings.TrimSpace(model.Alias) + + for _, candidate := range candidates { + if candidate == "" { + continue + } + if modelAlias != "" && strings.EqualFold(modelAlias, candidate) { + if name != "" { + return name + } + return candidate + } + if name != "" && strings.EqualFold(name, candidate) { + return name + } + } + } + return "" +} + +func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey { + if auth == nil || e.cfg == nil { + return nil + } + var attrKey, attrBase string + if auth.Attributes != nil { + attrKey = strings.TrimSpace(auth.Attributes["api_key"]) + attrBase = strings.TrimSpace(auth.Attributes["base_url"]) + } + for i := range e.cfg.GeminiKey { + entry := &e.cfg.GeminiKey[i] + cfgKey := strings.TrimSpace(entry.APIKey) + cfgBase := strings.TrimSpace(entry.BaseURL) + if attrKey != "" && attrBase != "" { + if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { + return entry + } + continue + } + if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { + if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { + return entry + } + } + if attrKey != "" { + for i := range e.cfg.GeminiKey { + entry := &e.cfg.GeminiKey[i] + if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { + return entry + } + } + } + return nil +} + func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) { var attrs map[string]string if auth != nil { diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index ae56e4b6..21690f8e 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -710,6 +710,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { case "gemini": models = registry.GetGeminiModels() if entry := s.resolveConfigGeminiKey(a); entry != nil { + if len(entry.Models) > 0 { + models = buildGeminiConfigModels(entry) + } if authKind == "apikey" { excluded = entry.ExcludedModels } @@ -1116,17 +1119,22 @@ func matchWildcard(pattern, value string) bool { return true } -func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { - if entry == nil || len(entry.Models) == 0 { +type modelEntry interface { + GetName() string + GetAlias() string +} + +func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo { + if len(models) == 0 { return nil } now := time.Now().Unix() - out := make([]*ModelInfo, 0, len(entry.Models)) - seen := make(map[string]struct{}, len(entry.Models)) - for i := range entry.Models { - model := entry.Models[i] - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) + out := make([]*ModelInfo, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for i := range models { + model := models[i] + name := strings.TrimSpace(model.GetName()) + alias := strings.TrimSpace(model.GetAlias()) if alias == "" { alias = name } @@ -1142,18 +1150,52 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { if display == "" { display = alias } - out = append(out, &ModelInfo{ + info := &ModelInfo{ ID: alias, Object: "model", Created: now, - OwnedBy: "vertex", - Type: "vertex", + OwnedBy: ownedBy, + Type: modelType, DisplayName: display, - }) + } + if name != "" { + if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil { + info.Thinking = upstream.Thinking + } + } + out = append(out, info) } return out } +func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "google", "vertex") +} + +func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "google", "gemini") +} + +func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "anthropic", "claude") +} + +func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo { + if entry == nil { + return nil + } + return buildConfigModels(entry.Models, "openai", "openai") +} + func rewriteModelInfoName(name, oldID, newID string) string { trimmed := strings.TrimSpace(name) if trimmed == "" { @@ -1240,79 +1282,3 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode } return out } - -func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo { - if entry == nil || len(entry.Models) == 0 { - return nil - } - now := time.Now().Unix() - out := make([]*ModelInfo, 0, len(entry.Models)) - seen := make(map[string]struct{}, len(entry.Models)) - for i := range entry.Models { - model := entry.Models[i] - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) - if alias == "" { - alias = name - } - if alias == "" { - continue - } - key := strings.ToLower(alias) - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - display := name - if display == "" { - display = alias - } - out = append(out, &ModelInfo{ - ID: alias, - Object: "model", - Created: now, - OwnedBy: "claude", - Type: "claude", - DisplayName: display, - }) - } - return out -} - -func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo { - if entry == nil || len(entry.Models) == 0 { - return nil - } - now := time.Now().Unix() - out := make([]*ModelInfo, 0, len(entry.Models)) - seen := make(map[string]struct{}, len(entry.Models)) - for i := range entry.Models { - model := entry.Models[i] - name := strings.TrimSpace(model.Name) - alias := strings.TrimSpace(model.Alias) - if alias == "" { - alias = name - } - if alias == "" { - continue - } - key := strings.ToLower(alias) - if _, exists := seen[key]; exists { - continue - } - seen[key] = struct{}{} - display := name - if display == "" { - display = alias - } - out = append(out, &ModelInfo{ - ID: alias, - Object: "model", - Created: now, - OwnedBy: "openai", - Type: "openai", - DisplayName: display, - }) - } - return out -}