diff --git a/sdk/cliproxy/auth/scheduler.go b/sdk/cliproxy/auth/scheduler.go index bfff53bf..fd8c9490 100644 --- a/sdk/cliproxy/auth/scheduler.go +++ b/sdk/cliproxy/auth/scheduler.go @@ -293,12 +293,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model } cursorKey := strings.Join(normalized, ",") + ":" + modelKey - start := 0 - if len(normalized) > 0 { - start = s.mixedCursors[cursorKey] % len(normalized) + weights := make([]int, len(normalized)) + segmentStarts := make([]int, len(normalized)) + segmentEnds := make([]int, len(normalized)) + totalWeight := 0 + for providerIndex, shard := range candidateShards { + segmentStarts[providerIndex] = totalWeight + if shard != nil { + weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority) + } + totalWeight += weights[providerIndex] + segmentEnds[providerIndex] = totalWeight } + if totalWeight == 0 { + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) + } + + startSlot := s.mixedCursors[cursorKey] % totalWeight + startProviderIndex := -1 + for providerIndex := range normalized { + if weights[providerIndex] == 0 { + continue + } + if startSlot < segmentEnds[providerIndex] { + startProviderIndex = providerIndex + break + } + } + if startProviderIndex < 0 { + return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) + } + + slot := startSlot for offset := 0; offset < len(normalized); offset++ { - providerIndex := (start + offset) % len(normalized) + providerIndex := (startProviderIndex + offset) % len(normalized) + if weights[providerIndex] == 0 { + continue + } + if providerIndex != startProviderIndex { + slot = segmentStarts[providerIndex] + } providerKey := normalized[providerIndex] shard := candidateShards[providerIndex] if shard == nil { @@ -308,7 +342,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model if picked == nil { continue } - s.mixedCursors[cursorKey] = providerIndex + 1 + s.mixedCursors[cursorKey] = slot + 1 return picked, providerKey, nil } return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried) @@ -704,6 +738,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit return picked.auth } +func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int { + if m == nil { + return 0 + } + bucket := m.readyByPriority[priority] + if bucket == nil { + return 0 + } + if preferWebsocket && len(bucket.ws.flat) > 0 { + return len(bucket.ws.flat) + } + return len(bucket.all.flat) +} + // unavailableErrorLocked returns the correct unavailable or cooldown error for the shard. func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error { now := time.Now() diff --git a/sdk/cliproxy/auth/scheduler_test.go b/sdk/cliproxy/auth/scheduler_test.go index e7d435a9..3988c90a 100644 --- a/sdk/cliproxy/auth/scheduler_test.go +++ b/sdk/cliproxy/auth/scheduler_test.go @@ -208,7 +208,7 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) } } -func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) { +func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) { t.Parallel() scheduler := newSchedulerForTest( @@ -218,8 +218,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t * &Auth{ID: "claude-a", Provider: "claude"}, ) - wantProviders := []string{"gemini", "claude", "gemini", "claude"} - wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + wantProviders := []string{"gemini", "gemini", "claude", "gemini"} + wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"} for index := range wantProviders { got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) if errPick != nil { @@ -272,7 +272,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) { } } -func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) { +func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) { t.Parallel() manager := NewManager(nil, &RoundRobinSelector{}, nil) @@ -288,8 +288,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t t.Fatalf("Register(claude-a) error = %v", errRegister) } - wantProviders := []string{"gemini", "claude", "gemini", "claude"} - wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + wantProviders := []string{"gemini", "gemini", "claude", "gemini"} + wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"} for index := range wantProviders { got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{}) if errPick != nil { @@ -399,8 +399,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) { t.Fatalf("Register(claude-a) error = %v", errRegister) } - wantProviders := []string{"gemini", "claude", "gemini", "claude"} - wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"} + wantProviders := []string{"gemini", "gemini", "claude", "gemini"} + wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"} for index := range wantProviders { got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil) if errPick != nil {