feat(auth): implement weighted provider rotation for improved scheduling fairness
This commit is contained in:
@@ -293,12 +293,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
||||
}
|
||||
|
||||
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||
start := 0
|
||||
if len(normalized) > 0 {
|
||||
start = s.mixedCursors[cursorKey] % len(normalized)
|
||||
weights := make([]int, len(normalized))
|
||||
segmentStarts := make([]int, len(normalized))
|
||||
segmentEnds := make([]int, len(normalized))
|
||||
totalWeight := 0
|
||||
for providerIndex, shard := range candidateShards {
|
||||
segmentStarts[providerIndex] = totalWeight
|
||||
if shard != nil {
|
||||
weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority)
|
||||
}
|
||||
totalWeight += weights[providerIndex]
|
||||
segmentEnds[providerIndex] = totalWeight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
startSlot := s.mixedCursors[cursorKey] % totalWeight
|
||||
startProviderIndex := -1
|
||||
for providerIndex := range normalized {
|
||||
if weights[providerIndex] == 0 {
|
||||
continue
|
||||
}
|
||||
if startSlot < segmentEnds[providerIndex] {
|
||||
startProviderIndex = providerIndex
|
||||
break
|
||||
}
|
||||
}
|
||||
if startProviderIndex < 0 {
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
slot := startSlot
|
||||
for offset := 0; offset < len(normalized); offset++ {
|
||||
providerIndex := (start + offset) % len(normalized)
|
||||
providerIndex := (startProviderIndex + offset) % len(normalized)
|
||||
if weights[providerIndex] == 0 {
|
||||
continue
|
||||
}
|
||||
if providerIndex != startProviderIndex {
|
||||
slot = segmentStarts[providerIndex]
|
||||
}
|
||||
providerKey := normalized[providerIndex]
|
||||
shard := candidateShards[providerIndex]
|
||||
if shard == nil {
|
||||
@@ -308,7 +342,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
||||
if picked == nil {
|
||||
continue
|
||||
}
|
||||
s.mixedCursors[cursorKey] = providerIndex + 1
|
||||
s.mixedCursors[cursorKey] = slot + 1
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
@@ -704,6 +738,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
|
||||
return picked.auth
|
||||
}
|
||||
|
||||
func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
return 0
|
||||
}
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
return len(bucket.ws.flat)
|
||||
}
|
||||
return len(bucket.all.flat)
|
||||
}
|
||||
|
||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||
now := time.Now()
|
||||
|
||||
@@ -208,7 +208,7 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
|
||||
func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
@@ -218,8 +218,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
|
||||
&Auth{ID: "claude-a", Provider: "claude"},
|
||||
)
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||
for index := range wantProviders {
|
||||
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
@@ -272,7 +272,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
@@ -288,8 +288,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||
if errPick != nil {
|
||||
@@ -399,8 +399,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
|
||||
Reference in New Issue
Block a user