feat(auth): add websocket session reuse for home auths with caching support

- Introduced `homeRuntimeAuths` to cache home auths for websocket session reuse.
- Updated `pickNextViaHome` to prioritize cached auths for pinned websocket sessions.
- Implemented automatic clearing of cached home auths when home mode is disabled.
- Added unit tests to validate caching behavior, clearing logic, and fallback scenarios.
This commit is contained in:
Luis Pater
2026-05-10 13:39:14 +08:00
parent a44e5eb1ab
commit dc1cc7f115
2 changed files with 216 additions and 4 deletions
+103 -4
View File
@@ -151,6 +151,9 @@ type Manager struct {
mu sync.RWMutex
auths map[string]*Auth
scheduler *authScheduler
// homeRuntimeAuths caches auths returned by Home so websocket sessions can
// reuse an established upstream credential without dispatching every turn.
homeRuntimeAuths map[string]*Auth
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
@@ -195,6 +198,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
selector: selector,
hook: hook,
auths: make(map[string]*Auth),
homeRuntimeAuths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
}
@@ -376,6 +380,9 @@ func (m *Manager) SetConfig(cfg *internalconfig.Config) {
cfg = &internalconfig.Config{}
}
m.runtimeConfig.Store(cfg)
if !cfg.Home.Enabled {
m.clearHomeRuntimeAuths()
}
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
}
@@ -2713,7 +2720,10 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
defer m.mu.RUnlock()
auth, ok := m.auths[id]
if !ok {
return nil, false
auth, ok = m.homeRuntimeAuths[id]
if !ok {
return nil, false
}
}
return auth.Clone(), true
}
@@ -2751,12 +2761,15 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
return
}
m.mu.RLock()
m.mu.Lock()
if sessionID == CloseAllExecutionSessionsID {
m.clearHomeRuntimeAuthsLocked()
}
executors := make([]ProviderExecutor, 0, len(m.executors))
for _, exec := range m.executors {
executors = append(executors, exec)
}
m.mu.RUnlock()
m.mu.Unlock()
for i := range executors {
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
@@ -3168,6 +3181,80 @@ func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) {
ginCtx.Set("userApiKey", apiKey)
}
func homeExecutionSessionIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.ExecutionSessionMetadataKey]
if !ok || raw == nil {
return ""
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value)
case []byte:
return strings.TrimSpace(string(value))
default:
return ""
}
}
func (m *Manager) clearHomeRuntimeAuths() {
if m == nil {
return
}
m.mu.Lock()
m.clearHomeRuntimeAuthsLocked()
m.mu.Unlock()
}
func (m *Manager) clearHomeRuntimeAuthsLocked() {
if m == nil {
return
}
m.homeRuntimeAuths = make(map[string]*Auth)
}
func (m *Manager) rememberHomeRuntimeAuth(auth *Auth) {
if m == nil || auth == nil || strings.TrimSpace(auth.ID) == "" || !authWebsocketsEnabled(auth) {
return
}
m.mu.Lock()
if m.homeRuntimeAuths == nil {
m.homeRuntimeAuths = make(map[string]*Auth)
}
m.homeRuntimeAuths[auth.ID] = auth.Clone()
m.mu.Unlock()
}
func (m *Manager) homeRuntimeAuthByID(authID string) (*Auth, ProviderExecutor, string, bool) {
authID = strings.TrimSpace(authID)
if m == nil || authID == "" {
return nil, nil, "", false
}
m.mu.RLock()
auth := m.homeRuntimeAuths[authID]
m.mu.RUnlock()
if auth == nil || !authWebsocketsEnabled(auth) {
return nil, nil, "", false
}
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
if providerKey == "" {
return nil, nil, "", false
}
executor, ok := m.Executor(providerKey)
if !ok && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["base_url"]) != "" {
executor, ok = m.Executor("openai-compatibility")
if ok {
providerKey = "openai-compatibility"
}
}
if !ok {
return nil, nil, "", false
}
return auth.Clone(), executor, providerKey, true
}
func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts cliproxyexecutor.Options) (*Auth, ProviderExecutor, string, error) {
if m == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
@@ -3175,6 +3262,14 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro
if ctx == nil {
ctx = context.Background()
}
if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" {
if pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata); pinnedAuthID != "" {
if auth, executor, providerKey, ok := m.homeRuntimeAuthByID(pinnedAuthID); ok {
return auth, executor, providerKey, nil
}
}
}
client := home.Current()
if client == nil || !client.HeartbeatOK() {
return nil, nil, "", &Error{Code: "home_unavailable", Message: "home control center unavailable", HTTPStatus: http.StatusServiceUnavailable}
@@ -3254,7 +3349,11 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered", HTTPStatus: http.StatusBadGateway}
}
return auth.Clone(), executor, providerKey, nil
authCopy := auth.Clone()
if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" && authWebsocketsEnabled(authCopy) {
m.rememberHomeRuntimeAuth(authCopy)
}
return authCopy, executor, providerKey, nil
}
func requestedModelFromMetadata(metadata map[string]any, fallback string) string {
@@ -0,0 +1,113 @@
package auth
import (
"context"
"errors"
"net/http"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
)
func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing.T) {
manager := NewManager(nil, nil, nil)
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
manager.RegisterExecutor(schedulerTestExecutor{})
auth := &Auth{
ID: "home-auth-1",
Provider: "test",
Status: StatusActive,
Attributes: map[string]string{
"websockets": "true",
homeUpstreamModelAttributeKey: "upstream-model",
},
Metadata: map[string]any{"email": "home@example.com"},
}
auth.EnsureIndex()
manager.rememberHomeRuntimeAuth(auth)
cachedAuth, ok := manager.GetByID("home-auth-1")
if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) {
t.Fatalf("GetByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok)
}
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
opts := cliproxyexecutor.Options{
Metadata: map[string]any{
cliproxyexecutor.ExecutionSessionMetadataKey: "session-1",
cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1",
},
Headers: http.Header{"Authorization": {"Bearer client-key"}},
}
got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts)
if errPick != nil {
t.Fatalf("pickNextViaHome() error = %v", errPick)
}
if got == nil || got.ID != "home-auth-1" {
t.Fatalf("pickNextViaHome() auth = %#v, want home-auth-1", got)
}
if executor == nil {
t.Fatal("pickNextViaHome() executor is nil")
}
if provider != "test" {
t.Fatalf("pickNextViaHome() provider = %q, want test", provider)
}
}
func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) {
manager := NewManager(nil, nil, nil)
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
manager.RegisterExecutor(schedulerTestExecutor{})
manager.mu.Lock()
manager.homeRuntimeAuths["home-auth-1"] = &Auth{
ID: "home-auth-1",
Provider: "test",
Status: StatusActive,
}
manager.mu.Unlock()
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
opts := cliproxyexecutor.Options{
Metadata: map[string]any{
cliproxyexecutor.ExecutionSessionMetadataKey: "session-1",
cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1",
},
Headers: http.Header{"Authorization": {"Bearer client-key"}},
}
got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts)
if errPick == nil {
t.Fatal("pickNextViaHome() error is nil, want home unavailable error")
}
var authErr *Error
if !errors.As(errPick, &authErr) || authErr.Code != "home_unavailable" {
t.Fatalf("pickNextViaHome() error = %v, want home_unavailable", errPick)
}
if got != nil || executor != nil || provider != "" {
t.Fatalf("pickNextViaHome() reused non-websocket auth: auth=%#v executor=%#v provider=%q", got, executor, provider)
}
}
func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) {
manager := NewManager(nil, nil, nil)
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
manager.rememberHomeRuntimeAuth(&Auth{
ID: "home-auth-1",
Provider: "test",
Attributes: map[string]string{
"websockets": "true",
},
})
if _, ok := manager.GetByID("home-auth-1"); !ok {
t.Fatal("expected remembered home auth before disabling home")
}
manager.SetConfig(&internalconfig.Config{})
if _, ok := manager.GetByID("home-auth-1"); ok {
t.Fatal("remembered home auth was not cleared when home was disabled")
}
}