feat(auth): implement weighted provider rotation for improved scheduling fairness
This commit is contained in:
@@ -293,12 +293,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
|||||||
}
|
}
|
||||||
|
|
||||||
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||||
start := 0
|
weights := make([]int, len(normalized))
|
||||||
if len(normalized) > 0 {
|
segmentStarts := make([]int, len(normalized))
|
||||||
start = s.mixedCursors[cursorKey] % len(normalized)
|
segmentEnds := make([]int, len(normalized))
|
||||||
|
totalWeight := 0
|
||||||
|
for providerIndex, shard := range candidateShards {
|
||||||
|
segmentStarts[providerIndex] = totalWeight
|
||||||
|
if shard != nil {
|
||||||
|
weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority)
|
||||||
|
}
|
||||||
|
totalWeight += weights[providerIndex]
|
||||||
|
segmentEnds[providerIndex] = totalWeight
|
||||||
}
|
}
|
||||||
|
if totalWeight == 0 {
|
||||||
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
|
}
|
||||||
|
|
||||||
|
startSlot := s.mixedCursors[cursorKey] % totalWeight
|
||||||
|
startProviderIndex := -1
|
||||||
|
for providerIndex := range normalized {
|
||||||
|
if weights[providerIndex] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if startSlot < segmentEnds[providerIndex] {
|
||||||
|
startProviderIndex = providerIndex
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if startProviderIndex < 0 {
|
||||||
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
|
}
|
||||||
|
|
||||||
|
slot := startSlot
|
||||||
for offset := 0; offset < len(normalized); offset++ {
|
for offset := 0; offset < len(normalized); offset++ {
|
||||||
providerIndex := (start + offset) % len(normalized)
|
providerIndex := (startProviderIndex + offset) % len(normalized)
|
||||||
|
if weights[providerIndex] == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if providerIndex != startProviderIndex {
|
||||||
|
slot = segmentStarts[providerIndex]
|
||||||
|
}
|
||||||
providerKey := normalized[providerIndex]
|
providerKey := normalized[providerIndex]
|
||||||
shard := candidateShards[providerIndex]
|
shard := candidateShards[providerIndex]
|
||||||
if shard == nil {
|
if shard == nil {
|
||||||
@@ -308,7 +342,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
|||||||
if picked == nil {
|
if picked == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.mixedCursors[cursorKey] = providerIndex + 1
|
s.mixedCursors[cursorKey] = slot + 1
|
||||||
return picked, providerKey, nil
|
return picked, providerKey, nil
|
||||||
}
|
}
|
||||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||||
@@ -704,6 +738,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
|
|||||||
return picked.auth
|
return picked.auth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int {
|
||||||
|
if m == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
bucket := m.readyByPriority[priority]
|
||||||
|
if bucket == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||||
|
return len(bucket.ws.flat)
|
||||||
|
}
|
||||||
|
return len(bucket.all.flat)
|
||||||
|
}
|
||||||
|
|
||||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||||
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
|
func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
scheduler := newSchedulerForTest(
|
scheduler := newSchedulerForTest(
|
||||||
@@ -218,8 +218,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
|
|||||||
&Auth{ID: "claude-a", Provider: "claude"},
|
&Auth{ID: "claude-a", Provider: "claude"},
|
||||||
)
|
)
|
||||||
|
|
||||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||||
for index := range wantProviders {
|
for index := range wantProviders {
|
||||||
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||||
if errPick != nil {
|
if errPick != nil {
|
||||||
@@ -272,7 +272,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||||
@@ -288,8 +288,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t
|
|||||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||||
}
|
}
|
||||||
|
|
||||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||||
for index := range wantProviders {
|
for index := range wantProviders {
|
||||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||||
if errPick != nil {
|
if errPick != nil {
|
||||||
@@ -399,8 +399,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
|
|||||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||||
}
|
}
|
||||||
|
|
||||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||||
for index := range wantProviders {
|
for index := range wantProviders {
|
||||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||||
if errPick != nil {
|
if errPick != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user