feat(auth): add websocket session reuse for home auths with caching support
- Introduced `homeRuntimeAuths` to cache home auths for websocket session reuse. - Updated `pickNextViaHome` to prioritize cached auths for pinned websocket sessions. - Implemented automatic clearing of cached home auths when home mode is disabled. - Added unit tests to validate caching behavior, clearing logic, and fallback scenarios.
This commit is contained in:
@@ -151,6 +151,9 @@ type Manager struct {
|
||||
mu sync.RWMutex
|
||||
auths map[string]*Auth
|
||||
scheduler *authScheduler
|
||||
// homeRuntimeAuths caches auths returned by Home so websocket sessions can
|
||||
// reuse an established upstream credential without dispatching every turn.
|
||||
homeRuntimeAuths map[string]*Auth
|
||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||
providerOffsets map[string]int
|
||||
|
||||
@@ -195,6 +198,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
selector: selector,
|
||||
hook: hook,
|
||||
auths: make(map[string]*Auth),
|
||||
homeRuntimeAuths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
modelPoolOffsets: make(map[string]int),
|
||||
}
|
||||
@@ -376,6 +380,9 @@ func (m *Manager) SetConfig(cfg *internalconfig.Config) {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
m.runtimeConfig.Store(cfg)
|
||||
if !cfg.Home.Enabled {
|
||||
m.clearHomeRuntimeAuths()
|
||||
}
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
}
|
||||
|
||||
@@ -2713,7 +2720,10 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
|
||||
defer m.mu.RUnlock()
|
||||
auth, ok := m.auths[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
auth, ok = m.homeRuntimeAuths[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
return auth.Clone(), true
|
||||
}
|
||||
@@ -2751,12 +2761,15 @@ func (m *Manager) CloseExecutionSession(sessionID string) {
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
m.mu.Lock()
|
||||
if sessionID == CloseAllExecutionSessionsID {
|
||||
m.clearHomeRuntimeAuthsLocked()
|
||||
}
|
||||
executors := make([]ProviderExecutor, 0, len(m.executors))
|
||||
for _, exec := range m.executors {
|
||||
executors = append(executors, exec)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
m.mu.Unlock()
|
||||
|
||||
for i := range executors {
|
||||
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
|
||||
@@ -3168,6 +3181,80 @@ func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) {
|
||||
ginCtx.Set("userApiKey", apiKey)
|
||||
}
|
||||
|
||||
func homeExecutionSessionIDFromMetadata(meta map[string]any) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := meta[cliproxyexecutor.ExecutionSessionMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(value)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(value))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) clearHomeRuntimeAuths() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.clearHomeRuntimeAuthsLocked()
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) clearHomeRuntimeAuthsLocked() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.homeRuntimeAuths = make(map[string]*Auth)
|
||||
}
|
||||
|
||||
func (m *Manager) rememberHomeRuntimeAuth(auth *Auth) {
|
||||
if m == nil || auth == nil || strings.TrimSpace(auth.ID) == "" || !authWebsocketsEnabled(auth) {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
if m.homeRuntimeAuths == nil {
|
||||
m.homeRuntimeAuths = make(map[string]*Auth)
|
||||
}
|
||||
m.homeRuntimeAuths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) homeRuntimeAuthByID(authID string) (*Auth, ProviderExecutor, string, bool) {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if m == nil || authID == "" {
|
||||
return nil, nil, "", false
|
||||
}
|
||||
m.mu.RLock()
|
||||
auth := m.homeRuntimeAuths[authID]
|
||||
m.mu.RUnlock()
|
||||
if auth == nil || !authWebsocketsEnabled(auth) {
|
||||
return nil, nil, "", false
|
||||
}
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if providerKey == "" {
|
||||
return nil, nil, "", false
|
||||
}
|
||||
executor, ok := m.Executor(providerKey)
|
||||
if !ok && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["base_url"]) != "" {
|
||||
executor, ok = m.Executor("openai-compatibility")
|
||||
if ok {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
return nil, nil, "", false
|
||||
}
|
||||
return auth.Clone(), executor, providerKey, true
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts cliproxyexecutor.Options) (*Auth, ProviderExecutor, string, error) {
|
||||
if m == nil {
|
||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
@@ -3175,6 +3262,14 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" {
|
||||
if pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata); pinnedAuthID != "" {
|
||||
if auth, executor, providerKey, ok := m.homeRuntimeAuthByID(pinnedAuthID); ok {
|
||||
return auth, executor, providerKey, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
client := home.Current()
|
||||
if client == nil || !client.HeartbeatOK() {
|
||||
return nil, nil, "", &Error{Code: "home_unavailable", Message: "home control center unavailable", HTTPStatus: http.StatusServiceUnavailable}
|
||||
@@ -3254,7 +3349,11 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro
|
||||
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered", HTTPStatus: http.StatusBadGateway}
|
||||
}
|
||||
|
||||
return auth.Clone(), executor, providerKey, nil
|
||||
authCopy := auth.Clone()
|
||||
if cliproxyexecutor.DownstreamWebsocket(ctx) && homeExecutionSessionIDFromMetadata(opts.Metadata) != "" && authWebsocketsEnabled(authCopy) {
|
||||
m.rememberHomeRuntimeAuth(authCopy)
|
||||
}
|
||||
return authCopy, executor, providerKey, nil
|
||||
}
|
||||
|
||||
func requestedModelFromMetadata(metadata map[string]any, fallback string) string {
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
func TestPickNextViaHomeReusesPinnedWebsocketAuthWithoutHomeDispatch(t *testing.T) {
|
||||
manager := NewManager(nil, nil, nil)
|
||||
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
|
||||
manager.RegisterExecutor(schedulerTestExecutor{})
|
||||
|
||||
auth := &Auth{
|
||||
ID: "home-auth-1",
|
||||
Provider: "test",
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"websockets": "true",
|
||||
homeUpstreamModelAttributeKey: "upstream-model",
|
||||
},
|
||||
Metadata: map[string]any{"email": "home@example.com"},
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
manager.rememberHomeRuntimeAuth(auth)
|
||||
cachedAuth, ok := manager.GetByID("home-auth-1")
|
||||
if !ok || cachedAuth == nil || !authWebsocketsEnabled(cachedAuth) {
|
||||
t.Fatalf("GetByID() did not expose remembered websocket home auth: auth=%#v ok=%v", cachedAuth, ok)
|
||||
}
|
||||
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
opts := cliproxyexecutor.Options{
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.ExecutionSessionMetadataKey: "session-1",
|
||||
cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1",
|
||||
},
|
||||
Headers: http.Header{"Authorization": {"Bearer client-key"}},
|
||||
}
|
||||
|
||||
got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts)
|
||||
if errPick != nil {
|
||||
t.Fatalf("pickNextViaHome() error = %v", errPick)
|
||||
}
|
||||
if got == nil || got.ID != "home-auth-1" {
|
||||
t.Fatalf("pickNextViaHome() auth = %#v, want home-auth-1", got)
|
||||
}
|
||||
if executor == nil {
|
||||
t.Fatal("pickNextViaHome() executor is nil")
|
||||
}
|
||||
if provider != "test" {
|
||||
t.Fatalf("pickNextViaHome() provider = %q, want test", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPickNextViaHomeDoesNotReusePinnedNonWebsocketAuth(t *testing.T) {
|
||||
manager := NewManager(nil, nil, nil)
|
||||
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
|
||||
manager.RegisterExecutor(schedulerTestExecutor{})
|
||||
|
||||
manager.mu.Lock()
|
||||
manager.homeRuntimeAuths["home-auth-1"] = &Auth{
|
||||
ID: "home-auth-1",
|
||||
Provider: "test",
|
||||
Status: StatusActive,
|
||||
}
|
||||
manager.mu.Unlock()
|
||||
|
||||
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
|
||||
opts := cliproxyexecutor.Options{
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.ExecutionSessionMetadataKey: "session-1",
|
||||
cliproxyexecutor.PinnedAuthMetadataKey: "home-auth-1",
|
||||
},
|
||||
Headers: http.Header{"Authorization": {"Bearer client-key"}},
|
||||
}
|
||||
|
||||
got, executor, provider, errPick := manager.pickNextViaHome(ctx, "gpt-5.4", opts)
|
||||
if errPick == nil {
|
||||
t.Fatal("pickNextViaHome() error is nil, want home unavailable error")
|
||||
}
|
||||
var authErr *Error
|
||||
if !errors.As(errPick, &authErr) || authErr.Code != "home_unavailable" {
|
||||
t.Fatalf("pickNextViaHome() error = %v, want home_unavailable", errPick)
|
||||
}
|
||||
if got != nil || executor != nil || provider != "" {
|
||||
t.Fatalf("pickNextViaHome() reused non-websocket auth: auth=%#v executor=%#v provider=%q", got, executor, provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHomeRuntimeAuthsClearWhenHomeDisabled(t *testing.T) {
|
||||
manager := NewManager(nil, nil, nil)
|
||||
manager.SetConfig(&internalconfig.Config{Home: internalconfig.HomeConfig{Enabled: true}})
|
||||
manager.rememberHomeRuntimeAuth(&Auth{
|
||||
ID: "home-auth-1",
|
||||
Provider: "test",
|
||||
Attributes: map[string]string{
|
||||
"websockets": "true",
|
||||
},
|
||||
})
|
||||
|
||||
if _, ok := manager.GetByID("home-auth-1"); !ok {
|
||||
t.Fatal("expected remembered home auth before disabling home")
|
||||
}
|
||||
|
||||
manager.SetConfig(&internalconfig.Config{})
|
||||
if _, ok := manager.GetByID("home-auth-1"); ok {
|
||||
t.Fatal("remembered home auth was not cleared when home was disabled")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user