feat(auth): implement weighted provider rotation for improved scheduling fairness

This commit is contained in:
Luis Pater
2026-03-29 13:49:01 +08:00
parent 1587ff5e74
commit 6d8de0ade4
2 changed files with 61 additions and 13 deletions

View File

@@ -293,12 +293,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
weights := make([]int, len(normalized))
segmentStarts := make([]int, len(normalized))
segmentEnds := make([]int, len(normalized))
totalWeight := 0
for providerIndex, shard := range candidateShards {
segmentStarts[providerIndex] = totalWeight
if shard != nil {
weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority)
}
totalWeight += weights[providerIndex]
segmentEnds[providerIndex] = totalWeight
}
if totalWeight == 0 {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
startSlot := s.mixedCursors[cursorKey] % totalWeight
startProviderIndex := -1
for providerIndex := range normalized {
if weights[providerIndex] == 0 {
continue
}
if startSlot < segmentEnds[providerIndex] {
startProviderIndex = providerIndex
break
}
}
if startProviderIndex < 0 {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
slot := startSlot
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerIndex := (startProviderIndex + offset) % len(normalized)
if weights[providerIndex] == 0 {
continue
}
if providerIndex != startProviderIndex {
slot = segmentStarts[providerIndex]
}
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
@@ -308,7 +342,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
s.mixedCursors[cursorKey] = slot + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
@@ -704,6 +738,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
return picked.auth
}
func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int {
if m == nil {
return 0
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return 0
}
if preferWebsocket && len(bucket.ws.flat) > 0 {
return len(bucket.ws.flat)
}
return len(bucket.all.flat)
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()

View File

@@ -208,7 +208,7 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
@@ -218,8 +218,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
@@ -272,7 +272,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
@@ -288,8 +288,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
@@ -399,8 +399,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {