feat(auth): improve unauthorized error handling for refresh and auto-refresh
- Added `isUnauthorizedError` and `hasUnauthorizedAuthFailure` to classify and handle unauthorized errors. - Introduced `refreshErrorFromError` to map errors to standardized unauthorized responses. - Modified refresh logic to stop auto-refresh retries for unauthorized errors. - Updated tests to verify unauthorized error handling and refresh retry prevention.
This commit is contained in:
@@ -78,7 +78,7 @@ func RefreshAuthViaHome(ctx context.Context, cfg *config.Config, auth *cliproxya
|
||||
if msg == "" {
|
||||
msg = "home returned error"
|
||||
}
|
||||
return nil, true, homeStatusErr{code: http.StatusBadGateway, msg: msg}
|
||||
return nil, true, homeStatusErr{code: statusFromHomeErrorCode(code), msg: msg}
|
||||
}
|
||||
|
||||
var updated cliproxyauth.Auth
|
||||
@@ -89,3 +89,14 @@ func RefreshAuthViaHome(ctx context.Context, cfg *config.Config, auth *cliproxya
|
||||
updated.EnsureIndex()
|
||||
return &updated, true, nil
|
||||
}
|
||||
|
||||
func statusFromHomeErrorCode(code string) int {
|
||||
switch strings.ToLower(strings.TrimSpace(code)) {
|
||||
case "authentication_error", "unauthorized":
|
||||
return http.StatusUnauthorized
|
||||
case "model_not_found":
|
||||
return http.StatusNotFound
|
||||
default:
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStatusFromHomeErrorCodeMapsAuthenticationErrorToUnauthorized(t *testing.T) {
|
||||
if got := statusFromHomeErrorCode("authentication_error"); got != http.StatusUnauthorized {
|
||||
t.Fatalf("statusFromHomeErrorCode(authentication_error) = %d, want %d", got, http.StatusUnauthorized)
|
||||
}
|
||||
if got := statusFromHomeErrorCode("unauthorized"); got != http.StatusUnauthorized {
|
||||
t.Fatalf("statusFromHomeErrorCode(unauthorized) = %d, want %d", got, http.StatusUnauthorized)
|
||||
}
|
||||
}
|
||||
@@ -339,6 +339,9 @@ func nextRefreshCheckAt(now time.Time, auth *Auth, interval time.Duration) (time
|
||||
if auth == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if hasUnauthorizedAuthFailure(auth) {
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
accountType, _ := auth.AccountInfo()
|
||||
if accountType == "api_key" {
|
||||
|
||||
@@ -2486,6 +2486,40 @@ func statusCodeFromError(err error) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func isUnauthorizedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if statusCodeFromError(err) == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
raw := strings.ToLower(err.Error())
|
||||
return strings.Contains(raw, "status 401") || strings.Contains(raw, "401 unauthorized")
|
||||
}
|
||||
|
||||
func hasUnauthorizedAuthFailure(auth *Auth) bool {
|
||||
if auth == nil || auth.LastError == nil {
|
||||
return false
|
||||
}
|
||||
return auth.LastError.StatusCode() == http.StatusUnauthorized || strings.EqualFold(auth.LastError.Code, "unauthorized")
|
||||
}
|
||||
|
||||
func refreshErrorFromError(err error) *Error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
statusCode := statusCodeFromError(err)
|
||||
if statusCode == 0 && isUnauthorizedError(err) {
|
||||
statusCode = http.StatusUnauthorized
|
||||
}
|
||||
authErr := &Error{Message: err.Error(), HTTPStatus: statusCode}
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
authErr.Code = "unauthorized"
|
||||
authErr.Retryable = false
|
||||
}
|
||||
return authErr
|
||||
}
|
||||
|
||||
func retryAfterFromError(err error) *time.Duration {
|
||||
if err == nil {
|
||||
return nil
|
||||
@@ -3680,6 +3714,9 @@ func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if hasUnauthorizedAuthFailure(a) {
|
||||
return false
|
||||
}
|
||||
if !a.NextRefreshAfter.IsZero() && now.Before(a.NextRefreshAfter) {
|
||||
return false
|
||||
}
|
||||
@@ -3924,11 +3961,19 @@ func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
||||
now := time.Now()
|
||||
if err != nil {
|
||||
unauthorized := isUnauthorizedError(err)
|
||||
shouldReschedule := false
|
||||
m.mu.Lock()
|
||||
if current := m.auths[id]; current != nil {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
current.LastError = &Error{Message: err.Error()}
|
||||
current.LastError = refreshErrorFromError(err)
|
||||
if unauthorized {
|
||||
current.NextRefreshAfter = time.Time{}
|
||||
current.Unavailable = true
|
||||
current.Status = StatusError
|
||||
current.StatusMessage = "unauthorized"
|
||||
} else {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
}
|
||||
m.auths[id] = current
|
||||
shouldReschedule = true
|
||||
if m.scheduler != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
@@ -36,6 +37,59 @@ func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Au
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type unauthorizedRefreshTestExecutor struct {
|
||||
schedulerProviderTestExecutor
|
||||
}
|
||||
|
||||
func (e unauthorizedRefreshTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
return nil, errors.New("token refresh failed with status 401: invalid_grant")
|
||||
}
|
||||
|
||||
func TestManager_RefreshAuthUnauthorizedFailureStopsAutoRefreshRetry(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
manager.RegisterExecutor(unauthorizedRefreshTestExecutor{
|
||||
schedulerProviderTestExecutor: schedulerProviderTestExecutor{provider: "codex"},
|
||||
})
|
||||
|
||||
auth := &Auth{
|
||||
ID: "unauthorized-refresh",
|
||||
Provider: "codex",
|
||||
Metadata: map[string]any{
|
||||
"email": "x@example.com",
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(ctx, auth); errRegister != nil {
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
|
||||
manager.refreshAuth(ctx, auth.ID)
|
||||
|
||||
updated, ok := manager.GetByID(auth.ID)
|
||||
if !ok {
|
||||
t.Fatalf("expected auth %q after refresh", auth.ID)
|
||||
}
|
||||
if updated.LastError == nil {
|
||||
t.Fatal("expected unauthorized refresh failure to be recorded")
|
||||
}
|
||||
if got := updated.LastError.StatusCode(); got != http.StatusUnauthorized {
|
||||
t.Fatalf("LastError.StatusCode() = %d, want %d", got, http.StatusUnauthorized)
|
||||
}
|
||||
if updated.LastError.Code != "unauthorized" {
|
||||
t.Fatalf("LastError.Code = %q, want unauthorized", updated.LastError.Code)
|
||||
}
|
||||
if !updated.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("NextRefreshAfter = %s, want zero for unauthorized refresh failure", updated.NextRefreshAfter)
|
||||
}
|
||||
now := time.Now()
|
||||
if manager.shouldRefresh(updated, now) {
|
||||
t.Fatal("expected unauthorized auth to stop refresh attempts")
|
||||
}
|
||||
if _, shouldSchedule := nextRefreshCheckAt(now, updated, time.Second); shouldSchedule {
|
||||
t.Fatal("expected unauthorized auth to be removed from the auto-refresh schedule")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user