feat(auth): implement auto-refresh loop for managing auth token schedule
- Introduced `authAutoRefreshLoop` to handle token refresh scheduling. - Replaced semaphore-based refresh logic in `Manager` with the new loop. - Added unit tests to verify refresh schedule logic and edge cases.
This commit is contained in:
@@ -0,0 +1,444 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type authAutoRefreshLoop struct {
|
||||
manager *Manager
|
||||
interval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
queue refreshMinHeap
|
||||
index map[string]*refreshHeapItem
|
||||
dirty map[string]struct{}
|
||||
|
||||
wakeCh chan struct{}
|
||||
jobs chan string
|
||||
}
|
||||
|
||||
func newAuthAutoRefreshLoop(manager *Manager, interval time.Duration) *authAutoRefreshLoop {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
jobBuffer := refreshMaxConcurrency * 4
|
||||
if jobBuffer < 64 {
|
||||
jobBuffer = 64
|
||||
}
|
||||
return &authAutoRefreshLoop{
|
||||
manager: manager,
|
||||
interval: interval,
|
||||
index: make(map[string]*refreshHeapItem),
|
||||
dirty: make(map[string]struct{}),
|
||||
wakeCh: make(chan struct{}, 1),
|
||||
jobs: make(chan string, jobBuffer),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) queueReschedule(authID string) {
|
||||
if l == nil || authID == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.dirty[authID] = struct{}{}
|
||||
l.mu.Unlock()
|
||||
select {
|
||||
case l.wakeCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) run(ctx context.Context) {
|
||||
if l == nil || l.manager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < refreshMaxConcurrency; i++ {
|
||||
go l.worker(ctx)
|
||||
}
|
||||
|
||||
l.loop(ctx)
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) worker(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case authID := <-l.jobs:
|
||||
if authID == "" {
|
||||
continue
|
||||
}
|
||||
l.manager.refreshAuth(ctx, authID)
|
||||
l.queueReschedule(authID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) rebuild(now time.Time) {
|
||||
type entry struct {
|
||||
id string
|
||||
next time.Time
|
||||
}
|
||||
|
||||
entries := make([]entry, 0)
|
||||
|
||||
l.manager.mu.RLock()
|
||||
for id, auth := range l.manager.auths {
|
||||
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, entry{id: id, next: next})
|
||||
}
|
||||
l.manager.mu.RUnlock()
|
||||
|
||||
l.mu.Lock()
|
||||
l.queue = l.queue[:0]
|
||||
l.index = make(map[string]*refreshHeapItem, len(entries))
|
||||
for _, e := range entries {
|
||||
item := &refreshHeapItem{id: e.id, next: e.next}
|
||||
heap.Push(&l.queue, item)
|
||||
l.index[e.id] = item
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) loop(ctx context.Context) {
|
||||
timer := time.NewTimer(time.Hour)
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
defer timer.Stop()
|
||||
|
||||
var timerCh <-chan time.Time
|
||||
l.resetTimer(timer, &timerCh, time.Now())
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-l.wakeCh:
|
||||
now := time.Now()
|
||||
l.applyDirty(now)
|
||||
l.resetTimer(timer, &timerCh, now)
|
||||
case <-timerCh:
|
||||
now := time.Now()
|
||||
l.handleDue(ctx, now)
|
||||
l.applyDirty(now)
|
||||
l.resetTimer(timer, &timerCh, now)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) resetTimer(timer *time.Timer, timerCh *<-chan time.Time, now time.Time) {
|
||||
next, ok := l.peek()
|
||||
if !ok {
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
*timerCh = nil
|
||||
return
|
||||
}
|
||||
|
||||
wait := next.Sub(now)
|
||||
if wait < 0 {
|
||||
wait = 0
|
||||
}
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(wait)
|
||||
*timerCh = timer.C
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) peek() (time.Time, bool) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.queue) == 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return l.queue[0].next, true
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) handleDue(ctx context.Context, now time.Time) {
|
||||
due := l.popDue(now)
|
||||
if len(due) == 0 {
|
||||
return
|
||||
}
|
||||
if log.IsLevelEnabled(log.DebugLevel) {
|
||||
log.Debugf("auto-refresh scheduler due auths: %d", len(due))
|
||||
}
|
||||
for _, authID := range due {
|
||||
l.handleDueAuth(ctx, now, authID)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) popDue(now time.Time) []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
var due []string
|
||||
for len(l.queue) > 0 {
|
||||
item := l.queue[0]
|
||||
if item == nil || item.next.After(now) {
|
||||
break
|
||||
}
|
||||
popped := heap.Pop(&l.queue).(*refreshHeapItem)
|
||||
if popped == nil {
|
||||
continue
|
||||
}
|
||||
delete(l.index, popped.id)
|
||||
due = append(due, popped.id)
|
||||
}
|
||||
return due
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) handleDueAuth(ctx context.Context, now time.Time, authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
manager := l.manager
|
||||
|
||||
manager.mu.RLock()
|
||||
auth := manager.auths[authID]
|
||||
if auth == nil {
|
||||
manager.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
next, shouldSchedule := nextRefreshCheckAt(now, auth, l.interval)
|
||||
shouldRefresh := manager.shouldRefresh(auth, now)
|
||||
exec := manager.executors[auth.Provider]
|
||||
manager.mu.RUnlock()
|
||||
|
||||
if !shouldSchedule {
|
||||
l.remove(authID)
|
||||
return
|
||||
}
|
||||
|
||||
if !shouldRefresh {
|
||||
l.upsert(authID, next)
|
||||
return
|
||||
}
|
||||
|
||||
if exec == nil {
|
||||
l.upsert(authID, now.Add(l.interval))
|
||||
return
|
||||
}
|
||||
|
||||
if !manager.markRefreshPending(authID, now) {
|
||||
manager.mu.RLock()
|
||||
auth = manager.auths[authID]
|
||||
next, shouldSchedule = nextRefreshCheckAt(now, auth, l.interval)
|
||||
manager.mu.RUnlock()
|
||||
if shouldSchedule {
|
||||
l.upsert(authID, next)
|
||||
} else {
|
||||
l.remove(authID)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case l.jobs <- authID:
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) applyDirty(now time.Time) {
|
||||
dirty := l.drainDirty()
|
||||
if len(dirty) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, authID := range dirty {
|
||||
l.manager.mu.RLock()
|
||||
auth := l.manager.auths[authID]
|
||||
next, ok := nextRefreshCheckAt(now, auth, l.interval)
|
||||
l.manager.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
l.remove(authID)
|
||||
continue
|
||||
}
|
||||
l.upsert(authID, next)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) drainDirty() []string {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if len(l.dirty) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(l.dirty))
|
||||
for authID := range l.dirty {
|
||||
out = append(out, authID)
|
||||
delete(l.dirty, authID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) upsert(authID string, next time.Time) {
|
||||
if authID == "" || next.IsZero() {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if item, ok := l.index[authID]; ok && item != nil {
|
||||
item.next = next
|
||||
heap.Fix(&l.queue, item.index)
|
||||
return
|
||||
}
|
||||
item := &refreshHeapItem{id: authID, next: next}
|
||||
heap.Push(&l.queue, item)
|
||||
l.index[authID] = item
|
||||
}
|
||||
|
||||
func (l *authAutoRefreshLoop) remove(authID string) {
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
item, ok := l.index[authID]
|
||||
if !ok || item == nil {
|
||||
return
|
||||
}
|
||||
heap.Remove(&l.queue, item.index)
|
||||
delete(l.index, authID)
|
||||
}
|
||||
|
||||
func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time.Time, bool) {
|
||||
if auth == nil || auth.Disabled {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
accountType, _ := auth.AccountInfo()
|
||||
if accountType == "api_key" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||
return auth.NextRefreshAfter, true
|
||||
}
|
||||
|
||||
if evaluator, ok := auth.Runtime.(RefreshEvaluator); ok && evaluator != nil {
|
||||
if interval <= 0 {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
return now.Add(interval), true
|
||||
}
|
||||
|
||||
lastRefresh := auth.LastRefreshedAt
|
||||
if lastRefresh.IsZero() {
|
||||
if ts, ok := authLastRefreshTimestamp(auth); ok {
|
||||
lastRefresh = ts
|
||||
}
|
||||
}
|
||||
|
||||
expiry, hasExpiry := auth.ExpirationTime()
|
||||
|
||||
if pref := authPreferredInterval(auth); pref > 0 {
|
||||
candidates := make([]time.Time, 0, 2)
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
if !expiry.After(now) || expiry.Sub(now) <= pref {
|
||||
return now, true
|
||||
}
|
||||
candidates = append(candidates, expiry.Add(-pref))
|
||||
}
|
||||
if lastRefresh.IsZero() {
|
||||
return now, true
|
||||
}
|
||||
candidates = append(candidates, lastRefresh.Add(pref))
|
||||
next := candidates[0]
|
||||
for _, candidate := range candidates[1:] {
|
||||
if candidate.Before(next) {
|
||||
next = candidate
|
||||
}
|
||||
}
|
||||
if !next.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return next, true
|
||||
}
|
||||
|
||||
provider := strings.ToLower(auth.Provider)
|
||||
lead := ProviderRefreshLead(provider, auth.Runtime)
|
||||
if lead == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
dueAt := expiry.Add(-*lead)
|
||||
if !dueAt.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return dueAt, true
|
||||
}
|
||||
if !lastRefresh.IsZero() {
|
||||
dueAt := lastRefresh.Add(*lead)
|
||||
if !dueAt.After(now) {
|
||||
return now, true
|
||||
}
|
||||
return dueAt, true
|
||||
}
|
||||
return now, true
|
||||
}
|
||||
|
||||
type refreshHeapItem struct {
|
||||
id string
|
||||
next time.Time
|
||||
index int
|
||||
}
|
||||
|
||||
type refreshMinHeap []*refreshHeapItem
|
||||
|
||||
func (h refreshMinHeap) Len() int { return len(h) }
|
||||
|
||||
func (h refreshMinHeap) Less(i, j int) bool {
|
||||
return h[i].next.Before(h[j].next)
|
||||
}
|
||||
|
||||
func (h refreshMinHeap) Swap(i, j int) {
|
||||
h[i], h[j] = h[j], h[i]
|
||||
h[i].index = i
|
||||
h[j].index = j
|
||||
}
|
||||
|
||||
func (h *refreshMinHeap) Push(x any) {
|
||||
item, ok := x.(*refreshHeapItem)
|
||||
if !ok || item == nil {
|
||||
return
|
||||
}
|
||||
item.index = len(*h)
|
||||
*h = append(*h, item)
|
||||
}
|
||||
|
||||
func (h *refreshMinHeap) Pop() any {
|
||||
old := *h
|
||||
n := len(old)
|
||||
if n == 0 {
|
||||
return (*refreshHeapItem)(nil)
|
||||
}
|
||||
item := old[n-1]
|
||||
item.index = -1
|
||||
*h = old[:n-1]
|
||||
return item
|
||||
}
|
||||
Reference in New Issue
Block a user