diff --git a/config.example.yaml b/config.example.yaml index e93d71b6..563dd06c 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -48,6 +48,9 @@ usage-statistics-enabled: false # Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/ proxy-url: "" +# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name). +force-model-prefix: false + # Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504. request-retry: 3 @@ -65,6 +68,7 @@ ws-auth: false # Gemini API keys # gemini-api-key: # - api-key: "AIzaSy...01" +# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential # base-url: "https://generativelanguage.googleapis.com" # headers: # X-Custom-Header: "custom-value" @@ -79,6 +83,7 @@ ws-auth: false # Codex API keys # codex-api-key: # - api-key: "sk-atSM..." +# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential # base-url: "https://www.example.com" # use the custom codex API endpoint # headers: # X-Custom-Header: "custom-value" @@ -93,6 +98,7 @@ ws-auth: false # claude-api-key: # - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url # - api-key: "sk-atSM..." +# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential # base-url: "https://www.example.com" # use the custom claude API endpoint # headers: # X-Custom-Header: "custom-value" @@ -109,6 +115,7 @@ ws-auth: false # OpenAI compatibility providers # openai-compatibility: # - name: "openrouter" # The name of the provider; it will be used in the user agent and other places. +# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials # base-url: "https://openrouter.ai/api/v1" # The base URL of the provider. # headers: # X-Custom-Header: "custom-value" @@ -123,6 +130,7 @@ ws-auth: false # Vertex API keys (Vertex-compatible endpoints, use API key + base URL) # vertex-api-key: # - api-key: "vk-123..." # x-goog-api-key header +# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential # base-url: "https://example.com/api" # e.g. https://zenmux.ai/api # proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override # headers: diff --git a/internal/api/server.go b/internal/api/server.go index af28e6ad..5ffffa1d 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -230,13 +230,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk envManagementSecret := envAdminPasswordSet && envAdminPassword != "" // Create server instance - providerNames := make([]string, 0, len(cfg.OpenAICompatibility)) - for _, p := range cfg.OpenAICompatibility { - providerNames = append(providerNames, p.Name) - } s := &Server{ engine: engine, - handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames), + handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager), cfg: cfg, accessManager: accessManager, requestLogger: requestLogger, @@ -919,12 +915,6 @@ func (s *Server) UpdateClients(cfg *config.Config) { // Save YAML snapshot for next comparison s.oldConfigYaml, _ = yaml.Marshal(cfg) - providerNames := make([]string, 0, len(cfg.OpenAICompatibility)) - for _, p := range cfg.OpenAICompatibility { - providerNames = append(providerNames, p.Name) - } - s.handlers.OpenAICompatProviders = providerNames - s.handlers.UpdateClients(&cfg.SDKConfig) if !cfg.RemoteManagement.DisableControlPanel { diff --git a/internal/config/config.go b/internal/config/config.go index 2310d7c2..63ac1cb0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -187,6 +187,9 @@ type ClaudeKey struct { // APIKey is the authentication key for accessing Claude API services. APIKey string `yaml:"api-key" json:"api-key"` + // Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4"). + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` + // BaseURL is the base URL for the Claude API endpoint. // If empty, the default Claude API URL will be used. BaseURL string `yaml:"base-url" json:"base-url"` @@ -219,6 +222,9 @@ type CodexKey struct { // APIKey is the authentication key for accessing Codex API services. APIKey string `yaml:"api-key" json:"api-key"` + // Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex"). + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` + // BaseURL is the base URL for the Codex API endpoint. // If empty, the default Codex API URL will be used. BaseURL string `yaml:"base-url" json:"base-url"` @@ -239,6 +245,9 @@ type GeminiKey struct { // APIKey is the authentication key for accessing Gemini API services. APIKey string `yaml:"api-key" json:"api-key"` + // Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview"). + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` + // BaseURL optionally overrides the Gemini API endpoint. BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"` @@ -258,6 +267,9 @@ type OpenAICompatibility struct { // Name is the identifier for this OpenAI compatibility configuration. Name string `yaml:"name" json:"name"` + // Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2"). + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` + // BaseURL is the base URL for the external OpenAI-compatible API endpoint. BaseURL string `yaml:"base-url" json:"base-url"` @@ -422,6 +434,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() { for i := range cfg.OpenAICompatibility { e := cfg.OpenAICompatibility[i] e.Name = strings.TrimSpace(e.Name) + e.Prefix = normalizeModelPrefix(e.Prefix) e.BaseURL = strings.TrimSpace(e.BaseURL) e.Headers = NormalizeHeaders(e.Headers) if e.BaseURL == "" { @@ -442,6 +455,7 @@ func (cfg *Config) SanitizeCodexKeys() { out := make([]CodexKey, 0, len(cfg.CodexKey)) for i := range cfg.CodexKey { e := cfg.CodexKey[i] + e.Prefix = normalizeModelPrefix(e.Prefix) e.BaseURL = strings.TrimSpace(e.BaseURL) e.Headers = NormalizeHeaders(e.Headers) e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels) @@ -460,6 +474,7 @@ func (cfg *Config) SanitizeClaudeKeys() { } for i := range cfg.ClaudeKey { entry := &cfg.ClaudeKey[i] + entry.Prefix = normalizeModelPrefix(entry.Prefix) entry.Headers = NormalizeHeaders(entry.Headers) entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels) } @@ -479,6 +494,7 @@ func (cfg *Config) SanitizeGeminiKeys() { if entry.APIKey == "" { continue } + entry.Prefix = normalizeModelPrefix(entry.Prefix) entry.BaseURL = strings.TrimSpace(entry.BaseURL) entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) entry.Headers = NormalizeHeaders(entry.Headers) @@ -492,6 +508,18 @@ func (cfg *Config) SanitizeGeminiKeys() { cfg.GeminiKey = out } +func normalizeModelPrefix(prefix string) string { + trimmed := strings.TrimSpace(prefix) + trimmed = strings.Trim(trimmed, "/") + if trimmed == "" { + return "" + } + if strings.Contains(trimmed, "/") { + return "" + } + return trimmed +} + func syncInlineAccessProvider(cfg *Config) { if cfg == nil { return diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go index 1257dd62..a14f75bc 100644 --- a/internal/config/vertex_compat.go +++ b/internal/config/vertex_compat.go @@ -13,6 +13,9 @@ type VertexCompatKey struct { // Maps to the x-goog-api-key header. APIKey string `yaml:"api-key" json:"api-key"` + // Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro"). + Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` + // BaseURL is the base URL for the Vertex-compatible API endpoint. // The executor will append "/v1/publishers/google/models/{model}:action" to this. // Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..." @@ -53,6 +56,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() { if entry.APIKey == "" { continue } + entry.Prefix = normalizeModelPrefix(entry.Prefix) entry.BaseURL = strings.TrimSpace(entry.BaseURL) if entry.BaseURL == "" { // BaseURL is required for Vertex API key entries diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 43a3a3dc..68ff5394 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -183,7 +183,7 @@ func (w *Watcher) Start(ctx context.Context) error { go w.processEvents(ctx) // Perform an initial full reload based on current config and auth dir - w.reloadClients(true, nil) + w.reloadClients(true, nil, false) return nil } @@ -276,7 +276,7 @@ func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool { return true } -func (w *Watcher) refreshAuthState() { +func (w *Watcher) refreshAuthState(force bool) { auths := w.SnapshotCoreAuths() w.clientsMutex.Lock() if len(w.runtimeAuths) > 0 { @@ -286,12 +286,12 @@ func (w *Watcher) refreshAuthState() { } } } - updates := w.prepareAuthUpdatesLocked(auths) + updates := w.prepareAuthUpdatesLocked(auths, force) w.clientsMutex.Unlock() w.dispatchAuthUpdates(updates) } -func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth) []AuthUpdate { +func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate { newState := make(map[string]*coreauth.Auth, len(auths)) for _, auth := range auths { if auth == nil || auth.ID == "" { @@ -318,7 +318,7 @@ func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth) []AuthUpdate for id, auth := range newState { if existing, ok := w.currentAuths[id]; !ok { updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()}) - } else if !authEqual(existing, auth) { + } else if force || !authEqual(existing, auth) { updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()}) } } @@ -949,15 +949,16 @@ func (w *Watcher) reloadConfig() bool { } authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir + forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix log.Infof("config successfully reloaded, triggering client reload") // Reload clients with new config - w.reloadClients(authDirChanged, affectedOAuthProviders) + w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh) return true } // reloadClients performs a full scan and reload of all clients. -func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) { +func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) { log.Debugf("starting full client load process") w.clientsMutex.RLock() @@ -1048,7 +1049,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.reloadCallback(cfg) } - w.refreshAuthState() + w.refreshAuthState(forceAuthRefresh) log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", totalNewClients, @@ -1099,7 +1100,7 @@ func (w *Watcher) addOrUpdateClient(path string) { w.clientsMutex.Unlock() // Unlock before the callback - w.refreshAuthState() + w.refreshAuthState(false) if w.reloadCallback != nil { log.Debugf("triggering server update callback after add/update") @@ -1118,7 +1119,7 @@ func (w *Watcher) removeClient(path string) { w.clientsMutex.Unlock() // Release the lock before the callback - w.refreshAuthState() + w.refreshAuthState(false) if w.reloadCallback != nil { log.Debugf("triggering server update callback after removal") @@ -1147,6 +1148,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if key == "" { continue } + prefix := strings.TrimSpace(entry.Prefix) base := strings.TrimSpace(entry.BaseURL) proxyURL := strings.TrimSpace(entry.ProxyURL) id, token := idGen.next("gemini:apikey", key, base) @@ -1162,6 +1164,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { ID: id, Provider: "gemini", Label: "gemini-apikey", + Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, Attributes: attrs, @@ -1179,6 +1182,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if key == "" { continue } + prefix := strings.TrimSpace(ck.Prefix) base := strings.TrimSpace(ck.BaseURL) id, token := idGen.next("claude:apikey", key, base) attrs := map[string]string{ @@ -1197,6 +1201,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { ID: id, Provider: "claude", Label: "claude-apikey", + Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, Attributes: attrs, @@ -1213,6 +1218,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { if key == "" { continue } + prefix := strings.TrimSpace(ck.Prefix) id, token := idGen.next("codex:apikey", key, ck.BaseURL) attrs := map[string]string{ "source": fmt.Sprintf("config:codex[%s]", token), @@ -1227,6 +1233,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { ID: id, Provider: "codex", Label: "codex-apikey", + Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, Attributes: attrs, @@ -1238,6 +1245,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } for i := range cfg.OpenAICompatibility { compat := &cfg.OpenAICompatibility[i] + prefix := strings.TrimSpace(compat.Prefix) providerName := strings.ToLower(strings.TrimSpace(compat.Name)) if providerName == "" { providerName = "openai-compatibility" @@ -1269,6 +1277,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { ID: id, Provider: providerName, Label: compat.Name, + Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, Attributes: attrs, @@ -1295,6 +1304,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { ID: id, Provider: providerName, Label: compat.Name, + Prefix: prefix, Status: coreauth.StatusActive, Attributes: attrs, CreatedAt: now, @@ -1312,6 +1322,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { base := strings.TrimSpace(compat.BaseURL) key := strings.TrimSpace(compat.APIKey) + prefix := strings.TrimSpace(compat.Prefix) proxyURL := strings.TrimSpace(compat.ProxyURL) idKind := fmt.Sprintf("vertex:apikey:%s", base) id, token := idGen.next(idKind, key, base, proxyURL) @@ -1331,6 +1342,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { ID: id, Provider: providerName, Label: "vertex-apikey", + Prefix: prefix, Status: coreauth.StatusActive, ProxyURL: proxyURL, Attributes: attrs, @@ -1383,10 +1395,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { proxyURL = p } + prefix := "" + if rawPrefix, ok := metadata["prefix"].(string); ok { + trimmed := strings.TrimSpace(rawPrefix) + trimmed = strings.Trim(trimmed, "/") + if trimmed != "" && !strings.Contains(trimmed, "/") { + prefix = trimmed + } + } + a := &coreauth.Auth{ ID: id, Provider: provider, Label: label, + Prefix: prefix, Status: coreauth.StatusActive, Attributes: map[string]string{ "source": full, @@ -1473,6 +1495,7 @@ func synthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]an Attributes: attrs, Metadata: metadataCopy, ProxyURL: primary.ProxyURL, + Prefix: primary.Prefix, CreatedAt: now, UpdatedAt: now, Runtime: geminicli.NewVirtualCredential(projectID, shared), @@ -1742,6 +1765,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if oldCfg.WebsocketAuth != newCfg.WebsocketAuth { changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth)) } + if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix { + changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix)) + } // Quota-exceeded behavior if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index a17e54aa..e5b4fc93 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -49,9 +49,6 @@ type BaseAPIHandler struct { // Cfg holds the current application configuration. Cfg *config.SDKConfig - - // OpenAICompatProviders is a list of provider names for OpenAI compatibility. - OpenAICompatProviders []string } // NewBaseAPIHandlers creates a new API handlers instance. @@ -63,11 +60,10 @@ type BaseAPIHandler struct { // // Returns: // - *BaseAPIHandler: A new API handlers instance -func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler { +func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler { return &BaseAPIHandler{ - Cfg: cfg, - AuthManager: authManager, - OpenAICompatProviders: openAICompatProviders, + Cfg: cfg, + AuthManager: authManager, } } @@ -342,30 +338,19 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string // Resolve "auto" model to an actual available model first resolvedModelName := util.ResolveAutoModel(modelName) - providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName) - - targetModelName := resolvedModelName - if isDynamic { - targetModelName = extractedModelName - } - // Normalize the model name to handle dynamic thinking suffixes before determining the provider. - normalizedModel, metadata = normalizeModelMetadata(targetModelName) + normalizedModel, metadata = normalizeModelMetadata(resolvedModelName) - if isDynamic { - providers = []string{providerName} - } else { - // For non-dynamic models, use the normalizedModel to get the provider name. - providers = util.GetProviderName(normalizedModel) - if len(providers) == 0 && metadata != nil { - if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok { - if originalModel, okStr := originalRaw.(string); okStr { - originalModel = strings.TrimSpace(originalModel) - if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) { - if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 { - providers = altProviders - normalizedModel = originalModel - } + // Use the normalizedModel to get the provider name. + providers = util.GetProviderName(normalizedModel) + if len(providers) == 0 && metadata != nil { + if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok { + if originalModel, okStr := originalRaw.(string); okStr { + originalModel = strings.TrimSpace(originalModel) + if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) { + if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 { + providers = altProviders + normalizedModel = originalModel } } } @@ -383,30 +368,6 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string return providers, normalizedModel, metadata, nil } -func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) { - var providerPart, modelPart string - for _, sep := range []string{"://"} { - if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 { - providerPart = parts[0] - modelPart = parts[1] - break - } - } - - if providerPart == "" { - return "", modelName, false - } - - // Check if the provider is a configured openai-compatibility provider - for _, pName := range h.OpenAICompatProviders { - if pName == providerPart { - return providerPart, modelPart, true - } - } - - return "", modelName, false -} - func cloneBytes(src []byte) []byte { if len(src) == 0 { return nil diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index 9f247bb9..c345cd15 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -363,10 +363,11 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req if provider == "" { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"} } + routeModel := req.Model tried := make(map[string]struct{}) var lastErr error for { - auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried) + auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried) if errPick != nil { if lastErr != nil { return cliproxyexecutor.Response{}, lastErr @@ -396,8 +397,10 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - resp, errExec := executor.Execute(execCtx, auth, req, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil} + execReq := req + execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + resp, errExec := executor.Execute(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { result.Error = &Error{Message: errExec.Error()} var se cliproxyexecutor.StatusError @@ -420,10 +423,11 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, if provider == "" { return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"} } + routeModel := req.Model tried := make(map[string]struct{}) var lastErr error for { - auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried) + auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried) if errPick != nil { if lastErr != nil { return cliproxyexecutor.Response{}, lastErr @@ -453,8 +457,10 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - resp, errExec := executor.CountTokens(execCtx, auth, req, opts) - result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil} + execReq := req + execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil} if errExec != nil { result.Error = &Error{Message: errExec.Error()} var se cliproxyexecutor.StatusError @@ -477,10 +483,11 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string if provider == "" { return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"} } + routeModel := req.Model tried := make(map[string]struct{}) var lastErr error for { - auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried) + auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried) if errPick != nil { if lastErr != nil { return nil, lastErr @@ -510,14 +517,16 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) } - chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts) + execReq := req + execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth) + chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts) if errStream != nil { rerr := &Error{Message: errStream.Error()} var se cliproxyexecutor.StatusError if errors.As(errStream, &se) && se != nil { rerr.HTTPStatus = se.StatusCode() } - result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr} + result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr} result.RetryAfter = retryAfterFromError(errStream) m.MarkResult(execCtx, result) lastErr = errStream @@ -535,18 +544,66 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string if errors.As(chunk.Err, &se) && se != nil { rerr.HTTPStatus = se.StatusCode() } - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr}) + m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr}) } out <- chunk } if !failed { - m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true}) + m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true}) } }(execCtx, auth.Clone(), provider, chunks) return out, nil } } +func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) { + if auth == nil || model == "" { + return model, metadata + } + prefix := strings.TrimSpace(auth.Prefix) + if prefix == "" { + return model, metadata + } + needle := prefix + "/" + if !strings.HasPrefix(model, needle) { + return model, metadata + } + rewritten := strings.TrimPrefix(model, needle) + return rewritten, stripPrefixFromMetadata(metadata, needle) +} + +func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any { + if len(metadata) == 0 || needle == "" { + return metadata + } + keys := []string{ + util.ThinkingOriginalModelMetadataKey, + util.GeminiOriginalModelMetadataKey, + } + var out map[string]any + for _, key := range keys { + raw, ok := metadata[key] + if !ok { + continue + } + value, okStr := raw.(string) + if !okStr || !strings.HasPrefix(value, needle) { + continue + } + if out == nil { + out = make(map[string]any, len(metadata)) + for k, v := range metadata { + out[k] = v + } + } + out[key] = strings.TrimPrefix(value, needle) + } + if out == nil { + return metadata + } + return out +} + func (m *Manager) normalizeProviders(providers []string) []string { if len(providers) == 0 { return nil diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go index efba6981..5a2d216d 100644 --- a/sdk/cliproxy/auth/types.go +++ b/sdk/cliproxy/auth/types.go @@ -19,6 +19,8 @@ type Auth struct { Index uint64 `json:"-"` // Provider is the upstream provider key (e.g. "gemini", "claude"). Provider string `json:"provider"` + // Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview"). + Prefix string `json:"prefix,omitempty"` // FileName stores the relative or absolute path of the backing auth file. FileName string `json:"-"` // Storage holds the token persistence implementation used during login flows. diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 1ef829d1..f3cbf484 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -787,7 +787,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if providerKey == "" { providerKey = "openai-compatibility" } - GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms) + GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix)) } else { // Ensure stale registrations are cleared when model list becomes empty. GlobalModelRegistry().UnregisterClient(a.ID) @@ -807,7 +807,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { if key == "" { key = strings.ToLower(strings.TrimSpace(a.Provider)) } - GlobalModelRegistry().RegisterClient(a.ID, key, models) + GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix)) return } @@ -987,6 +987,48 @@ func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo { return filtered } +func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo { + trimmedPrefix := strings.TrimSpace(prefix) + if trimmedPrefix == "" || len(models) == 0 { + return models + } + + out := make([]*ModelInfo, 0, len(models)*2) + seen := make(map[string]struct{}, len(models)*2) + + addModel := func(model *ModelInfo) { + if model == nil { + return + } + id := strings.TrimSpace(model.ID) + if id == "" { + return + } + if _, exists := seen[id]; exists { + return + } + seen[id] = struct{}{} + out = append(out, model) + } + + for _, model := range models { + if model == nil { + continue + } + baseID := strings.TrimSpace(model.ID) + if baseID == "" { + continue + } + if !forceModelPrefix || trimmedPrefix == baseID { + addModel(model) + } + clone := *model + clone.ID = trimmedPrefix + "/" + baseID + addModel(&clone) + } + return out +} + // matchWildcard performs case-insensitive wildcard matching where '*' matches any substring. func matchWildcard(pattern, value string) bool { if pattern == "" { diff --git a/sdk/config/config.go b/sdk/config/config.go index acb340ef..f6f20d5c 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -9,6 +9,11 @@ type SDKConfig struct { // ProxyURL is the URL of an optional proxy server to use for outbound requests. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + // ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") + // to target prefixed credentials. When false, unprefixed model requests may use prefixed + // credentials as well. + ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"` + // RequestLog enables or disables detailed request logging functionality. RequestLog bool `yaml:"request-log" json:"request-log"`