feat(auth): implement auto-refresh loop for managing auth token schedule
- Introduced `authAutoRefreshLoop` to handle token refresh scheduling. - Replaced semaphore-based refresh logic in `Manager` with the new loop. - Added unit tests to verify refresh schedule logic and edge cases.
This commit is contained in:
@@ -0,0 +1,444 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/heap"
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authAutoRefreshLoop struct {
|
||||||
|
manager *Manager
|
||||||
|
interval time.Duration
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
queue refreshMinHeap
|
||||||
|
index map[string]*refreshHeapItem
|
||||||
|
dirty map[string]struct{}
|
||||||
|
|
||||||
|
wakeCh chan struct{}
|
||||||
|
jobs chan string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration) *authAutoRefreshLoop {
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = refreshCheckInterval
|
||||||
|
}
|
||||||
|
jobBuffer := refreshMaxConcurrency * 4
|
||||||
|
if jobBuffer < 64 {
|
||||||
|
jobBuffer = 64
|
||||||
|
}
|
||||||
|
return &authAutoRefreshLoop{
|
||||||
|
manager: manager,
|
||||||
|
interval: interval,
|
||||||
|
index: make(map[string]*refreshHeapItem),
|
||||||
|
dirty: make(map[string]struct{}),
|
||||||
|
wakeCh: make(chan struct{}, 1),
|
||||||
|
jobs: make(chan string, jobBuffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) queueReschedule(authID string) {
|
||||||
|
if l == nil || authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.mu.Lock()
|
||||||
|
l.dirty[authID] = struct{}{}
|
||||||
|
l.mu.Unlock()
|
||||||
|
select {
|
||||||
|
case l.wakeCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) run(ctx context.Context) {
|
||||||
|
if l == nil || l.manager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < refreshMaxConcurrency; i++ {
|
||||||
|
go l.worker(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.loop(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) worker(ctx context.Context) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case authID := <-l.jobs:
|
||||||
|
if authID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
l.manager.refreshAuth(ctx, authID)
|
||||||
|
l.queueReschedule(authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) rebuild(now time.Time) {
|
||||||
|
type entry struct {
|
||||||
|
id string
|
||||||
|
next time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]entry, 0)
|
||||||
|
|
||||||
|
l.manager.mu.RLock()
|
||||||
|
for id, auth := range l.manager.auths {
|
||||||
|
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
entries = append(entries, entry{id: id, next: next})
|
||||||
|
}
|
||||||
|
l.manager.mu.RUnlock()
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
l.queue = l.queue[:0]
|
||||||
|
l.index = make(map[string]*refreshHeapItem, len(entries))
|
||||||
|
for _, e := range entries {
|
||||||
|
item := &refreshHeapItem{id: e.id, next: e.next}
|
||||||
|
heap.Push(&l.queue, item)
|
||||||
|
l.index[e.id] = item
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) loop(ctx context.Context) {
|
||||||
|
timer := time.NewTimer(time.Hour)
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
var timerCh <-chan time.Time
|
||||||
|
l.resetTimer(timer, &timerCh, time.Now())
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-l.wakeCh:
|
||||||
|
now := time.Now()
|
||||||
|
l.applyDirty(now)
|
||||||
|
l.resetTimer(timer, &timerCh, now)
|
||||||
|
case <-timerCh:
|
||||||
|
now := time.Now()
|
||||||
|
l.handleDue(ctx, now)
|
||||||
|
l.applyDirty(now)
|
||||||
|
l.resetTimer(timer, &timerCh, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) {
|
||||||
|
next, ok := l.peek()
|
||||||
|
if !ok {
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*timerCh = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wait := next.Sub(now)
|
||||||
|
if wait < 0 {
|
||||||
|
wait = 0
|
||||||
|
}
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timer.Reset(wait)
|
||||||
|
*timerCh = timer.C
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) peek() (time.Time, bool) {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
if len(l.queue) == 0 {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
return l.queue[0].next, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) {
|
||||||
|
due := l.popDue(now)
|
||||||
|
if len(due) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if log.IsLevelEnabled(log.DebugLevel) {
|
||||||
|
log.Debugf("auto-refresh scheduler due auths: %d", len(due))
|
||||||
|
}
|
||||||
|
for _, authID := range due {
|
||||||
|
l.handleDueAuth(ctx, now, authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) popDue(now time.Time) []string {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
|
var due []string
|
||||||
|
for len(l.queue) > 0 {
|
||||||
|
item := l.queue[0]
|
||||||
|
if item == nil || item.next.After(now) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
popped := heap.Pop(&l.queue).(*refreshHeapItem)
|
||||||
|
if popped == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(l.index, popped.id)
|
||||||
|
due = append(due, popped.id)
|
||||||
|
}
|
||||||
|
return due
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) {
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := l.manager
|
||||||
|
|
||||||
|
manager.mu.RLock()
|
||||||
|
auth := manager.auths[authID]
|
||||||
|
if auth == nil {
|
||||||
|
manager.mu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval)
|
||||||
|
shouldRefresh := manager.shouldRefresh(auth, now)
|
||||||
|
exec := manager.executors[auth.Provider]
|
||||||
|
manager.mu.RUnlock()
|
||||||
|
|
||||||
|
if !shouldSchedule {
|
||||||
|
l.remove(authID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !shouldRefresh {
|
||||||
|
l.upsert(authID, next)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if exec == nil {
|
||||||
|
l.upsert(authID, now.Add(l.interval))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !manager.markRefreshPending(authID, now) {
|
||||||
|
manager.mu.RLock()
|
||||||
|
auth = manager.auths[authID]
|
||||||
|
next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval)
|
||||||
|
manager.mu.RUnlock()
|
||||||
|
if shouldSchedule {
|
||||||
|
l.upsert(authID, next)
|
||||||
|
} else {
|
||||||
|
l.remove(authID)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case l.jobs <- authID:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) applyDirty(now time.Time) {
|
||||||
|
dirty := l.drainDirty()
|
||||||
|
if len(dirty) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, authID := range dirty {
|
||||||
|
l.manager.mu.RLock()
|
||||||
|
auth := l.manager.auths[authID]
|
||||||
|
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||||
|
l.manager.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
l.remove(authID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
l.upsert(authID, next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) drainDirty() []string {
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
if len(l.dirty) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(l.dirty))
|
||||||
|
for authID := range l.dirty {
|
||||||
|
out = append(out, authID)
|
||||||
|
delete(l.dirty, authID)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) {
|
||||||
|
if authID == "" || next.IsZero() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
if item, ok := l.index[authID]; ok && item != nil {
|
||||||
|
item.next = next
|
||||||
|
heap.Fix(&l.queue, item.index)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
item := &refreshHeapItem{id: authID, next: next}
|
||||||
|
heap.Push(&l.queue, item)
|
||||||
|
l.index[authID] = item
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *authAutoRefreshLoop) remove(authID string) {
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
item, ok := l.index[authID]
|
||||||
|
if !ok || item == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
heap.Remove(&l.queue, item.index)
|
||||||
|
delete(l.index, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) {
|
||||||
|
if auth == nil || auth.Disabled {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
accountType, _ := auth.AccountInfo()
|
||||||
|
if accountType == "api_key" {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||||
|
return auth.NextRefreshAfter, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil {
|
||||||
|
if interval <= 0 {
|
||||||
|
interval = refreshCheckInterval
|
||||||
|
}
|
||||||
|
return now.Add(interval), true
|
||||||
|
}
|
||||||
|
|
||||||
|
lastRefresh := auth.LastRefreshedAt
|
||||||
|
if lastRefresh.IsZero() {
|
||||||
|
if ts, ok := authLastRefreshTimestamp(auth); ok {
|
||||||
|
lastRefresh = ts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry, hasExpiry := auth.ExpirationTime()
|
||||||
|
|
||||||
|
if pref := authPreferredInterval(auth); pref > 0 {
|
||||||
|
candidates := make([]time.Time, 0, 2)
|
||||||
|
if hasExpiry && !expiry.IsZero() {
|
||||||
|
if !expiry.After(now) || expiry.Sub(now) <= pref {
|
||||||
|
return now, true
|
||||||
|
}
|
||||||
|
candidates = append(candidates, expiry.Add(-pref))
|
||||||
|
}
|
||||||
|
if lastRefresh.IsZero() {
|
||||||
|
return now, true
|
||||||
|
}
|
||||||
|
candidates = append(candidates, lastRefresh.Add(pref))
|
||||||
|
next := candidates[0]
|
||||||
|
for _, candidate := range candidates[1:] {
|
||||||
|
if candidate.Before(next) {
|
||||||
|
next = candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !next.After(now) {
|
||||||
|
return now, true
|
||||||
|
}
|
||||||
|
return next, true
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := strings.ToLower(auth.Provider)
|
||||||
|
lead := ProviderRefreshLead(provider, auth.Runtime)
|
||||||
|
if lead == nil {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
if hasExpiry && !expiry.IsZero() {
|
||||||
|
dueAt := expiry.Add(-*lead)
|
||||||
|
if !dueAt.After(now) {
|
||||||
|
return now, true
|
||||||
|
}
|
||||||
|
return dueAt, true
|
||||||
|
}
|
||||||
|
if !lastRefresh.IsZero() {
|
||||||
|
dueAt := lastRefresh.Add(*lead)
|
||||||
|
if !dueAt.After(now) {
|
||||||
|
return now, true
|
||||||
|
}
|
||||||
|
return dueAt, true
|
||||||
|
}
|
||||||
|
return now, true
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshHeapItem struct {
|
||||||
|
id string
|
||||||
|
next time.Time
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshMinHeap []*refreshHeapItem
|
||||||
|
|
||||||
|
func (h refreshMinHeap) Len() int { return len(h) }
|
||||||
|
|
||||||
|
func (h refreshMinHeap) Less(i, j int) bool {
|
||||||
|
return h[i].next.Before(h[j].next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h refreshMinHeap) Swap(i, j int) {
|
||||||
|
h[i], h[j] = h[j], h[i]
|
||||||
|
h[i].index = i
|
||||||
|
h[j].index = j
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *refreshMinHeap) Push(x any) {
|
||||||
|
item, ok := x.(*refreshHeapItem)
|
||||||
|
if !ok || item == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
item.index = len(*h)
|
||||||
|
*h = append(*h, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *refreshMinHeap) Pop() any {
|
||||||
|
old := *h
|
||||||
|
n := len(old)
|
||||||
|
if n == 0 {
|
||||||
|
return (*refreshHeapItem)(nil)
|
||||||
|
}
|
||||||
|
item := old[n-1]
|
||||||
|
item.index = -1
|
||||||
|
*h = old[:n-1]
|
||||||
|
return item
|
||||||
|
}
|
||||||
@@ -0,0 +1,137 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testRefreshEvaluator struct{}
|
||||||
|
|
||||||
|
func (testRefreshEvaluator) ShouldRefresh(time.Time, *Auth) bool { return false }
|
||||||
|
|
||||||
|
func setRefreshLeadFactory(t *testing.T, provider string, factory func() *time.Duration) {
|
||||||
|
t.Helper()
|
||||||
|
key := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
refreshLeadMu.Lock()
|
||||||
|
prev, hadPrev := refreshLeadFactories[key]
|
||||||
|
if factory == nil {
|
||||||
|
delete(refreshLeadFactories, key)
|
||||||
|
} else {
|
||||||
|
refreshLeadFactories[key] = factory
|
||||||
|
}
|
||||||
|
refreshLeadMu.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
refreshLeadMu.Lock()
|
||||||
|
if hadPrev {
|
||||||
|
refreshLeadFactories[key] = prev
|
||||||
|
} else {
|
||||||
|
delete(refreshLeadFactories, key)
|
||||||
|
}
|
||||||
|
refreshLeadMu.Unlock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRefreshCheckAt_DisabledUnschedule(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
auth := &Auth{ID: "a1", Provider: "test", Disabled: true}
|
||||||
|
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRefreshCheckAt_APIKeyUnschedule(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
auth := &Auth{ID: "a1", Provider: "test", Attributes: map[string]string{"api_key": "k"}}
|
||||||
|
if _, ok := nextRefreshCheckAt(now, auth, 15*time.Minute); ok {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() ok = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRefreshCheckAt_NextRefreshAfterGate(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
nextAfter := now.Add(30 * time.Minute)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a1",
|
||||||
|
Provider: "test",
|
||||||
|
NextRefreshAfter: nextAfter,
|
||||||
|
Metadata: map[string]any{"email": "x@example.com"},
|
||||||
|
}
|
||||||
|
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||||
|
}
|
||||||
|
if !got.Equal(nextAfter) {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, nextAfter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRefreshCheckAt_PreferredInterval_PicksEarliestCandidate(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
expiry := now.Add(20 * time.Minute)
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a1",
|
||||||
|
Provider: "test",
|
||||||
|
LastRefreshedAt: now,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"email": "x@example.com",
|
||||||
|
"expires_at": expiry.Format(time.RFC3339),
|
||||||
|
"refresh_interval_seconds": 900, // 15m
|
||||||
|
},
|
||||||
|
}
|
||||||
|
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||||
|
}
|
||||||
|
want := expiry.Add(-15 * time.Minute)
|
||||||
|
if !got.Equal(want) {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRefreshCheckAt_ProviderLead_Expiry(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
expiry := now.Add(time.Hour)
|
||||||
|
lead := 10 * time.Minute
|
||||||
|
setRefreshLeadFactory(t, "provider-lead-expiry", func() *time.Duration {
|
||||||
|
d := lead
|
||||||
|
return &d
|
||||||
|
})
|
||||||
|
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a1",
|
||||||
|
Provider: "provider-lead-expiry",
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"email": "x@example.com",
|
||||||
|
"expires_at": expiry.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
got, ok := nextRefreshCheckAt(now, auth, 15*time.Minute)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||||
|
}
|
||||||
|
want := expiry.Add(-lead)
|
||||||
|
if !got.Equal(want) {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextRefreshCheckAt_RefreshEvaluatorFallback(t *testing.T) {
|
||||||
|
now := time.Date(2026, 4, 12, 0, 0, 0, 0, time.UTC)
|
||||||
|
interval := 15 * time.Minute
|
||||||
|
auth := &Auth{
|
||||||
|
ID: "a1",
|
||||||
|
Provider: "test",
|
||||||
|
Metadata: map[string]any{"email": "x@example.com"},
|
||||||
|
Runtime: testRefreshEvaluator{},
|
||||||
|
}
|
||||||
|
got, ok := nextRefreshCheckAt(now, auth, interval)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() ok = false, want true")
|
||||||
|
}
|
||||||
|
want := now.Add(interval)
|
||||||
|
if !got.Equal(want) {
|
||||||
|
t.Fatalf("nextRefreshCheckAt() = %s, want %s", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -162,8 +162,8 @@ type Manager struct {
|
|||||||
rtProvider RoundTripperProvider
|
rtProvider RoundTripperProvider
|
||||||
|
|
||||||
// Auto refresh state
|
// Auto refresh state
|
||||||
refreshCancel context.CancelFunc
|
refreshCancel context.CancelFunc
|
||||||
refreshSemaphore chan struct{}
|
refreshLoop *authAutoRefreshLoop
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager constructs a manager with optional custom selector and hook.
|
// NewManager constructs a manager with optional custom selector and hook.
|
||||||
@@ -182,7 +182,6 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
|||||||
auths: make(map[string]*Auth),
|
auths: make(map[string]*Auth),
|
||||||
providerOffsets: make(map[string]int),
|
providerOffsets: make(map[string]int),
|
||||||
modelPoolOffsets: make(map[string]int),
|
modelPoolOffsets: make(map[string]int),
|
||||||
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
|
|
||||||
}
|
}
|
||||||
// atomic.Value requires non-nil initial value.
|
// atomic.Value requires non-nil initial value.
|
||||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||||
@@ -214,6 +213,16 @@ func (m *Manager) syncScheduler() {
|
|||||||
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
m.syncSchedulerFromSnapshot(m.snapshotAuths())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) snapshotAuths() []*Auth {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
out := make([]*Auth, 0, len(m.auths))
|
||||||
|
for _, a := range m.auths {
|
||||||
|
out = append(out, a.Clone())
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
|
||||||
// supportedModelSet is rebuilt from the current global model registry state.
|
// supportedModelSet is rebuilt from the current global model registry state.
|
||||||
// This must be called after models have been registered for a newly added auth,
|
// This must be called after models have been registered for a newly added auth,
|
||||||
@@ -1088,6 +1097,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
if m.scheduler != nil {
|
if m.scheduler != nil {
|
||||||
m.scheduler.upsertAuth(authClone)
|
m.scheduler.upsertAuth(authClone)
|
||||||
}
|
}
|
||||||
|
m.queueRefreshReschedule(auth.ID)
|
||||||
_ = m.persist(ctx, auth)
|
_ = m.persist(ctx, auth)
|
||||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||||
return auth.Clone(), nil
|
return auth.Clone(), nil
|
||||||
@@ -1118,6 +1128,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
|||||||
if m.scheduler != nil {
|
if m.scheduler != nil {
|
||||||
m.scheduler.upsertAuth(authClone)
|
m.scheduler.upsertAuth(authClone)
|
||||||
}
|
}
|
||||||
|
m.queueRefreshReschedule(auth.ID)
|
||||||
_ = m.persist(ctx, auth)
|
_ = m.persist(ctx, auth)
|
||||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||||
return auth.Clone(), nil
|
return auth.Clone(), nil
|
||||||
@@ -2890,80 +2901,51 @@ func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duratio
|
|||||||
if interval <= 0 {
|
if interval <= 0 {
|
||||||
interval = refreshCheckInterval
|
interval = refreshCheckInterval
|
||||||
}
|
}
|
||||||
if m.refreshCancel != nil {
|
|
||||||
m.refreshCancel()
|
m.mu.Lock()
|
||||||
m.refreshCancel = nil
|
cancel := m.refreshCancel
|
||||||
|
m.refreshCancel = nil
|
||||||
|
m.refreshLoop = nil
|
||||||
|
m.mu.Unlock()
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(parent)
|
ctx, cancel := context.WithCancel(parent)
|
||||||
|
loop := newAuthAutoRefreshLoop(m, interval)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
m.refreshCancel = cancel
|
m.refreshCancel = cancel
|
||||||
go func() {
|
m.refreshLoop = loop
|
||||||
ticker := time.NewTicker(interval)
|
m.mu.Unlock()
|
||||||
defer ticker.Stop()
|
|
||||||
m.checkRefreshes(ctx)
|
loop.rebuild(time.Now())
|
||||||
for {
|
go loop.run(ctx)
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
m.checkRefreshes(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopAutoRefresh cancels the background refresh loop, if running.
|
// StopAutoRefresh cancels the background refresh loop, if running.
|
||||||
func (m *Manager) StopAutoRefresh() {
|
func (m *Manager) StopAutoRefresh() {
|
||||||
if m.refreshCancel != nil {
|
m.mu.Lock()
|
||||||
m.refreshCancel()
|
cancel := m.refreshCancel
|
||||||
m.refreshCancel = nil
|
m.refreshCancel = nil
|
||||||
|
m.refreshLoop = nil
|
||||||
|
m.mu.Unlock()
|
||||||
|
if cancel != nil {
|
||||||
|
cancel()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkRefreshes(ctx context.Context) {
|
func (m *Manager) queueRefreshReschedule(authID string) {
|
||||||
// log.Debugf("checking refreshes")
|
if m == nil || authID == "" {
|
||||||
now := time.Now()
|
|
||||||
snapshot := m.snapshotAuths()
|
|
||||||
for _, a := range snapshot {
|
|
||||||
typ, _ := a.AccountInfo()
|
|
||||||
if typ != "api_key" {
|
|
||||||
if !m.shouldRefresh(a, now) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
|
|
||||||
|
|
||||||
if exec := m.executorFor(a.Provider); exec == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !m.markRefreshPending(a.ID, now) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go m.refreshAuthWithLimit(ctx, a.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
|
|
||||||
if m.refreshSemaphore == nil {
|
|
||||||
m.refreshAuth(ctx, id)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
select {
|
|
||||||
case m.refreshSemaphore <- struct{}{}:
|
|
||||||
defer func() { <-m.refreshSemaphore }()
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.refreshAuth(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) snapshotAuths() []*Auth {
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
loop := m.refreshLoop
|
||||||
out := make([]*Auth, 0, len(m.auths))
|
m.mu.RUnlock()
|
||||||
for _, a := range m.auths {
|
if loop == nil {
|
||||||
out = append(out, a.Clone())
|
return
|
||||||
}
|
}
|
||||||
return out
|
loop.queueReschedule(authID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
||||||
@@ -3173,16 +3155,20 @@ func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
|
|||||||
|
|
||||||
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
|
||||||
auth, ok := m.auths[id]
|
auth, ok := m.auths[id]
|
||||||
if !ok || auth == nil || auth.Disabled {
|
if !ok || auth == nil || auth.Disabled {
|
||||||
|
m.mu.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||||
|
m.mu.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
|
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
|
||||||
m.auths[id] = auth
|
m.auths[id] = auth
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
m.queueRefreshReschedule(id)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -3209,16 +3195,21 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
|||||||
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
shouldReschedule := false
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
if current := m.auths[id]; current != nil {
|
if current := m.auths[id]; current != nil {
|
||||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||||
current.LastError = &Error{Message: err.Error()}
|
current.LastError = &Error{Message: err.Error()}
|
||||||
m.auths[id] = current
|
m.auths[id] = current
|
||||||
|
shouldReschedule = true
|
||||||
if m.scheduler != nil {
|
if m.scheduler != nil {
|
||||||
m.scheduler.upsertAuth(current.Clone())
|
m.scheduler.upsertAuth(current.Clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
if shouldReschedule {
|
||||||
|
m.queueRefreshReschedule(id)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if updated == nil {
|
if updated == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user