Files
CLIProxyAPI/sdk/cliproxy/auth/conductor.go
VooDisss 6962e09dd9 fix(auth): scope affinity by provider
Keep sticky auth affinity limited to matching providers and stop persisting execution-session IDs as long-lived affinity keys so provider switching and normal streaming traffic do not create incorrect pins or stale affinity state.
2026-03-27 18:52:58 +02:00

3163 lines
88 KiB
Go

package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
// ProviderExecutor defines the contract required by Manager to execute provider calls.
type ProviderExecutor interface {
// Identifier returns the provider key handled by this executor.
Identifier() string
// Execute handles non-streaming execution and returns the provider response payload.
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// ExecuteStream handles streaming execution and returns a StreamResult containing
// upstream headers and a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
// Refresh attempts to refresh provider credentials and returns the updated auth state.
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
// CountTokens returns the token count for the given request.
CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
// Callers must close the response body when non-nil.
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
}
// ExecutionSessionCloser allows executors to release per-session runtime resources.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
}
const (
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
// Executors that do not support this marker may ignore it.
CloseAllExecutionSessionsID = "__all_execution_sessions__"
)
// RefreshEvaluator allows runtime state to override refresh decisions.
type RefreshEvaluator interface {
ShouldRefresh(now time.Time, auth *Auth) bool
}
const (
refreshCheckInterval = 5 * time.Second
refreshMaxConcurrency = 16
refreshPendingBackoff = time.Minute
refreshFailureBackoff = 5 * time.Minute
quotaBackoffBase = time.Second
quotaBackoffMax = 30 * time.Minute
)
var quotaCooldownDisabled atomic.Bool
// SetQuotaCooldownDisabled toggles quota cooldown scheduling globally.
func SetQuotaCooldownDisabled(disable bool) {
quotaCooldownDisabled.Store(disable)
}
func quotaCooldownDisabledForAuth(auth *Auth) bool {
if auth != nil {
if override, ok := auth.DisableCoolingOverride(); ok {
return override
}
}
return quotaCooldownDisabled.Load()
}
// Result captures execution outcome used to adjust auth state.
type Result struct {
// AuthID references the auth that produced this result.
AuthID string
// Provider is copied for convenience when emitting hooks.
Provider string
// Model is the upstream model identifier used for the request.
Model string
// Success marks whether the execution succeeded.
Success bool
// RetryAfter carries a provider supplied retry hint (e.g. 429 retryDelay).
RetryAfter *time.Duration
// Error describes the failure when Success is false.
Error *Error
}
// Selector chooses an auth candidate for execution.
type Selector interface {
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
}
// Hook captures lifecycle callbacks for observing auth changes.
type Hook interface {
// OnAuthRegistered fires when a new auth is registered.
OnAuthRegistered(ctx context.Context, auth *Auth)
// OnAuthUpdated fires when an existing auth changes state.
OnAuthUpdated(ctx context.Context, auth *Auth)
// OnResult fires when execution result is recorded.
OnResult(ctx context.Context, result Result)
}
// NoopHook provides optional hook defaults.
type NoopHook struct{}
// OnAuthRegistered implements Hook.
func (NoopHook) OnAuthRegistered(context.Context, *Auth) {}
// OnAuthUpdated implements Hook.
func (NoopHook) OnAuthUpdated(context.Context, *Auth) {}
// OnResult implements Hook.
func (NoopHook) OnResult(context.Context, Result) {}
// Manager orchestrates auth lifecycle, selection, execution, and persistence.
type Manager struct {
store Store
executors map[string]ProviderExecutor
selector Selector
hook Hook
mu sync.RWMutex
auths map[string]*Auth
scheduler *authScheduler
affinityMu sync.RWMutex
affinity map[string]string
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
// Retry controls request retry behavior.
requestRetry atomic.Int32
maxRetryCredentials atomic.Int32
maxRetryInterval atomic.Int64
// oauthModelAlias stores global OAuth model alias mappings (alias -> upstream name) keyed by channel.
oauthModelAlias atomic.Value
// apiKeyModelAlias caches resolved model alias mappings for API-key auths.
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelAlias atomic.Value
// modelPoolOffsets tracks per-auth alias pool rotation state.
modelPoolOffsets map[string]int
// runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
// Auto refresh state
refreshCancel context.CancelFunc
refreshSemaphore chan struct{}
}
// NewManager constructs a manager with optional custom selector and hook.
func NewManager(store Store, selector Selector, hook Hook) *Manager {
if selector == nil {
selector = &RoundRobinSelector{}
}
if hook == nil {
hook = NoopHook{}
}
manager := &Manager{
store: store,
executors: make(map[string]ProviderExecutor),
selector: selector,
hook: hook,
auths: make(map[string]*Auth),
affinity: make(map[string]string),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
manager.scheduler = newAuthScheduler(selector)
return manager
}
func isBuiltInSelector(selector Selector) bool {
switch selector.(type) {
case *RoundRobinSelector, *FillFirstSelector:
return true
default:
return false
}
}
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
if m == nil || m.scheduler == nil {
return
}
m.scheduler.rebuild(auths)
}
func (m *Manager) syncScheduler() {
if m == nil || m.scheduler == nil {
return
}
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
// because the initial scheduler.upsertAuth during Register/Update runs before
// registerModelsForAuth and therefore snapshots an empty model set.
func (m *Manager) RefreshSchedulerEntry(authID string) {
if m == nil || m.scheduler == nil || authID == "" {
return
}
m.mu.RLock()
auth, ok := m.auths[authID]
if !ok || auth == nil {
m.mu.RUnlock()
return
}
snapshot := auth.Clone()
m.mu.RUnlock()
m.scheduler.upsertAuth(snapshot)
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
}
if selector == nil {
selector = &RoundRobinSelector{}
}
m.mu.Lock()
m.selector = selector
m.mu.Unlock()
if m.scheduler != nil {
m.scheduler.setSelector(selector)
m.syncScheduler()
}
}
// SetStore swaps the underlying persistence store.
func (m *Manager) SetStore(store Store) {
m.mu.Lock()
defer m.mu.Unlock()
m.store = store
}
// SetRoundTripperProvider register a provider that returns a per-auth RoundTripper.
func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
m.mu.Lock()
m.rtProvider = p
m.mu.Unlock()
}
// SetConfig updates the runtime config snapshot used by request-time helpers.
// Callers should provide the latest config on reload so per-credential alias mapping stays in sync.
func (m *Manager) SetConfig(cfg *internalconfig.Config) {
if m == nil {
return
}
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.runtimeConfig.Store(cfg)
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
}
func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string {
if m == nil {
return ""
}
authID = strings.TrimSpace(authID)
if authID == "" {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
table, _ := m.apiKeyModelAlias.Load().(apiKeyModelAliasTable)
if table == nil {
return ""
}
byAlias := table[authID]
if len(byAlias) == 0 {
return ""
}
key := strings.ToLower(thinking.ParseSuffix(requestedModel).ModelName)
if key == "" {
key = strings.ToLower(requestedModel)
}
resolved := strings.TrimSpace(byAlias[key])
if resolved == "" {
return ""
}
return preserveRequestedModelSuffix(requestedModel, resolved)
}
func isAPIKeyAuth(auth *Auth) bool {
if auth == nil {
return false
}
kind, _ := auth.AccountInfo()
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
}
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
if !isAPIKeyAuth(auth) {
return false
}
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return true
}
if auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
}
func openAICompatProviderKey(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
return strings.ToLower(providerKey)
}
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
return strings.ToLower(compatName)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
if base == "" {
base = strings.TrimSpace(requestedModel)
}
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
}
func (m *Manager) nextModelPoolOffset(key string, size int) int {
if m == nil || size <= 1 {
return 0
}
key = strings.TrimSpace(key)
if key == "" {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
if m.modelPoolOffsets == nil {
m.modelPoolOffsets = make(map[string]int)
}
offset := m.modelPoolOffsets[key]
if offset >= 2_147_483_640 {
offset = 0
}
m.modelPoolOffsets[key] = offset + 1
if size <= 0 {
return 0
}
return offset % size
}
func rotateStrings(values []string, offset int) []string {
if len(values) <= 1 {
return values
}
if offset <= 0 {
out := make([]string, len(values))
copy(out, values)
return out
}
offset = offset % len(values)
out := make([]string, 0, len(values))
out = append(out, values[offset:]...)
out = append(out, values[:offset]...)
return out
}
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
return nil
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return nil
}
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
}
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
requestedModel := rewriteModelForAuth(routeModel, auth)
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
if len(pool) == 1 {
return pool
}
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
return rotateStrings(pool, offset)
}
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
if strings.TrimSpace(resolved) == "" {
resolved = requestedModel
}
return []string{resolved}
}
func executionResultModel(routeModel, upstreamModel string, pooled bool) string {
if pooled {
if resolved := strings.TrimSpace(upstreamModel); resolved != "" {
return resolved
}
}
if requested := strings.TrimSpace(routeModel); requested != "" {
return requested
}
return strings.TrimSpace(upstreamModel)
}
func filterExecutionModels(auth *Auth, routeModel string, candidates []string, pooled bool) []string {
if len(candidates) == 0 {
return nil
}
now := time.Now()
out := make([]string, 0, len(candidates))
for _, upstreamModel := range candidates {
stateModel := executionResultModel(routeModel, upstreamModel, pooled)
blocked, _, _ := isAuthBlockedForModel(auth, stateModel, now)
if blocked {
continue
}
out = append(out, upstreamModel)
}
return out
}
func (m *Manager) preparedExecutionModels(auth *Auth, routeModel string) ([]string, bool) {
candidates := m.executionModelCandidates(auth, routeModel)
pooled := len(candidates) > 1
return filterExecutionModels(auth, routeModel, candidates, pooled), pooled
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
models, _ := m.preparedExecutionModels(auth, routeModel)
return models
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
}
go func() {
for range ch {
}
}()
}
type streamBootstrapError struct {
cause error
headers http.Header
}
func cloneHTTPHeader(headers http.Header) http.Header {
if headers == nil {
return nil
}
return headers.Clone()
}
func newStreamBootstrapError(err error, headers http.Header) error {
if err == nil {
return nil
}
return &streamBootstrapError{
cause: err,
headers: cloneHTTPHeader(headers),
}
}
func (e *streamBootstrapError) Error() string {
if e == nil || e.cause == nil {
return ""
}
return e.cause.Error()
}
func (e *streamBootstrapError) Unwrap() error {
if e == nil {
return nil
}
return e.cause
}
func (e *streamBootstrapError) Headers() http.Header {
if e == nil {
return nil
}
return cloneHTTPHeader(e.headers)
}
func streamErrorResult(headers http.Header, err error) *cliproxyexecutor.StreamResult {
ch := make(chan cliproxyexecutor.StreamChunk, 1)
ch <- cliproxyexecutor.StreamChunk{Err: err}
close(ch)
return &cliproxyexecutor.StreamResult{
Headers: cloneHTTPHeader(headers),
Chunks: ch,
}
}
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
if ch == nil {
return nil, true, nil
}
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
for {
var (
chunk cliproxyexecutor.StreamChunk
ok bool
)
if ctx != nil {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case chunk, ok = <-ch:
}
} else {
chunk, ok = <-ch
}
if !ok {
return buffered, true, nil
}
if chunk.Err != nil {
return nil, false, chunk.Err
}
buffered = append(buffered, chunk)
if len(chunk.Payload) > 0 {
return buffered, false, nil
}
}
}
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, resultModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
var failed bool
forward := true
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr})
}
if !forward {
return false
}
if ctx == nil {
out <- chunk
return true
}
select {
case <-ctx.Done():
forward = false
return false
case out <- chunk:
return true
}
}
for _, chunk := range buffered {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
for chunk := range remaining {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
if !failed {
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: true})
}
}()
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string, execModels []string, pooled bool) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
var lastErr error
for idx, execModel := range execModels {
resultModel := executionResultModel(routeModel, execModel, pooled)
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
lastErr = bootstrapErr
continue
}
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, newStreamBootstrapError(bootstrapErr, streamResult.Headers)
}
if closed && len(buffered) == 0 {
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: emptyErr}
m.MarkResult(ctx, result)
if idx < len(execModels)-1 {
lastErr = emptyErr
continue
}
return nil, newStreamBootstrapError(emptyErr, streamResult.Headers)
}
remaining := streamResult.Chunks
if closed {
closedCh := make(chan cliproxyexecutor.StreamChunk)
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, resultModel, streamResult.Headers, buffered, remaining), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
}
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
if m == nil {
return
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.mu.Lock()
defer m.mu.Unlock()
m.rebuildAPIKeyModelAliasLocked(cfg)
}
func (m *Manager) rebuildAPIKeyModelAliasLocked(cfg *internalconfig.Config) {
if m == nil {
return
}
if cfg == nil {
cfg = &internalconfig.Config{}
}
out := make(apiKeyModelAliasTable)
for _, auth := range m.auths {
if auth == nil {
continue
}
if strings.TrimSpace(auth.ID) == "" {
continue
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
continue
}
byAlias := make(map[string]string)
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
switch provider {
case "gemini":
if entry := resolveGeminiAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
case "claude":
if entry := resolveClaudeAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
case "codex":
if entry := resolveCodexAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
case "vertex":
if entry := resolveVertexAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
default:
// OpenAI-compat uses config selection from auth.Attributes.
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
if entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
}
}
if len(byAlias) > 0 {
out[auth.ID] = byAlias
}
}
m.apiKeyModelAlias.Store(out)
}
func compileAPIKeyModelAliasForModels[T interface {
GetName() string
GetAlias() string
}](out map[string]string, models []T) {
if out == nil {
return
}
for i := range models {
alias := strings.TrimSpace(models[i].GetAlias())
name := strings.TrimSpace(models[i].GetName())
if alias == "" || name == "" {
continue
}
aliasKey := strings.ToLower(thinking.ParseSuffix(alias).ModelName)
if aliasKey == "" {
aliasKey = strings.ToLower(alias)
}
// Config priority: first alias wins.
if _, exists := out[aliasKey]; exists {
continue
}
out[aliasKey] = name
// Also allow direct lookup by upstream name (case-insensitive), so lookups on already-upstream
// models remain a cheap no-op.
nameKey := strings.ToLower(thinking.ParseSuffix(name).ModelName)
if nameKey == "" {
nameKey = strings.ToLower(name)
}
if nameKey != "" {
if _, exists := out[nameKey]; !exists {
out[nameKey] = name
}
}
// Preserve config suffix priority by seeding a base-name lookup when name already has suffix.
nameResult := thinking.ParseSuffix(name)
if nameResult.HasSuffix {
baseKey := strings.ToLower(strings.TrimSpace(nameResult.ModelName))
if baseKey != "" {
if _, exists := out[baseKey]; !exists {
out[baseKey] = name
}
}
}
}
}
// SetRetryConfig updates retry attempts, credential retry limit and cooldown wait interval.
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration, maxRetryCredentials int) {
if m == nil {
return
}
if retry < 0 {
retry = 0
}
if maxRetryCredentials < 0 {
maxRetryCredentials = 0
}
if maxRetryInterval < 0 {
maxRetryInterval = 0
}
m.requestRetry.Store(int32(retry))
m.maxRetryCredentials.Store(int32(maxRetryCredentials))
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
}
// RegisterExecutor registers a provider executor with the manager.
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
return
}
provider := strings.TrimSpace(executor.Identifier())
if provider == "" {
return
}
var replaced ProviderExecutor
m.mu.Lock()
replaced = m.executors[provider]
m.executors[provider] = executor
m.mu.Unlock()
if replaced == nil || replaced == executor {
return
}
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
}
}
// UnregisterExecutor removes the executor associated with the provider key.
func (m *Manager) UnregisterExecutor(provider string) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return
}
m.mu.Lock()
delete(m.executors, provider)
m.mu.Unlock()
}
// Register inserts a new auth entry into the manager.
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {
return nil, nil
}
if auth.ID == "" {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
authClone := auth.Clone()
m.mu.Lock()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
}
// Update replaces an existing auth entry and notifies hooks.
func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil || auth.ID == "" {
return nil, nil
}
m.mu.Lock()
if existing, ok := m.auths[auth.ID]; ok && existing != nil {
if !auth.indexAssigned && auth.Index == "" {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
auth.ModelStates = existing.ModelStates
}
}
auth.EnsureIndex()
authClone := auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
}
// Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock()
if m.store == nil {
m.mu.Unlock()
return nil
}
items, err := m.store.List(ctx)
if err != nil {
m.mu.Unlock()
return err
}
m.auths = make(map[string]*Auth, len(items))
for _, auth := range items {
if auth == nil || auth.ID == "" {
continue
}
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.rebuildAPIKeyModelAliasLocked(cfg)
m.mu.Unlock()
m.syncScheduler()
return nil
}
// Execute performs a non-streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
_, maxRetryCredentials, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; ; attempt++ {
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return cliproxyexecutor.Response{}, errWait
}
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// ExecuteCount performs a non-streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
_, maxRetryCredentials, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; ; attempt++ {
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return cliproxyexecutor.Response{}, errWait
}
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// ExecuteStream performs a streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
_, maxRetryCredentials, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; ; attempt++ {
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
if errStream == nil {
return result, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return nil, errWait
}
}
if lastErr != nil {
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
attempted := make(map[string]struct{})
var lastErr error
for {
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
continue
}
attempted[auth.ID] = struct{}{}
var authErr error
for _, upstreamModel := range models {
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.MarkResult(execCtx, result)
m.persistAuthAffinity(entry, opts, auth.ID, provider, req.Model)
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = authErr
continue
}
}
}
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
attempted := make(map[string]struct{})
var lastErr error
for {
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
continue
}
attempted[auth.ID] = struct{}{}
var authErr error
for _, upstreamModel := range models {
resultModel := executionResultModel(routeModel, upstreamModel, pooled)
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.MarkResult(execCtx, result)
m.persistAuthAffinity(entry, opts, auth.ID, provider, req.Model)
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = authErr
continue
}
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
attempted := make(map[string]struct{})
var lastErr error
for {
if maxRetryCredentials > 0 && len(attempted) >= maxRetryCredentials {
if lastErr != nil {
var bootstrapErr *streamBootstrapError
if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil {
return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil
}
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
var bootstrapErr *streamBootstrapError
if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil {
return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil
}
return nil, lastErr
}
return nil, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
continue
}
attempted[auth.ID] = struct{}{}
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
m.persistAuthAffinity(entry, opts, auth.ID, provider, req.Model)
return streamResult, nil
}
}
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return opts
}
if hasRequestedModelMetadata(opts.Metadata) {
return opts
}
if len(opts.Metadata) == 0 {
opts.Metadata = map[string]any{cliproxyexecutor.RequestedModelMetadataKey: requestedModel}
return opts
}
meta := make(map[string]any, len(opts.Metadata)+1)
for k, v := range opts.Metadata {
meta[k] = v
}
meta[cliproxyexecutor.RequestedModelMetadataKey] = requestedModel
opts.Metadata = meta
return opts
}
func hasRequestedModelMetadata(meta map[string]any) bool {
if len(meta) == 0 {
return false
}
raw, ok := meta[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v) != ""
case []byte:
return strings.TrimSpace(string(v)) != ""
default:
return false
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
callback(authID)
}
}
func rewriteModelForAuth(model string, auth *Auth) string {
if auth == nil || model == "" {
return model
}
prefix := strings.TrimSpace(auth.Prefix)
if prefix == "" {
return model
}
needle := prefix + "/"
if !strings.HasPrefix(model, needle) {
return model
}
return strings.TrimPrefix(model, needle)
}
func (m *Manager) applyAPIKeyModelAlias(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return requestedModel
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
return requestedModel
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return requestedModel
}
// Fast path: lookup per-auth mapping table (keyed by auth.ID).
if resolved := m.lookupAPIKeyUpstreamModel(auth.ID, requestedModel); resolved != "" {
return resolved
}
// Slow path: scan config for the matching credential entry and resolve alias.
// This acts as a safety net if mappings are stale or auth.ID is missing.
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
upstreamModel := ""
switch provider {
case "gemini":
upstreamModel = resolveUpstreamModelForGeminiAPIKey(cfg, auth, requestedModel)
case "claude":
upstreamModel = resolveUpstreamModelForClaudeAPIKey(cfg, auth, requestedModel)
case "codex":
upstreamModel = resolveUpstreamModelForCodexAPIKey(cfg, auth, requestedModel)
case "vertex":
upstreamModel = resolveUpstreamModelForVertexAPIKey(cfg, auth, requestedModel)
default:
upstreamModel = resolveUpstreamModelForOpenAICompatAPIKey(cfg, auth, requestedModel)
}
// Return upstream model if found, otherwise return requested model.
if upstreamModel != "" {
return upstreamModel
}
return requestedModel
}
// APIKeyConfigEntry is a generic interface for API key configurations.
type APIKeyConfigEntry interface {
GetAPIKey() string
GetBaseURL() string
}
func resolveAPIKeyConfig[T APIKeyConfigEntry](entries []T, auth *Auth) *T {
if auth == nil || len(entries) == 0 {
return nil
}
attrKey, attrBase := "", ""
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range entries {
entry := &entries[i]
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range entries {
entry := &entries[i]
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
return entry
}
}
}
return nil
}
func resolveGeminiAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.GeminiKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.GeminiKey, auth)
}
func resolveClaudeAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.ClaudeKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.ClaudeKey, auth)
}
func resolveCodexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.CodexKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.CodexKey, auth)
}
func resolveVertexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.VertexCompatKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.VertexCompatAPIKey, auth)
}
func resolveUpstreamModelForGeminiAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveGeminiAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForClaudeAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveClaudeAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForCodexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveCodexAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForVertexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveVertexAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForOpenAICompatAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
providerKey := ""
compatName := ""
if auth != nil && len(auth.Attributes) > 0 {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
if compatName == "" && !strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return ""
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
type apiKeyModelAliasTable map[string]map[string]string
func resolveOpenAICompatConfig(cfg *internalconfig.Config, providerKey, compatName, authProvider string) *internalconfig.OpenAICompatibility {
if cfg == nil {
return nil
}
candidates := make([]string, 0, 3)
if v := strings.TrimSpace(compatName); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(providerKey); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(authProvider); v != "" {
candidates = append(candidates, v)
}
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
return compat
}
}
}
return nil
}
func asModelAliasEntries[T interface {
GetName() string
GetAlias() string
}](models []T) []modelAliasEntry {
if len(models) == 0 {
return nil
}
out := make([]modelAliasEntry, 0, len(models))
for i := range models {
out = append(out, models[i])
}
return out
}
func (m *Manager) normalizeProviders(providers []string) []string {
if len(providers) == 0 {
return nil
}
result := make([]string, 0, len(providers))
seen := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
result = append(result, p)
}
return result
}
func (m *Manager) retrySettings() (int, int, time.Duration) {
if m == nil {
return 0, 0, 0
}
return int(m.requestRetry.Load()), int(m.maxRetryCredentials.Load()), time.Duration(m.maxRetryInterval.Load())
}
func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
if m == nil || len(providers) == 0 {
return 0, false
}
now := time.Now()
defaultRetry := int(m.requestRetry.Load())
if defaultRetry < 0 {
defaultRetry = 0
}
providerSet := make(map[string]struct{}, len(providers))
for i := range providers {
key := strings.TrimSpace(strings.ToLower(providers[i]))
if key == "" {
continue
}
providerSet[key] = struct{}{}
}
m.mu.RLock()
defer m.mu.RUnlock()
var (
found bool
minWait time.Duration
)
for _, auth := range m.auths {
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
effectiveRetry := defaultRetry
if override, ok := auth.RequestRetryOverride(); ok {
effectiveRetry = override
}
if effectiveRetry < 0 {
effectiveRetry = 0
}
if attempt >= effectiveRetry {
continue
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if !blocked || next.IsZero() || reason == blockReasonDisabled {
continue
}
wait := next.Sub(now)
if wait < 0 {
continue
}
if !found || wait < minWait {
minWait = wait
found = true
}
}
return minWait, found
}
func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
if err == nil {
return 0, false
}
if maxWait <= 0 {
return 0, false
}
if status := statusCodeFromError(err); status == http.StatusOK {
return 0, false
}
if isRequestInvalidError(err) {
return 0, false
}
wait, found := m.closestCooldownWait(providers, model, attempt)
if !found || wait > maxWait {
return 0, false
}
return wait, true
}
func waitForCooldown(ctx context.Context, wait time.Duration) error {
if wait <= 0 {
return nil
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// MarkResult records an execution result and notifies hooks.
func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.AuthID == "" {
return
}
shouldResumeModel := false
shouldSuspendModel := false
suspendReason := ""
clearModelQuota := false
setModelQuota := false
var authSnapshot *Auth
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
now := time.Now()
if result.Success {
if result.Model != "" {
state := ensureModelState(auth, result.Model)
resetModelState(state, now)
updateAggregatedAvailability(auth, now)
if !hasModelError(auth, now) {
auth.LastError = nil
auth.StatusMessage = ""
auth.Status = StatusActive
}
auth.UpdatedAt = now
shouldResumeModel = true
clearModelQuota = true
} else {
clearAuthStateOnSuccess(auth, now)
}
} else {
if result.Model != "" {
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
state.UpdatedAt = now
if result.Error != nil {
state.LastError = cloneError(result.Error)
state.StatusMessage = result.Error.Message
auth.LastError = cloneError(result.Error)
auth.StatusMessage = result.Error.Message
}
statusCode := statusCodeFromResult(result.Error)
if isModelSupportResultError(result.Error) {
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "model_not_supported"
shouldSuspendModel = true
} else {
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
case 404:
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
Exceeded: true,
Reason: "quota",
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
state.NextRetryAfter = next
}
default:
state.NextRetryAfter = time.Time{}
}
}
auth.Status = StatusError
auth.UpdatedAt = now
updateAggregatedAvailability(auth, now)
} else {
applyAuthFailureState(auth, result.Error, result.RetryAfter, now)
}
}
_ = m.persist(ctx, auth)
authSnapshot = auth.Clone()
}
m.mu.Unlock()
if m.scheduler != nil && authSnapshot != nil {
m.scheduler.upsertAuth(authSnapshot)
}
if clearModelQuota && result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
}
if setModelQuota && result.Model != "" {
registry.GetGlobalRegistry().SetModelQuotaExceeded(result.AuthID, result.Model)
}
if shouldResumeModel {
registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model)
} else if shouldSuspendModel {
registry.GetGlobalRegistry().SuspendClientModel(result.AuthID, result.Model, suspendReason)
}
m.hook.OnResult(ctx, result)
}
func ensureModelState(auth *Auth, model string) *ModelState {
if auth == nil || model == "" {
return nil
}
if auth.ModelStates == nil {
auth.ModelStates = make(map[string]*ModelState)
}
if state, ok := auth.ModelStates[model]; ok && state != nil {
return state
}
state := &ModelState{Status: StatusActive}
auth.ModelStates[model] = state
return state
}
func resetModelState(state *ModelState, now time.Time) {
if state == nil {
return
}
state.Unavailable = false
state.Status = StatusActive
state.StatusMessage = ""
state.NextRetryAfter = time.Time{}
state.LastError = nil
state.Quota = QuotaState{}
state.UpdatedAt = now
}
func updateAggregatedAvailability(auth *Auth, now time.Time) {
if auth == nil || len(auth.ModelStates) == 0 {
return
}
allUnavailable := true
earliestRetry := time.Time{}
quotaExceeded := false
quotaRecover := time.Time{}
maxBackoffLevel := 0
for _, state := range auth.ModelStates {
if state == nil {
continue
}
stateUnavailable := false
if state.Status == StatusDisabled {
stateUnavailable = true
} else if state.Unavailable {
if state.NextRetryAfter.IsZero() {
stateUnavailable = false
} else if state.NextRetryAfter.After(now) {
stateUnavailable = true
if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) {
earliestRetry = state.NextRetryAfter
}
} else {
state.Unavailable = false
state.NextRetryAfter = time.Time{}
}
}
if !stateUnavailable {
allUnavailable = false
}
if state.Quota.Exceeded {
quotaExceeded = true
if quotaRecover.IsZero() || (!state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.Before(quotaRecover)) {
quotaRecover = state.Quota.NextRecoverAt
}
if state.Quota.BackoffLevel > maxBackoffLevel {
maxBackoffLevel = state.Quota.BackoffLevel
}
}
}
auth.Unavailable = allUnavailable
if allUnavailable {
auth.NextRetryAfter = earliestRetry
} else {
auth.NextRetryAfter = time.Time{}
}
if quotaExceeded {
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
auth.Quota.NextRecoverAt = quotaRecover
auth.Quota.BackoffLevel = maxBackoffLevel
} else {
auth.Quota.Exceeded = false
auth.Quota.Reason = ""
auth.Quota.NextRecoverAt = time.Time{}
auth.Quota.BackoffLevel = 0
}
}
func hasModelError(auth *Auth, now time.Time) bool {
if auth == nil || len(auth.ModelStates) == 0 {
return false
}
for _, state := range auth.ModelStates {
if state == nil {
continue
}
if state.LastError != nil {
return true
}
if state.Status == StatusError {
if state.Unavailable && (state.NextRetryAfter.IsZero() || state.NextRetryAfter.After(now)) {
return true
}
}
}
return false
}
func clearAuthStateOnSuccess(auth *Auth, now time.Time) {
if auth == nil {
return
}
auth.Unavailable = false
auth.Status = StatusActive
auth.StatusMessage = ""
auth.Quota.Exceeded = false
auth.Quota.Reason = ""
auth.Quota.NextRecoverAt = time.Time{}
auth.Quota.BackoffLevel = 0
auth.LastError = nil
auth.NextRetryAfter = time.Time{}
auth.UpdatedAt = now
}
func cloneError(err *Error) *Error {
if err == nil {
return nil
}
return &Error{
Code: err.Code,
Message: err.Message,
Retryable: err.Retryable,
HTTPStatus: err.HTTPStatus,
}
}
func statusCodeFromError(err error) int {
if err == nil {
return 0
}
type statusCoder interface {
StatusCode() int
}
var sc statusCoder
if errors.As(err, &sc) && sc != nil {
return sc.StatusCode()
}
return 0
}
func retryAfterFromError(err error) *time.Duration {
if err == nil {
return nil
}
type retryAfterProvider interface {
RetryAfter() *time.Duration
}
rap, ok := err.(retryAfterProvider)
if !ok || rap == nil {
return nil
}
retryAfter := rap.RetryAfter()
if retryAfter == nil {
return nil
}
return new(*retryAfter)
}
func statusCodeFromResult(err *Error) int {
if err == nil {
return 0
}
return err.StatusCode()
}
func isModelSupportErrorMessage(message string) bool {
lower := strings.ToLower(strings.TrimSpace(message))
if lower == "" {
return false
}
patterns := [...]string{
"model_not_supported",
"requested model is not supported",
"requested model is unsupported",
"requested model is unavailable",
"model is not supported",
"model not supported",
"unsupported model",
"model unavailable",
"not available for your plan",
"not available for your account",
}
for _, pattern := range patterns {
if strings.Contains(lower, pattern) {
return true
}
}
return false
}
func isModelSupportError(err error) bool {
if err == nil {
return false
}
status := statusCodeFromError(err)
if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity {
return false
}
return isModelSupportErrorMessage(err.Error())
}
func isModelSupportResultError(err *Error) bool {
if err == nil {
return false
}
status := statusCodeFromResult(err)
if status != http.StatusBadRequest && status != http.StatusUnprocessableEntity {
return false
}
return isModelSupportErrorMessage(err.Message)
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help. Model-support
// errors are excluded so routing can fall through to another auth or upstream.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
}
if isModelSupportError(err) {
return false
}
status := statusCodeFromError(err)
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusUnprocessableEntity:
return true
default:
return false
}
}
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
if auth == nil {
return
}
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
if resultErr != nil {
auth.LastError = cloneError(resultErr)
if resultErr.Message != "" {
auth.StatusMessage = resultErr.Message
}
}
statusCode := statusCodeFromResult(resultErr)
switch statusCode {
case 401:
auth.StatusMessage = "unauthorized"
auth.NextRetryAfter = now.Add(30 * time.Minute)
case 402, 403:
auth.StatusMessage = "payment_required"
auth.NextRetryAfter = now.Add(30 * time.Minute)
case 404:
auth.StatusMessage = "not_found"
auth.NextRetryAfter = now.Add(12 * time.Hour)
case 429:
auth.StatusMessage = "quota exhausted"
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
var next time.Time
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
auth.Quota.BackoffLevel = nextLevel
}
auth.Quota.NextRecoverAt = next
auth.NextRetryAfter = next
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
if quotaCooldownDisabledForAuth(auth) {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(1 * time.Minute)
}
default:
if auth.StatusMessage == "" {
auth.StatusMessage = "request failed"
}
}
}
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) {
if prevLevel < 0 {
prevLevel = 0
}
if disableCooling {
return 0, prevLevel
}
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
if cooldown < quotaBackoffBase {
cooldown = quotaBackoffBase
}
if cooldown >= quotaBackoffMax {
return quotaBackoffMax, prevLevel
}
return cooldown, prevLevel + 1
}
// List returns all auth entries currently known by the manager.
func (m *Manager) List() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
list := make([]*Auth, 0, len(m.auths))
for _, auth := range m.auths {
list = append(list, auth.Clone())
}
return list
}
// GetByID retrieves an auth entry by its ID.
func (m *Manager) GetByID(id string) (*Auth, bool) {
if id == "" {
return nil, false
}
m.mu.RLock()
defer m.mu.RUnlock()
auth, ok := m.auths[id]
if !ok {
return nil, false
}
return auth.Clone(), true
}
// Executor returns the registered provider executor for a provider key.
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
if m == nil {
return nil, false
}
provider = strings.TrimSpace(provider)
if provider == "" {
return nil, false
}
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
lowerProvider := strings.ToLower(provider)
if lowerProvider != provider {
executor, okExecutor = m.executors[lowerProvider]
}
}
m.mu.RUnlock()
if !okExecutor || executor == nil {
return nil, false
}
return executor, true
}
// CloseExecutionSession asks all registered executors to release the supplied execution session.
func (m *Manager) CloseExecutionSession(sessionID string) {
sessionID = strings.TrimSpace(sessionID)
if m == nil || sessionID == "" {
return
}
m.mu.RLock()
executors := make([]ProviderExecutor, 0, len(m.executors))
for _, exec := range m.executors {
executors = append(executors, exec)
}
m.mu.RUnlock()
for i := range executors {
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(sessionID)
}
}
}
func authAffinityKeyFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta["auth_affinity_key"]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}
func scopedAuthAffinityKey(provider, key string) string {
provider = strings.TrimSpace(strings.ToLower(provider))
key = strings.TrimSpace(key)
if provider == "" || key == "" {
return ""
}
return provider + "|" + key
}
func (m *Manager) AuthAffinity(provider, key string) string {
key = scopedAuthAffinityKey(provider, key)
if m == nil || key == "" {
return ""
}
m.affinityMu.RLock()
defer m.affinityMu.RUnlock()
return strings.TrimSpace(m.affinity[key])
}
func (m *Manager) applyAuthAffinity(provider string, opts *cliproxyexecutor.Options) {
if m == nil || opts == nil || pinnedAuthIDFromMetadata(opts.Metadata) != "" {
return
}
if affinityKey := authAffinityKeyFromMetadata(opts.Metadata); affinityKey != "" {
if affinityAuthID := m.AuthAffinity(provider, affinityKey); affinityAuthID != "" {
if opts.Metadata == nil {
opts.Metadata = make(map[string]any)
}
opts.Metadata[cliproxyexecutor.PinnedAuthMetadataKey] = affinityAuthID
}
}
}
func (m *Manager) persistAuthAffinity(entry *log.Entry, opts cliproxyexecutor.Options, authID, provider, model string) {
if m == nil {
return
}
if affinityKey := authAffinityKeyFromMetadata(opts.Metadata); affinityKey != "" {
m.SetAuthAffinity(provider, affinityKey, authID)
if entry != nil && log.IsLevelEnabled(log.DebugLevel) {
entry.Debugf("auth affinity pinned auth_id=%s provider=%s model=%s", authID, provider, model)
}
}
}
func (m *Manager) SetAuthAffinity(provider, key, authID string) {
key = scopedAuthAffinityKey(provider, key)
authID = strings.TrimSpace(authID)
if m == nil || key == "" || authID == "" {
return
}
m.affinityMu.Lock()
if m.affinity == nil {
m.affinity = make(map[string]string)
}
m.affinity[key] = authID
m.affinityMu.Unlock()
}
func (m *Manager) ClearAuthAffinity(provider, key string) {
key = scopedAuthAffinityKey(provider, key)
if m == nil || key == "" {
return
}
m.affinityMu.Lock()
delete(m.affinity, key)
m.affinityMu.Unlock()
}
func (m *Manager) useSchedulerFastPath() bool {
if m == nil || m.scheduler == nil {
return false
}
return isBuiltInSelector(m.selector)
}
func shouldRetrySchedulerPick(err error) bool {
if err == nil {
return false
}
var cooldownErr *modelCooldownError
if errors.As(err, &cooldownErr) {
return true
}
var authErr *Error
if !errors.As(err, &authErr) || authErr == nil {
return false
}
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
candidates := make([]*Auth, 0, len(m.auths))
modelKey := strings.TrimSpace(model)
// Always use base model name (without thinking suffix) for auth matching.
if modelKey != "" {
parsed := thinking.ParseSuffix(modelKey)
if parsed.ModelName != "" {
modelKey = strings.TrimSpace(parsed.ModelName)
}
}
registryRef := registry.GetGlobalRegistry()
for _, candidate := range m.auths {
if candidate.Provider != provider || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 {
m.mu.RUnlock()
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, errPick
}
if selected == nil {
m.mu.RUnlock()
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
m.applyAuthAffinity(provider, &opts)
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
}
if errPick != nil {
return nil, nil, errPick
}
if selected == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
providerSet[p] = struct{}{}
}
if len(providerSet) == 0 {
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
m.mu.RLock()
candidates := make([]*Auth, 0, len(m.auths))
modelKey := strings.TrimSpace(model)
// Always use base model name (without thinking suffix) for auth matching.
if modelKey != "" {
parsed := thinking.ParseSuffix(modelKey)
if parsed.ModelName != "" {
modelKey = strings.TrimSpace(parsed.ModelName)
}
}
registryRef := registry.GetGlobalRegistry()
for _, candidate := range m.auths {
if candidate == nil || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue
}
if _, ok := providerSet[providerKey]; !ok {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if _, ok := m.executors[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, "", errPick
}
if selected == nil {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
executor, okExecutor := m.executors[providerKey]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
if pinnedAuthIDFromMetadata(opts.Metadata) == "" {
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
m.applyAuthAffinity(providerKey, &opts)
if pinnedAuthIDFromMetadata(opts.Metadata) != "" {
break
}
}
}
if !m.useSchedulerFastPath() {
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
eligibleProviders := make([]string, 0, len(providers))
seenProviders := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
if _, seen := seenProviders[providerKey]; seen {
continue
}
if _, okExecutor := m.Executor(providerKey); !okExecutor {
continue
}
seenProviders[providerKey] = struct{}{}
eligibleProviders = append(eligibleProviders, providerKey)
}
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
}
if errPick != nil {
return nil, nil, "", errPick
}
if selected == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
executor, okExecutor := m.Executor(providerKey)
if !okExecutor {
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
}
if shouldSkipPersist(ctx) {
return nil
}
if auth.Attributes != nil {
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
return nil
}
}
// Skip persistence when metadata is absent (e.g., runtime-only auths).
if auth.Metadata == nil {
return nil
}
_, err := m.store.Save(ctx, auth)
return err
}
// StartAutoRefresh launches a background loop that evaluates auth freshness
// every few seconds and triggers refresh operations when required.
// Only one loop is kept alive; starting a new one cancels the previous run.
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
if interval <= 0 {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
}
ctx, cancel := context.WithCancel(parent)
m.refreshCancel = cancel
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
}
}
}()
}
// StopAutoRefresh cancels the background refresh loop, if running.
func (m *Manager) StopAutoRefresh() {
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
}
}
func (m *Manager) checkRefreshes(ctx context.Context) {
// log.Debugf("checking refreshes")
now := time.Now()
snapshot := m.snapshotAuths()
for _, a := range snapshot {
typ, _ := a.AccountInfo()
if typ != "api_key" {
if !m.shouldRefresh(a, now) {
continue
}
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
if exec := m.executorFor(a.Provider); exec == nil {
continue
}
if !m.markRefreshPending(a.ID, now) {
continue
}
go m.refreshAuthWithLimit(ctx, a.ID)
}
}
}
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
if m.refreshSemaphore == nil {
m.refreshAuth(ctx, id)
return
}
select {
case m.refreshSemaphore <- struct{}{}:
defer func() { <-m.refreshSemaphore }()
case <-ctx.Done():
return
}
m.refreshAuth(ctx, id)
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
}
return out
}
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
if a == nil || a.Disabled {
return false
}
if !a.NextRefreshAfter.IsZero() && now.Before(a.NextRefreshAfter) {
return false
}
if evaluator, ok := a.Runtime.(RefreshEvaluator); ok && evaluator != nil {
return evaluator.ShouldRefresh(now, a)
}
lastRefresh := a.LastRefreshedAt
if lastRefresh.IsZero() {
if ts, ok := authLastRefreshTimestamp(a); ok {
lastRefresh = ts
}
}
expiry, hasExpiry := a.ExpirationTime()
if interval := authPreferredInterval(a); interval > 0 {
if hasExpiry && !expiry.IsZero() {
if !expiry.After(now) {
return true
}
if expiry.Sub(now) <= interval {
return true
}
}
if lastRefresh.IsZero() {
return true
}
return now.Sub(lastRefresh) >= interval
}
provider := strings.ToLower(a.Provider)
lead := ProviderRefreshLead(provider, a.Runtime)
if lead == nil {
return false
}
if *lead <= 0 {
if hasExpiry && !expiry.IsZero() {
return now.After(expiry)
}
return false
}
if hasExpiry && !expiry.IsZero() {
return time.Until(expiry) <= *lead
}
if !lastRefresh.IsZero() {
return now.Sub(lastRefresh) >= *lead
}
return true
}
func authPreferredInterval(a *Auth) time.Duration {
if a == nil {
return 0
}
if d := durationFromMetadata(a.Metadata, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
return d
}
if d := durationFromAttributes(a.Attributes, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
return d
}
return 0
}
func durationFromMetadata(meta map[string]any, keys ...string) time.Duration {
if len(meta) == 0 {
return 0
}
for _, key := range keys {
if val, ok := meta[key]; ok {
if dur := parseDurationValue(val); dur > 0 {
return dur
}
}
}
return 0
}
func durationFromAttributes(attrs map[string]string, keys ...string) time.Duration {
if len(attrs) == 0 {
return 0
}
for _, key := range keys {
if val, ok := attrs[key]; ok {
if dur := parseDurationString(val); dur > 0 {
return dur
}
}
}
return 0
}
func parseDurationValue(val any) time.Duration {
switch v := val.(type) {
case time.Duration:
if v <= 0 {
return 0
}
return v
case int:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case int32:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case int64:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case uint:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case uint32:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case uint64:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case float32:
if v <= 0 {
return 0
}
return time.Duration(float64(v) * float64(time.Second))
case float64:
if v <= 0 {
return 0
}
return time.Duration(v * float64(time.Second))
case json.Number:
if i, err := v.Int64(); err == nil {
if i <= 0 {
return 0
}
return time.Duration(i) * time.Second
}
if f, err := v.Float64(); err == nil && f > 0 {
return time.Duration(f * float64(time.Second))
}
case string:
return parseDurationString(v)
}
return 0
}
func parseDurationString(raw string) time.Duration {
s := strings.TrimSpace(raw)
if s == "" {
return 0
}
if dur, err := time.ParseDuration(s); err == nil && dur > 0 {
return dur
}
if secs, err := strconv.ParseFloat(s, 64); err == nil && secs > 0 {
return time.Duration(secs * float64(time.Second))
}
return 0
}
func authLastRefreshTimestamp(a *Auth) (time.Time, bool) {
if a == nil {
return time.Time{}, false
}
if a.Metadata != nil {
if ts, ok := lookupMetadataTime(a.Metadata, "last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"); ok {
return ts, true
}
}
if a.Attributes != nil {
for _, key := range []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} {
if val := strings.TrimSpace(a.Attributes[key]); val != "" {
if ts, ok := parseTimeValue(val); ok {
return ts, true
}
}
}
}
return time.Time{}, false
}
func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
for _, key := range keys {
if val, ok := meta[key]; ok {
if ts, ok1 := parseTimeValue(val); ok1 {
return ts, true
}
}
}
return time.Time{}, false
}
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
m.mu.Lock()
defer m.mu.Unlock()
auth, ok := m.auths[id]
if !ok || auth == nil || auth.Disabled {
return false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
return false
}
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
m.auths[id] = auth
return true
}
func (m *Manager) refreshAuth(ctx context.Context, id string) {
if ctx == nil {
ctx = context.Background()
}
m.mu.RLock()
auth := m.auths[id]
var exec ProviderExecutor
if auth != nil {
exec = m.executors[auth.Provider]
}
m.mu.RUnlock()
if auth == nil || exec == nil {
return
}
cloned := auth.Clone()
updated, err := exec.Refresh(ctx, cloned)
if err != nil && errors.Is(err, context.Canceled) {
log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID)
return
}
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
now := time.Now()
if err != nil {
m.mu.Lock()
if current := m.auths[id]; current != nil {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
return
}
if updated == nil {
updated = cloned
}
// Preserve runtime created by the executor during Refresh.
// If executor didn't set one, fall back to the previous runtime.
if updated.Runtime == nil {
updated.Runtime = auth.Runtime
}
updated.LastRefreshedAt = now
updated.NextRefreshAfter = time.Time{}
updated.LastError = nil
updated.UpdatedAt = now
_, _ = m.Update(ctx, updated)
}
func (m *Manager) executorFor(provider string) ProviderExecutor {
m.mu.RLock()
defer m.mu.RUnlock()
return m.executors[provider]
}
// roundTripperContextKey is an unexported context key type to avoid collisions.
type roundTripperContextKey struct{}
// roundTripperFor retrieves an HTTP RoundTripper for the given auth if a provider is registered.
func (m *Manager) roundTripperFor(auth *Auth) http.RoundTripper {
m.mu.RLock()
p := m.rtProvider
m.mu.RUnlock()
if p == nil || auth == nil {
return nil
}
return p.RoundTripperFor(auth)
}
// RoundTripperProvider defines a minimal provider of per-auth HTTP transports.
type RoundTripperProvider interface {
RoundTripperFor(auth *Auth) http.RoundTripper
}
// RequestPreparer is an optional interface that provider executors can implement
// to mutate outbound HTTP requests with provider credentials.
type RequestPreparer interface {
PrepareRequest(req *http.Request, auth *Auth) error
}
func executorKeyFromAuth(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
providerKey := strings.TrimSpace(auth.Attributes["provider_key"])
compatName := strings.TrimSpace(auth.Attributes["compat_name"])
if compatName != "" {
if providerKey == "" {
providerKey = compatName
}
return strings.ToLower(providerKey)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
func logEntryWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
if reqID := logging.GetRequestID(ctx); reqID != "" {
return log.WithField("request_id", reqID)
}
return log.NewEntry(log.StandardLogger())
}
func debugLogAuthSelection(entry *log.Entry, auth *Auth, provider string, model string) {
if !log.IsLevelEnabled(log.DebugLevel) {
return
}
if entry == nil || auth == nil {
return
}
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
suffix := ""
if proxyInfo != "" {
suffix = " " + proxyInfo
}
switch accountType {
case "api_key":
entry.Debugf("Use API key %s for model %s%s", util.HideAPIKey(accountInfo), model, suffix)
case "oauth":
ident := formatOauthIdentity(auth, provider, accountInfo)
entry.Debugf("Use OAuth %s for model %s%s", ident, model, suffix)
}
}
func formatOauthIdentity(auth *Auth, provider string, accountInfo string) string {
if auth == nil {
return ""
}
// Prefer the auth's provider when available.
providerName := strings.TrimSpace(auth.Provider)
if providerName == "" {
providerName = strings.TrimSpace(provider)
}
// Only log the basename to avoid leaking host paths.
// FileName may be unset for some auth backends; fall back to ID.
authFile := strings.TrimSpace(auth.FileName)
if authFile == "" {
authFile = strings.TrimSpace(auth.ID)
}
if authFile != "" {
authFile = filepath.Base(authFile)
}
parts := make([]string, 0, 3)
if providerName != "" {
parts = append(parts, "provider="+providerName)
}
if authFile != "" {
parts = append(parts, "auth_file="+authFile)
}
if len(parts) == 0 {
return accountInfo
}
return strings.Join(parts, " ")
}
// InjectCredentials delegates per-provider HTTP request preparation when supported.
// If the registered executor for the auth provider implements RequestPreparer,
// it will be invoked to modify the request (e.g., add headers).
func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
if req == nil || authID == "" {
return nil
}
m.mu.RLock()
a := m.auths[authID]
var exec ProviderExecutor
if a != nil {
exec = m.executors[executorKeyFromAuth(a)]
}
m.mu.RUnlock()
if a == nil || exec == nil {
return nil
}
if p, ok := exec.(RequestPreparer); ok && p != nil {
return p.PrepareRequest(req, a)
}
return nil
}
// PrepareHttpRequest injects provider credentials into the supplied HTTP request.
func (m *Manager) PrepareHttpRequest(ctx context.Context, auth *Auth, req *http.Request) error {
if m == nil {
return &Error{Code: "provider_not_found", Message: "manager is nil"}
}
if auth == nil {
return &Error{Code: "auth_not_found", Message: "auth is nil"}
}
if req == nil {
return &Error{Code: "invalid_request", Message: "http request is nil"}
}
if ctx != nil {
*req = *req.WithContext(ctx)
}
providerKey := executorKeyFromAuth(auth)
if providerKey == "" {
return &Error{Code: "provider_not_found", Message: "auth provider is empty"}
}
exec := m.executorFor(providerKey)
if exec == nil {
return &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
}
preparer, ok := exec.(RequestPreparer)
if !ok || preparer == nil {
return &Error{Code: "not_supported", Message: "executor does not support http request preparation"}
}
return preparer.PrepareRequest(req, auth)
}
// NewHttpRequest constructs a new HTTP request and injects provider credentials into it.
func (m *Manager) NewHttpRequest(ctx context.Context, auth *Auth, method, targetURL string, body []byte, headers http.Header) (*http.Request, error) {
if ctx == nil {
ctx = context.Background()
}
method = strings.TrimSpace(method)
if method == "" {
method = http.MethodGet
}
var reader io.Reader
if body != nil {
reader = bytes.NewReader(body)
}
httpReq, err := http.NewRequestWithContext(ctx, method, targetURL, reader)
if err != nil {
return nil, err
}
if headers != nil {
httpReq.Header = headers.Clone()
}
if errPrepare := m.PrepareHttpRequest(ctx, auth, httpReq); errPrepare != nil {
return nil, errPrepare
}
return httpReq, nil
}
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
func (m *Manager) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
if m == nil {
return nil, &Error{Code: "provider_not_found", Message: "manager is nil"}
}
if auth == nil {
return nil, &Error{Code: "auth_not_found", Message: "auth is nil"}
}
if req == nil {
return nil, &Error{Code: "invalid_request", Message: "http request is nil"}
}
providerKey := executorKeyFromAuth(auth)
if providerKey == "" {
return nil, &Error{Code: "provider_not_found", Message: "auth provider is empty"}
}
exec := m.executorFor(providerKey)
if exec == nil {
return nil, &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
}
return exec.HttpRequest(ctx, auth, req)
}