feat(session-affinity): add session-sticky routing for multi-account load balancing
When multiple auth credentials are configured, requests from the same
session are now routed to the same credential, improving upstream prompt
cache hit rates and maintaining context continuity.
Core components:
- SessionAffinitySelector: wraps RoundRobin/FillFirst selectors with
session-to-auth binding; automatic failover when bound auth is
unavailable, re-binding via the fallback selector for even distribution
- SessionCache: TTL-based in-memory cache with background cleanup
goroutine, supporting per-session and per-auth invalidation
- StoppableSelector interface: lifecycle hook for selectors holding
resources, called during Manager.StopAutoRefresh()
Session ID extraction priority (extractSessionIDs):
1. metadata.user_id with Claude Code session format (old
user_{hash}_session_{uuid} and new JSON {session_id} format)
2. X-Session-ID header (generic client support)
3. metadata.user_id (non-Claude format, used as-is)
4. conversation_id field
5. Stable FNV hash from system prompt + first user/assistant messages
(fallback for clients with no explicit session ID); returns both a
full hash (primaryID) and a short hash without assistant content
(fallbackID) to inherit bindings from the first turn
Multi-format message hash covers OpenAI messages, Claude system array,
Gemini contents/systemInstruction, and OpenAI Responses API input items
(including inline messages with role but no type field).
Configuration (config.yaml routing section):
- session-affinity: bool (default false)
- session-affinity-ttl: duration string (default "1h")
- claude-code-session-affinity: bool (deprecated, alias for above)
All three fields trigger selector rebuild on config hot reload.
Side effect: Idempotency-Key header is no longer auto-generated with a
random UUID when absent — only forwarded when explicitly provided by the
client, to avoid polluting session hash extraction.
This commit is contained in:
@@ -4,15 +4,21 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
@@ -420,3 +426,448 @@ func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, block
|
||||
}
|
||||
return false, blockReasonNone, time.Time{}
|
||||
}
|
||||
|
||||
// sessionPattern matches Claude Code user_id format:
|
||||
// user_{hash}_account__session_{uuid}
|
||||
var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`)
|
||||
|
||||
// SessionAffinitySelector wraps another selector with session-sticky behavior.
|
||||
// It extracts session ID from multiple sources and maintains session-to-auth
|
||||
// mappings with automatic failover when the bound auth becomes unavailable.
|
||||
type SessionAffinitySelector struct {
|
||||
fallback Selector
|
||||
cache *SessionCache
|
||||
}
|
||||
|
||||
// SessionAffinityConfig configures the session affinity selector.
|
||||
type SessionAffinityConfig struct {
|
||||
Fallback Selector
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// NewSessionAffinitySelector creates a new session-aware selector.
|
||||
func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector {
|
||||
return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
||||
Fallback: fallback,
|
||||
TTL: time.Hour,
|
||||
})
|
||||
}
|
||||
|
||||
// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration.
|
||||
func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector {
|
||||
if cfg.Fallback == nil {
|
||||
cfg.Fallback = &RoundRobinSelector{}
|
||||
}
|
||||
if cfg.TTL <= 0 {
|
||||
cfg.TTL = time.Hour
|
||||
}
|
||||
return &SessionAffinitySelector{
|
||||
fallback: cfg.Fallback,
|
||||
cache: NewSessionCache(cfg.TTL),
|
||||
}
|
||||
}
|
||||
|
||||
// Pick selects an auth with session affinity when possible.
|
||||
// Priority for session ID extraction:
|
||||
// 1. metadata.user_id (Claude Code format) - highest priority
|
||||
// 2. X-Session-ID header
|
||||
// 3. metadata.user_id (non-Claude Code format)
|
||||
// 4. conversation_id field
|
||||
// 5. Hash-based fallback from messages
|
||||
//
|
||||
// Note: The cache key includes provider, session ID, and model to handle cases where
|
||||
// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview)
|
||||
// that may be supported by different auth credentials, and to avoid cross-provider conflicts.
|
||||
func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
entry := selectorLogEntry(ctx)
|
||||
primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata)
|
||||
if primaryID == "" {
|
||||
entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model)
|
||||
return s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
available, err := getAvailableAuths(auths, provider, model, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cacheKey := provider + "::" + primaryID + "::" + model
|
||||
|
||||
if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok {
|
||||
for _, auth := range available {
|
||||
if auth.ID == cachedAuthID {
|
||||
entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
// Cached auth not available, reselect via fallback selector for even distribution
|
||||
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cache.Set(cacheKey, auth.ID)
|
||||
entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
if fallbackID != "" && fallbackID != primaryID {
|
||||
fallbackKey := provider + "::" + fallbackID + "::" + model
|
||||
if cachedAuthID, ok := s.cache.Get(fallbackKey); ok {
|
||||
for _, auth := range available {
|
||||
if auth.ID == cachedAuthID {
|
||||
s.cache.Set(cacheKey, auth.ID)
|
||||
entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.cache.Set(cacheKey, auth.ID)
|
||||
entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func selectorLogEntry(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
||||
return log.WithField("request_id", reqID)
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
// truncateSessionID shortens session ID for logging (first 8 chars + "...")
|
||||
func truncateSessionID(id string) string {
|
||||
if len(id) <= 20 {
|
||||
return id
|
||||
}
|
||||
return id[:8] + "..."
|
||||
}
|
||||
|
||||
// Stop releases resources held by the selector.
|
||||
func (s *SessionAffinitySelector) Stop() {
|
||||
if s.cache != nil {
|
||||
s.cache.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateAuth removes all session bindings for a specific auth.
|
||||
// Called when an auth becomes rate-limited or unavailable.
|
||||
func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
|
||||
if s.cache != nil {
|
||||
s.cache.InvalidateAuth(authID)
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractSessionID extracts session identifier from multiple sources.
|
||||
// Priority order:
|
||||
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
|
||||
// 2. X-Session-ID header
|
||||
// 3. metadata.user_id (non-Claude Code format)
|
||||
// 4. conversation_id field in request body
|
||||
// 5. Stable hash from first few messages content (fallback)
|
||||
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
|
||||
primary, _ := extractSessionIDs(headers, payload, metadata)
|
||||
return primary
|
||||
}
|
||||
|
||||
// extractSessionIDs returns (primaryID, fallbackID) for session affinity.
|
||||
// primaryID: full hash including assistant response (stable after first turn)
|
||||
// fallbackID: short hash without assistant (used to inherit binding from first turn)
|
||||
func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) {
|
||||
// 1. metadata.user_id with Claude Code session format (highest priority)
|
||||
if len(payload) > 0 {
|
||||
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if userID != "" {
|
||||
// Old format: user_{hash}_account__session_{uuid}
|
||||
if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 {
|
||||
id := "claude:" + matches[1]
|
||||
return id, ""
|
||||
}
|
||||
// New format: JSON object with session_id field
|
||||
// e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"}
|
||||
if len(userID) > 0 && userID[0] == '{' {
|
||||
if sid := gjson.Get(userID, "session_id").String(); sid != "" {
|
||||
return "claude:" + sid, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. X-Session-ID header
|
||||
if headers != nil {
|
||||
if sid := headers.Get("X-Session-ID"); sid != "" {
|
||||
return "header:" + sid, ""
|
||||
}
|
||||
}
|
||||
|
||||
if len(payload) == 0 {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 3. metadata.user_id (non-Claude Code format)
|
||||
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||
if userID != "" {
|
||||
return "user:" + userID, ""
|
||||
}
|
||||
|
||||
// 4. conversation_id field
|
||||
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
|
||||
return "conv:" + convID, ""
|
||||
}
|
||||
|
||||
// 5. Hash-based fallback from message content
|
||||
return extractMessageHashIDs(payload)
|
||||
}
|
||||
|
||||
func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) {
|
||||
var systemPrompt, firstUserMsg, firstAssistantMsg string
|
||||
|
||||
// OpenAI/Claude messages format
|
||||
messages := gjson.GetBytes(payload, "messages")
|
||||
if messages.Exists() && messages.IsArray() {
|
||||
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := msg.Get("role").String()
|
||||
content := extractMessageContent(msg.Get("content"))
|
||||
if content == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch role {
|
||||
case "system":
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = truncateString(content, 100)
|
||||
}
|
||||
case "user":
|
||||
if firstUserMsg == "" {
|
||||
firstUserMsg = truncateString(content, 100)
|
||||
}
|
||||
case "assistant":
|
||||
if firstAssistantMsg == "" {
|
||||
firstAssistantMsg = truncateString(content, 100)
|
||||
}
|
||||
}
|
||||
|
||||
if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Claude API: top-level "system" field (array or string)
|
||||
if systemPrompt == "" {
|
||||
topSystem := gjson.GetBytes(payload, "system")
|
||||
if topSystem.Exists() {
|
||||
if topSystem.IsArray() {
|
||||
topSystem.ForEach(func(_, part gjson.Result) bool {
|
||||
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
||||
systemPrompt = truncateString(text, 100)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
} else if topSystem.Type == gjson.String {
|
||||
systemPrompt = truncateString(topSystem.String(), 100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gemini format
|
||||
if systemPrompt == "" && firstUserMsg == "" {
|
||||
sysInstr := gjson.GetBytes(payload, "systemInstruction.parts")
|
||||
if sysInstr.Exists() && sysInstr.IsArray() {
|
||||
sysInstr.ForEach(func(_, part gjson.Result) bool {
|
||||
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
||||
systemPrompt = truncateString(text, 100)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
contents := gjson.GetBytes(payload, "contents")
|
||||
if contents.Exists() && contents.IsArray() {
|
||||
contents.ForEach(func(_, msg gjson.Result) bool {
|
||||
role := msg.Get("role").String()
|
||||
msg.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||
text := part.Get("text").String()
|
||||
if text == "" {
|
||||
return true
|
||||
}
|
||||
switch role {
|
||||
case "user":
|
||||
if firstUserMsg == "" {
|
||||
firstUserMsg = truncateString(text, 100)
|
||||
}
|
||||
case "model":
|
||||
if firstAssistantMsg == "" {
|
||||
firstAssistantMsg = truncateString(text, 100)
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
if firstUserMsg != "" && firstAssistantMsg != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI Responses API format (v1/responses)
|
||||
if systemPrompt == "" && firstUserMsg == "" {
|
||||
if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" {
|
||||
systemPrompt = truncateString(instr, 100)
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(payload, "input")
|
||||
if input.Exists() && input.IsArray() {
|
||||
input.ForEach(func(_, item gjson.Result) bool {
|
||||
itemType := item.Get("type").String()
|
||||
if itemType == "reasoning" {
|
||||
return true
|
||||
}
|
||||
// Skip non-message typed items (function_call, function_call_output, etc.)
|
||||
// but allow items with no type that have a role (inline message format).
|
||||
if itemType != "" && itemType != "message" {
|
||||
return true
|
||||
}
|
||||
|
||||
role := item.Get("role").String()
|
||||
if itemType == "" && role == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle both string content and array content (multimodal).
|
||||
content := item.Get("content")
|
||||
var text string
|
||||
if content.Type == gjson.String {
|
||||
text = content.String()
|
||||
} else {
|
||||
text = extractResponsesAPIContent(content)
|
||||
}
|
||||
if text == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch role {
|
||||
case "developer", "system":
|
||||
if systemPrompt == "" {
|
||||
systemPrompt = truncateString(text, 100)
|
||||
}
|
||||
case "user":
|
||||
if firstUserMsg == "" {
|
||||
firstUserMsg = truncateString(text, 100)
|
||||
}
|
||||
case "assistant":
|
||||
if firstAssistantMsg == "" {
|
||||
firstAssistantMsg = truncateString(text, 100)
|
||||
}
|
||||
}
|
||||
|
||||
if firstUserMsg != "" && firstAssistantMsg != "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if systemPrompt == "" && firstUserMsg == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
shortHash := computeSessionHash(systemPrompt, firstUserMsg, "")
|
||||
if firstAssistantMsg == "" {
|
||||
return shortHash, ""
|
||||
}
|
||||
|
||||
fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg)
|
||||
return fullHash, shortHash
|
||||
}
|
||||
|
||||
func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string {
|
||||
h := fnv.New64a()
|
||||
if systemPrompt != "" {
|
||||
h.Write([]byte("sys:" + systemPrompt + "\n"))
|
||||
}
|
||||
if userMsg != "" {
|
||||
h.Write([]byte("usr:" + userMsg + "\n"))
|
||||
}
|
||||
if assistantMsg != "" {
|
||||
h.Write([]byte("ast:" + assistantMsg + "\n"))
|
||||
}
|
||||
return fmt.Sprintf("msg:%016x", h.Sum64())
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) > maxLen {
|
||||
return s[:maxLen]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// extractMessageContent extracts text content from a message content field.
|
||||
// Handles both string content and array content (multimodal messages).
|
||||
// For array content, extracts text from all text-type elements.
|
||||
func extractMessageContent(content gjson.Result) string {
|
||||
// String content: "Hello world"
|
||||
if content.Type == gjson.String {
|
||||
return content.String()
|
||||
}
|
||||
|
||||
// Array content: [{"type":"text","text":"Hello"},{"type":"image",...}]
|
||||
if content.IsArray() {
|
||||
var texts []string
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
// Handle Claude format: {"type":"text","text":"content"}
|
||||
if part.Get("type").String() == "text" {
|
||||
if text := part.Get("text").String(); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
// Handle OpenAI format: {"type":"text","text":"content"}
|
||||
// Same structure as Claude, already handled above
|
||||
return true
|
||||
})
|
||||
if len(texts) > 0 {
|
||||
return strings.Join(texts, " ")
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractResponsesAPIContent(content gjson.Result) string {
|
||||
if !content.IsArray() {
|
||||
return ""
|
||||
}
|
||||
var texts []string
|
||||
content.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "input_text" || partType == "output_text" || partType == "text" {
|
||||
if text := part.Get("text").String(); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if len(texts) > 0 {
|
||||
return strings.Join(texts, " ")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractSessionID is kept for backward compatibility.
|
||||
// Deprecated: Use ExtractSessionID instead.
|
||||
func extractSessionID(payload []byte) string {
|
||||
return ExtractSessionID(nil, payload, nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user