test(auth-scheduler): add unit tests and scheduler implementation
- Added comprehensive unit tests for `authScheduler` and related components. - Implemented `authScheduler` with support for Round Robin, Fill First, and custom selector strategies. - Improved tracking of auth states, cooldowns, and recovery logic in scheduler.
This commit is contained in:
851
sdk/cliproxy/auth/scheduler.go
Normal file
851
sdk/cliproxy/auth/scheduler.go
Normal file
@@ -0,0 +1,851 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
|
||||
type schedulerStrategy int
|
||||
|
||||
const (
|
||||
schedulerStrategyCustom schedulerStrategy = iota
|
||||
schedulerStrategyRoundRobin
|
||||
schedulerStrategyFillFirst
|
||||
)
|
||||
|
||||
// scheduledState describes how an auth currently participates in a model shard.
|
||||
type scheduledState int
|
||||
|
||||
const (
|
||||
scheduledStateReady scheduledState = iota
|
||||
scheduledStateCooldown
|
||||
scheduledStateBlocked
|
||||
scheduledStateDisabled
|
||||
)
|
||||
|
||||
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
|
||||
type authScheduler struct {
|
||||
mu sync.Mutex
|
||||
strategy schedulerStrategy
|
||||
providers map[string]*providerScheduler
|
||||
authProviders map[string]string
|
||||
mixedCursors map[string]int
|
||||
}
|
||||
|
||||
// providerScheduler stores auth metadata and model shards for a single provider.
|
||||
type providerScheduler struct {
|
||||
providerKey string
|
||||
auths map[string]*scheduledAuthMeta
|
||||
modelShards map[string]*modelScheduler
|
||||
}
|
||||
|
||||
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
|
||||
type scheduledAuthMeta struct {
|
||||
auth *Auth
|
||||
providerKey string
|
||||
priority int
|
||||
virtualParent string
|
||||
websocketEnabled bool
|
||||
supportedModelSet map[string]struct{}
|
||||
}
|
||||
|
||||
// modelScheduler tracks ready and blocked auths for one provider/model combination.
|
||||
type modelScheduler struct {
|
||||
modelKey string
|
||||
entries map[string]*scheduledAuth
|
||||
priorityOrder []int
|
||||
readyByPriority map[int]*readyBucket
|
||||
blocked cooldownQueue
|
||||
}
|
||||
|
||||
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
|
||||
type scheduledAuth struct {
|
||||
meta *scheduledAuthMeta
|
||||
auth *Auth
|
||||
state scheduledState
|
||||
nextRetryAt time.Time
|
||||
}
|
||||
|
||||
// readyBucket keeps the ready views for one priority level.
|
||||
type readyBucket struct {
|
||||
all readyView
|
||||
ws readyView
|
||||
}
|
||||
|
||||
// readyView holds the selection order for flat or grouped round-robin traversal.
|
||||
type readyView struct {
|
||||
flat []*scheduledAuth
|
||||
cursor int
|
||||
parentOrder []string
|
||||
parentCursor int
|
||||
children map[string]*childBucket
|
||||
}
|
||||
|
||||
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
|
||||
type childBucket struct {
|
||||
items []*scheduledAuth
|
||||
cursor int
|
||||
}
|
||||
|
||||
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
|
||||
type cooldownQueue []*scheduledAuth
|
||||
|
||||
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
|
||||
func newAuthScheduler(selector Selector) *authScheduler {
|
||||
return &authScheduler{
|
||||
strategy: selectorStrategy(selector),
|
||||
providers: make(map[string]*providerScheduler),
|
||||
authProviders: make(map[string]string),
|
||||
mixedCursors: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
|
||||
func selectorStrategy(selector Selector) schedulerStrategy {
|
||||
switch selector.(type) {
|
||||
case *FillFirstSelector:
|
||||
return schedulerStrategyFillFirst
|
||||
case nil, *RoundRobinSelector:
|
||||
return schedulerStrategyRoundRobin
|
||||
default:
|
||||
return schedulerStrategyCustom
|
||||
}
|
||||
}
|
||||
|
||||
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
|
||||
func (s *authScheduler) setSelector(selector Selector) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.strategy = selectorStrategy(selector)
|
||||
clear(s.mixedCursors)
|
||||
}
|
||||
|
||||
// rebuild recreates the complete scheduler state from an auth snapshot.
|
||||
func (s *authScheduler) rebuild(auths []*Auth) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.providers = make(map[string]*providerScheduler)
|
||||
s.authProviders = make(map[string]string)
|
||||
s.mixedCursors = make(map[string]int)
|
||||
now := time.Now()
|
||||
for _, auth := range auths {
|
||||
s.upsertAuthLocked(auth, now)
|
||||
}
|
||||
}
|
||||
|
||||
// upsertAuth incrementally synchronizes one auth into the scheduler.
|
||||
func (s *authScheduler) upsertAuth(auth *Auth) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.upsertAuthLocked(auth, time.Now())
|
||||
}
|
||||
|
||||
// removeAuth deletes one auth from every scheduler shard that references it.
|
||||
func (s *authScheduler) removeAuth(authID string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.removeAuthLocked(authID)
|
||||
}
|
||||
|
||||
// pickSingle returns the next auth for a single provider/model request using scheduler state.
|
||||
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
|
||||
if s == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
modelKey := canonicalModelKey(model)
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
predicate := func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil {
|
||||
return false
|
||||
}
|
||||
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
|
||||
return false
|
||||
}
|
||||
if len(tried) > 0 {
|
||||
if _, ok := tried[entry.auth.ID]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
|
||||
return picked, nil
|
||||
}
|
||||
return nil, shard.unavailableErrorLocked(provider, model, predicate)
|
||||
}
|
||||
|
||||
// pickMixed returns the next auth and provider for a mixed-provider request.
|
||||
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
|
||||
if s == nil {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
normalized := normalizeProviderKeys(providers)
|
||||
if len(normalized) == 0 {
|
||||
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
modelKey := canonicalModelKey(model)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if pinnedAuthID != "" {
|
||||
providerKey := s.authProviders[pinnedAuthID]
|
||||
if providerKey == "" || !containsProvider(normalized, providerKey) {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
predicate := func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
|
||||
return false
|
||||
}
|
||||
if len(tried) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := tried[pinnedAuthID]
|
||||
return !ok
|
||||
}
|
||||
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
|
||||
}
|
||||
|
||||
if s.strategy == schedulerStrategyFillFirst {
|
||||
for _, providerKey := range normalized {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyLocked(false, s.strategy, triedPredicate(tried))
|
||||
if picked != nil {
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||
start := 0
|
||||
if len(normalized) > 0 {
|
||||
start = s.mixedCursors[cursorKey] % len(normalized)
|
||||
}
|
||||
for offset := 0; offset < len(normalized); offset++ {
|
||||
providerIndex := (start + offset) % len(normalized)
|
||||
providerKey := normalized[providerIndex]
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(modelKey, time.Now())
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
picked := shard.pickReadyLocked(false, schedulerStrategyRoundRobin, triedPredicate(tried))
|
||||
if picked == nil {
|
||||
continue
|
||||
}
|
||||
s.mixedCursors[cursorKey] = providerIndex + 1
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
|
||||
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
|
||||
now := time.Now()
|
||||
total := 0
|
||||
cooldownCount := 0
|
||||
earliest := time.Time{}
|
||||
for _, providerKey := range providers {
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
continue
|
||||
}
|
||||
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
|
||||
total += localTotal
|
||||
cooldownCount += localCooldownCount
|
||||
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
|
||||
earliest = localEarliest
|
||||
}
|
||||
}
|
||||
if total == 0 {
|
||||
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
if cooldownCount == total && !earliest.IsZero() {
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return newModelCooldownError(model, "", resetIn)
|
||||
}
|
||||
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// triedPredicate builds a filter that excludes auths already attempted for the current request.
|
||||
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
|
||||
if len(tried) == 0 {
|
||||
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
|
||||
}
|
||||
return func(entry *scheduledAuth) bool {
|
||||
if entry == nil || entry.auth == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := tried[entry.auth.ID]
|
||||
return !ok
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
|
||||
func normalizeProviderKeys(providers []string) []string {
|
||||
seen := make(map[string]struct{}, len(providers))
|
||||
out := make([]string, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[providerKey]; ok {
|
||||
continue
|
||||
}
|
||||
seen[providerKey] = struct{}{}
|
||||
out = append(out, providerKey)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// containsProvider reports whether provider is present in the normalized provider list.
|
||||
func containsProvider(providers []string, provider string) bool {
|
||||
for _, candidate := range providers {
|
||||
if candidate == provider {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
|
||||
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
authID := strings.TrimSpace(auth.ID)
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if authID == "" || providerKey == "" || auth.Disabled {
|
||||
s.removeAuthLocked(authID)
|
||||
return
|
||||
}
|
||||
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
|
||||
if previousState := s.providers[previousProvider]; previousState != nil {
|
||||
previousState.removeAuthLocked(authID)
|
||||
}
|
||||
}
|
||||
meta := buildScheduledAuthMeta(auth)
|
||||
s.authProviders[authID] = providerKey
|
||||
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
|
||||
}
|
||||
|
||||
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
|
||||
func (s *authScheduler) removeAuthLocked(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
if providerKey := s.authProviders[authID]; providerKey != "" {
|
||||
if providerState := s.providers[providerKey]; providerState != nil {
|
||||
providerState.removeAuthLocked(authID)
|
||||
}
|
||||
delete(s.authProviders, authID)
|
||||
}
|
||||
}
|
||||
|
||||
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
|
||||
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
|
||||
if s.providers == nil {
|
||||
s.providers = make(map[string]*providerScheduler)
|
||||
}
|
||||
providerState := s.providers[providerKey]
|
||||
if providerState == nil {
|
||||
providerState = &providerScheduler{
|
||||
providerKey: providerKey,
|
||||
auths: make(map[string]*scheduledAuthMeta),
|
||||
modelShards: make(map[string]*modelScheduler),
|
||||
}
|
||||
s.providers[providerKey] = providerState
|
||||
}
|
||||
return providerState
|
||||
}
|
||||
|
||||
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
|
||||
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
virtualParent := ""
|
||||
if auth.Attributes != nil {
|
||||
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
return &scheduledAuthMeta{
|
||||
auth: auth,
|
||||
providerKey: providerKey,
|
||||
priority: authPriority(auth),
|
||||
virtualParent: virtualParent,
|
||||
websocketEnabled: authWebsocketsEnabled(auth),
|
||||
supportedModelSet: supportedModelSetForAuth(auth.ID),
|
||||
}
|
||||
}
|
||||
|
||||
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
|
||||
func supportedModelSetForAuth(authID string) map[string]struct{} {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return nil
|
||||
}
|
||||
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
set := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelKey := canonicalModelKey(model.ID)
|
||||
if modelKey == "" {
|
||||
continue
|
||||
}
|
||||
set[modelKey] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
|
||||
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||
if p == nil || meta == nil || meta.auth == nil {
|
||||
return
|
||||
}
|
||||
p.auths[meta.auth.ID] = meta
|
||||
for modelKey, shard := range p.modelShards {
|
||||
if shard == nil {
|
||||
continue
|
||||
}
|
||||
if !meta.supportsModel(modelKey) {
|
||||
shard.removeEntryLocked(meta.auth.ID)
|
||||
continue
|
||||
}
|
||||
shard.upsertEntryLocked(meta, now)
|
||||
}
|
||||
}
|
||||
|
||||
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
|
||||
func (p *providerScheduler) removeAuthLocked(authID string) {
|
||||
if p == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
delete(p.auths, authID)
|
||||
for _, shard := range p.modelShards {
|
||||
if shard != nil {
|
||||
shard.removeEntryLocked(authID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
|
||||
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
modelKey = canonicalModelKey(modelKey)
|
||||
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
|
||||
shard.promoteExpiredLocked(now)
|
||||
return shard
|
||||
}
|
||||
shard := &modelScheduler{
|
||||
modelKey: modelKey,
|
||||
entries: make(map[string]*scheduledAuth),
|
||||
readyByPriority: make(map[int]*readyBucket),
|
||||
}
|
||||
for _, meta := range p.auths {
|
||||
if meta == nil || !meta.supportsModel(modelKey) {
|
||||
continue
|
||||
}
|
||||
shard.upsertEntryLocked(meta, now)
|
||||
}
|
||||
p.modelShards[modelKey] = shard
|
||||
return shard
|
||||
}
|
||||
|
||||
// supportsModel reports whether the auth metadata currently supports modelKey.
|
||||
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
|
||||
modelKey = canonicalModelKey(modelKey)
|
||||
if modelKey == "" {
|
||||
return true
|
||||
}
|
||||
if len(m.supportedModelSet) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := m.supportedModelSet[modelKey]
|
||||
return ok
|
||||
}
|
||||
|
||||
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
|
||||
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
|
||||
if m == nil || meta == nil || meta.auth == nil {
|
||||
return
|
||||
}
|
||||
entry, ok := m.entries[meta.auth.ID]
|
||||
if !ok || entry == nil {
|
||||
entry = &scheduledAuth{}
|
||||
m.entries[meta.auth.ID] = entry
|
||||
}
|
||||
previousState := entry.state
|
||||
previousNextRetryAt := entry.nextRetryAt
|
||||
previousPriority := 0
|
||||
previousParent := ""
|
||||
previousWebsocketEnabled := false
|
||||
if entry.meta != nil {
|
||||
previousPriority = entry.meta.priority
|
||||
previousParent = entry.meta.virtualParent
|
||||
previousWebsocketEnabled = entry.meta.websocketEnabled
|
||||
}
|
||||
|
||||
entry.meta = meta
|
||||
entry.auth = meta.auth
|
||||
entry.nextRetryAt = time.Time{}
|
||||
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
|
||||
switch {
|
||||
case !blocked:
|
||||
entry.state = scheduledStateReady
|
||||
case reason == blockReasonCooldown:
|
||||
entry.state = scheduledStateCooldown
|
||||
entry.nextRetryAt = next
|
||||
case reason == blockReasonDisabled:
|
||||
entry.state = scheduledStateDisabled
|
||||
default:
|
||||
entry.state = scheduledStateBlocked
|
||||
entry.nextRetryAt = next
|
||||
}
|
||||
|
||||
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
|
||||
return
|
||||
}
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
|
||||
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
|
||||
func (m *modelScheduler) removeEntryLocked(authID string) {
|
||||
if m == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := m.entries[authID]; !ok {
|
||||
return
|
||||
}
|
||||
delete(m.entries, authID)
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
|
||||
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
|
||||
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
|
||||
if m == nil || len(m.blocked) == 0 {
|
||||
return
|
||||
}
|
||||
changed := false
|
||||
for _, entry := range m.blocked {
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
|
||||
continue
|
||||
}
|
||||
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
|
||||
switch {
|
||||
case !blocked:
|
||||
entry.state = scheduledStateReady
|
||||
entry.nextRetryAt = time.Time{}
|
||||
case reason == blockReasonCooldown:
|
||||
entry.state = scheduledStateCooldown
|
||||
entry.nextRetryAt = next
|
||||
case reason == blockReasonDisabled:
|
||||
entry.state = scheduledStateDisabled
|
||||
entry.nextRetryAt = time.Time{}
|
||||
default:
|
||||
entry.state = scheduledStateBlocked
|
||||
entry.nextRetryAt = next
|
||||
}
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
m.rebuildIndexesLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
|
||||
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.promoteExpiredLocked(time.Now())
|
||||
for _, priority := range m.priorityOrder {
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
continue
|
||||
}
|
||||
view := &bucket.all
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
view = &bucket.ws
|
||||
}
|
||||
var picked *scheduledAuth
|
||||
if strategy == schedulerStrategyFillFirst {
|
||||
picked = view.pickFirst(predicate)
|
||||
} else {
|
||||
picked = view.pickRoundRobin(predicate)
|
||||
}
|
||||
if picked != nil && picked.auth != nil {
|
||||
return picked.auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||
now := time.Now()
|
||||
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
|
||||
if total == 0 {
|
||||
return &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
if cooldownCount == total && !earliest.IsZero() {
|
||||
providerForError := provider
|
||||
if providerForError == "mixed" {
|
||||
providerForError = ""
|
||||
}
|
||||
resetIn := earliest.Sub(now)
|
||||
if resetIn < 0 {
|
||||
resetIn = 0
|
||||
}
|
||||
return newModelCooldownError(model, providerForError, resetIn)
|
||||
}
|
||||
return &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
|
||||
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
|
||||
if m == nil {
|
||||
return 0, 0, time.Time{}
|
||||
}
|
||||
total := 0
|
||||
cooldownCount := 0
|
||||
earliest := time.Time{}
|
||||
for _, entry := range m.entries {
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
total++
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
if entry.state != scheduledStateCooldown {
|
||||
continue
|
||||
}
|
||||
cooldownCount++
|
||||
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
|
||||
earliest = entry.nextRetryAt
|
||||
}
|
||||
}
|
||||
return total, cooldownCount, earliest
|
||||
}
|
||||
|
||||
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
|
||||
func (m *modelScheduler) rebuildIndexesLocked() {
|
||||
m.readyByPriority = make(map[int]*readyBucket)
|
||||
m.priorityOrder = m.priorityOrder[:0]
|
||||
m.blocked = m.blocked[:0]
|
||||
priorityBuckets := make(map[int][]*scheduledAuth)
|
||||
for _, entry := range m.entries {
|
||||
if entry == nil || entry.auth == nil {
|
||||
continue
|
||||
}
|
||||
switch entry.state {
|
||||
case scheduledStateReady:
|
||||
priority := entry.meta.priority
|
||||
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
|
||||
case scheduledStateCooldown, scheduledStateBlocked:
|
||||
m.blocked = append(m.blocked, entry)
|
||||
}
|
||||
}
|
||||
for priority, entries := range priorityBuckets {
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
return entries[i].auth.ID < entries[j].auth.ID
|
||||
})
|
||||
m.readyByPriority[priority] = buildReadyBucket(entries)
|
||||
m.priorityOrder = append(m.priorityOrder, priority)
|
||||
}
|
||||
sort.Slice(m.priorityOrder, func(i, j int) bool {
|
||||
return m.priorityOrder[i] > m.priorityOrder[j]
|
||||
})
|
||||
sort.Slice(m.blocked, func(i, j int) bool {
|
||||
left := m.blocked[i]
|
||||
right := m.blocked[j]
|
||||
if left == nil || right == nil {
|
||||
return left != nil
|
||||
}
|
||||
if left.nextRetryAt.Equal(right.nextRetryAt) {
|
||||
return left.auth.ID < right.auth.ID
|
||||
}
|
||||
if left.nextRetryAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
if right.nextRetryAt.IsZero() {
|
||||
return true
|
||||
}
|
||||
return left.nextRetryAt.Before(right.nextRetryAt)
|
||||
})
|
||||
}
|
||||
|
||||
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
|
||||
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
|
||||
bucket := &readyBucket{}
|
||||
bucket.all = buildReadyView(entries)
|
||||
wsEntries := make([]*scheduledAuth, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
|
||||
wsEntries = append(wsEntries, entry)
|
||||
}
|
||||
}
|
||||
bucket.ws = buildReadyView(wsEntries)
|
||||
return bucket
|
||||
}
|
||||
|
||||
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
|
||||
func buildReadyView(entries []*scheduledAuth) readyView {
|
||||
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
|
||||
if len(entries) == 0 {
|
||||
return view
|
||||
}
|
||||
groups := make(map[string][]*scheduledAuth)
|
||||
for _, entry := range entries {
|
||||
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
|
||||
return view
|
||||
}
|
||||
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
|
||||
}
|
||||
if len(groups) <= 1 {
|
||||
return view
|
||||
}
|
||||
view.children = make(map[string]*childBucket, len(groups))
|
||||
view.parentOrder = make([]string, 0, len(groups))
|
||||
for parent := range groups {
|
||||
view.parentOrder = append(view.parentOrder, parent)
|
||||
}
|
||||
sort.Strings(view.parentOrder)
|
||||
for _, parent := range view.parentOrder {
|
||||
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
|
||||
}
|
||||
return view
|
||||
}
|
||||
|
||||
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
|
||||
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
for _, entry := range v.flat {
|
||||
if predicate == nil || predicate(entry) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
|
||||
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
if len(v.parentOrder) > 1 && len(v.children) > 0 {
|
||||
return v.pickGroupedRoundRobin(predicate)
|
||||
}
|
||||
if len(v.flat) == 0 {
|
||||
return nil
|
||||
}
|
||||
start := 0
|
||||
if len(v.flat) > 0 {
|
||||
start = v.cursor % len(v.flat)
|
||||
}
|
||||
for offset := 0; offset < len(v.flat); offset++ {
|
||||
index := (start + offset) % len(v.flat)
|
||||
entry := v.flat[index]
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
v.cursor = index + 1
|
||||
return entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
|
||||
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
|
||||
start := 0
|
||||
if len(v.parentOrder) > 0 {
|
||||
start = v.parentCursor % len(v.parentOrder)
|
||||
}
|
||||
for offset := 0; offset < len(v.parentOrder); offset++ {
|
||||
parentIndex := (start + offset) % len(v.parentOrder)
|
||||
parent := v.parentOrder[parentIndex]
|
||||
child := v.children[parent]
|
||||
if child == nil || len(child.items) == 0 {
|
||||
continue
|
||||
}
|
||||
itemStart := child.cursor % len(child.items)
|
||||
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
|
||||
itemIndex := (itemStart + itemOffset) % len(child.items)
|
||||
entry := child.items[itemIndex]
|
||||
if predicate != nil && !predicate(entry) {
|
||||
continue
|
||||
}
|
||||
child.cursor = itemIndex + 1
|
||||
v.parentCursor = parentIndex + 1
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user