7c24d54ca8
When multiple auth credentials are configured, requests from the same
session are now routed to the same credential, improving upstream prompt
cache hit rates and maintaining context continuity.
Core components:
- SessionAffinitySelector: wraps RoundRobin/FillFirst selectors with
session-to-auth binding; automatic failover when bound auth is
unavailable, re-binding via the fallback selector for even distribution
- SessionCache: TTL-based in-memory cache with background cleanup
goroutine, supporting per-session and per-auth invalidation
- StoppableSelector interface: lifecycle hook for selectors holding
resources, called during Manager.StopAutoRefresh()
Session ID extraction priority (extractSessionIDs):
1. metadata.user_id with Claude Code session format (old
user_{hash}_session_{uuid} and new JSON {session_id} format)
2. X-Session-ID header (generic client support)
3. metadata.user_id (non-Claude format, used as-is)
4. conversation_id field
5. Stable FNV hash from system prompt + first user/assistant messages
(fallback for clients with no explicit session ID); returns both a
full hash (primaryID) and a short hash without assistant content
(fallbackID) to inherit bindings from the first turn
Multi-format message hash covers OpenAI messages, Claude system array,
Gemini contents/systemInstruction, and OpenAI Responses API input items
(including inline messages with role but no type field).
Configuration (config.yaml routing section):
- session-affinity: bool (default false)
- session-affinity-ttl: duration string (default "1h")
- claude-code-session-affinity: bool (deprecated, alias for above)
All three fields trigger selector rebuild on config hot reload.
Side effect: Idempotency-Key header is no longer auto-generated with a
random UUID when absent — only forwarded when explicitly provided by the
client, to avoid polluting session hash extraction.
874 lines
25 KiB
Go
874 lines
25 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"math"
|
|
"math/rand/v2"
|
|
"net/http"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
)
|
|
|
|
// RoundRobinSelector provides a simple provider scoped round-robin selection strategy.
|
|
type RoundRobinSelector struct {
|
|
mu sync.Mutex
|
|
cursors map[string]int
|
|
maxKeys int
|
|
}
|
|
|
|
// FillFirstSelector selects the first available credential (deterministic ordering).
|
|
// This "burns" one account before moving to the next, which can help stagger
|
|
// rolling-window subscription caps (e.g. chat message limits).
|
|
type FillFirstSelector struct{}
|
|
|
|
type blockReason int
|
|
|
|
const (
|
|
blockReasonNone blockReason = iota
|
|
blockReasonCooldown
|
|
blockReasonDisabled
|
|
blockReasonOther
|
|
)
|
|
|
|
type modelCooldownError struct {
|
|
model string
|
|
resetIn time.Duration
|
|
provider string
|
|
}
|
|
|
|
func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError {
|
|
if resetIn < 0 {
|
|
resetIn = 0
|
|
}
|
|
return &modelCooldownError{
|
|
model: model,
|
|
provider: provider,
|
|
resetIn: resetIn,
|
|
}
|
|
}
|
|
|
|
func (e *modelCooldownError) Error() string {
|
|
modelName := e.model
|
|
if modelName == "" {
|
|
modelName = "requested model"
|
|
}
|
|
message := fmt.Sprintf("All credentials for model %s are cooling down", modelName)
|
|
if e.provider != "" {
|
|
message = fmt.Sprintf("%s via provider %s", message, e.provider)
|
|
}
|
|
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
|
|
if resetSeconds < 0 {
|
|
resetSeconds = 0
|
|
}
|
|
displayDuration := e.resetIn
|
|
if displayDuration > 0 && displayDuration < time.Second {
|
|
displayDuration = time.Second
|
|
} else {
|
|
displayDuration = displayDuration.Round(time.Second)
|
|
}
|
|
errorBody := map[string]any{
|
|
"code": "model_cooldown",
|
|
"message": message,
|
|
"model": e.model,
|
|
"reset_time": displayDuration.String(),
|
|
"reset_seconds": resetSeconds,
|
|
}
|
|
if e.provider != "" {
|
|
errorBody["provider"] = e.provider
|
|
}
|
|
payload := map[string]any{"error": errorBody}
|
|
data, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message)
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
func (e *modelCooldownError) StatusCode() int {
|
|
return http.StatusTooManyRequests
|
|
}
|
|
|
|
func (e *modelCooldownError) Headers() http.Header {
|
|
headers := make(http.Header)
|
|
headers.Set("Content-Type", "application/json")
|
|
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
|
|
if resetSeconds < 0 {
|
|
resetSeconds = 0
|
|
}
|
|
headers.Set("Retry-After", strconv.Itoa(resetSeconds))
|
|
return headers
|
|
}
|
|
|
|
func authPriority(auth *Auth) int {
|
|
if auth == nil || auth.Attributes == nil {
|
|
return 0
|
|
}
|
|
raw := strings.TrimSpace(auth.Attributes["priority"])
|
|
if raw == "" {
|
|
return 0
|
|
}
|
|
parsed, err := strconv.Atoi(raw)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return parsed
|
|
}
|
|
|
|
func canonicalModelKey(model string) string {
|
|
model = strings.TrimSpace(model)
|
|
if model == "" {
|
|
return ""
|
|
}
|
|
parsed := thinking.ParseSuffix(model)
|
|
modelName := strings.TrimSpace(parsed.ModelName)
|
|
if modelName == "" {
|
|
return model
|
|
}
|
|
return modelName
|
|
}
|
|
|
|
func authWebsocketsEnabled(auth *Auth) bool {
|
|
if auth == nil {
|
|
return false
|
|
}
|
|
if len(auth.Attributes) > 0 {
|
|
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
|
|
parsed, errParse := strconv.ParseBool(raw)
|
|
if errParse == nil {
|
|
return parsed
|
|
}
|
|
}
|
|
}
|
|
if len(auth.Metadata) == 0 {
|
|
return false
|
|
}
|
|
raw, ok := auth.Metadata["websockets"]
|
|
if !ok || raw == nil {
|
|
return false
|
|
}
|
|
switch v := raw.(type) {
|
|
case bool:
|
|
return v
|
|
case string:
|
|
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
|
|
if errParse == nil {
|
|
return parsed
|
|
}
|
|
default:
|
|
}
|
|
return false
|
|
}
|
|
|
|
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
|
|
if len(available) == 0 {
|
|
return available
|
|
}
|
|
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
|
|
return available
|
|
}
|
|
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
|
|
return available
|
|
}
|
|
|
|
wsEnabled := make([]*Auth, 0, len(available))
|
|
for i := 0; i < len(available); i++ {
|
|
candidate := available[i]
|
|
if authWebsocketsEnabled(candidate) {
|
|
wsEnabled = append(wsEnabled, candidate)
|
|
}
|
|
}
|
|
if len(wsEnabled) > 0 {
|
|
return wsEnabled
|
|
}
|
|
return available
|
|
}
|
|
|
|
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
|
available = make(map[int][]*Auth)
|
|
for i := 0; i < len(auths); i++ {
|
|
candidate := auths[i]
|
|
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
|
if !blocked {
|
|
priority := authPriority(candidate)
|
|
available[priority] = append(available[priority], candidate)
|
|
continue
|
|
}
|
|
if reason == blockReasonCooldown {
|
|
cooldownCount++
|
|
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
|
|
earliest = next
|
|
}
|
|
}
|
|
}
|
|
return available, cooldownCount, earliest
|
|
}
|
|
|
|
func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]*Auth, error) {
|
|
if len(auths) == 0 {
|
|
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
|
}
|
|
|
|
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
|
|
if len(availableByPriority) == 0 {
|
|
if cooldownCount == len(auths) && !earliest.IsZero() {
|
|
providerForError := provider
|
|
if providerForError == "mixed" {
|
|
providerForError = ""
|
|
}
|
|
resetIn := earliest.Sub(now)
|
|
if resetIn < 0 {
|
|
resetIn = 0
|
|
}
|
|
return nil, newModelCooldownError(model, providerForError, resetIn)
|
|
}
|
|
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
|
}
|
|
|
|
bestPriority := 0
|
|
found := false
|
|
for priority := range availableByPriority {
|
|
if !found || priority > bestPriority {
|
|
bestPriority = priority
|
|
found = true
|
|
}
|
|
}
|
|
|
|
available := availableByPriority[bestPriority]
|
|
if len(available) > 1 {
|
|
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
|
}
|
|
return available, nil
|
|
}
|
|
|
|
// Pick selects the next available auth for the provider in a round-robin manner.
|
|
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
|
|
// a two-level round-robin is used: first cycling across credential groups (parent
|
|
// accounts), then cycling within each group's project auths.
|
|
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
|
_ = opts
|
|
now := time.Now()
|
|
available, err := getAvailableAuths(auths, provider, model, now)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
available = preferCodexWebsocketAuths(ctx, provider, available)
|
|
key := provider + ":" + canonicalModelKey(model)
|
|
s.mu.Lock()
|
|
if s.cursors == nil {
|
|
s.cursors = make(map[string]int)
|
|
}
|
|
limit := s.maxKeys
|
|
if limit <= 0 {
|
|
limit = 4096
|
|
}
|
|
|
|
// Check if any available auth has gemini_virtual_parent attribute,
|
|
// indicating gemini-cli virtual auths that should use credential-level polling.
|
|
groups, parentOrder := groupByVirtualParent(available)
|
|
if len(parentOrder) > 1 {
|
|
// Two-level round-robin: first select a credential group, then pick within it.
|
|
groupKey := key + "::group"
|
|
s.ensureCursorKey(groupKey, limit)
|
|
if _, exists := s.cursors[groupKey]; !exists {
|
|
// Seed with a random initial offset so the starting credential is randomized.
|
|
s.cursors[groupKey] = rand.IntN(len(parentOrder))
|
|
}
|
|
groupIndex := s.cursors[groupKey]
|
|
if groupIndex >= 2_147_483_640 {
|
|
groupIndex = 0
|
|
}
|
|
s.cursors[groupKey] = groupIndex + 1
|
|
|
|
selectedParent := parentOrder[groupIndex%len(parentOrder)]
|
|
group := groups[selectedParent]
|
|
|
|
// Second level: round-robin within the selected credential group.
|
|
innerKey := key + "::cred:" + selectedParent
|
|
s.ensureCursorKey(innerKey, limit)
|
|
innerIndex := s.cursors[innerKey]
|
|
if innerIndex >= 2_147_483_640 {
|
|
innerIndex = 0
|
|
}
|
|
s.cursors[innerKey] = innerIndex + 1
|
|
s.mu.Unlock()
|
|
return group[innerIndex%len(group)], nil
|
|
}
|
|
|
|
// Flat round-robin for non-grouped auths (original behavior).
|
|
s.ensureCursorKey(key, limit)
|
|
index := s.cursors[key]
|
|
if index >= 2_147_483_640 {
|
|
index = 0
|
|
}
|
|
s.cursors[key] = index + 1
|
|
s.mu.Unlock()
|
|
return available[index%len(available)], nil
|
|
}
|
|
|
|
// ensureCursorKey ensures the cursor map has capacity for the given key.
|
|
// Must be called with s.mu held.
|
|
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
|
|
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
|
s.cursors = make(map[string]int)
|
|
}
|
|
}
|
|
|
|
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
|
|
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
|
|
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
|
|
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
|
|
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
|
|
if len(auths) == 0 {
|
|
return nil, nil
|
|
}
|
|
groups := make(map[string][]*Auth)
|
|
for _, a := range auths {
|
|
parent := ""
|
|
if a.Attributes != nil {
|
|
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
|
|
}
|
|
if parent == "" {
|
|
// Non-virtual auth present; fall back to flat round-robin.
|
|
return nil, nil
|
|
}
|
|
groups[parent] = append(groups[parent], a)
|
|
}
|
|
// Collect parent IDs in sorted order for stable cursor indexing.
|
|
parentOrder := make([]string, 0, len(groups))
|
|
for p := range groups {
|
|
parentOrder = append(parentOrder, p)
|
|
}
|
|
sort.Strings(parentOrder)
|
|
return groups, parentOrder
|
|
}
|
|
|
|
// Pick selects the first available auth for the provider in a deterministic manner.
|
|
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
|
_ = opts
|
|
now := time.Now()
|
|
available, err := getAvailableAuths(auths, provider, model, now)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
available = preferCodexWebsocketAuths(ctx, provider, available)
|
|
return available[0], nil
|
|
}
|
|
|
|
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
|
|
if auth == nil {
|
|
return true, blockReasonOther, time.Time{}
|
|
}
|
|
if auth.Disabled || auth.Status == StatusDisabled {
|
|
return true, blockReasonDisabled, time.Time{}
|
|
}
|
|
if model != "" {
|
|
if len(auth.ModelStates) > 0 {
|
|
state, ok := auth.ModelStates[model]
|
|
if (!ok || state == nil) && model != "" {
|
|
baseModel := canonicalModelKey(model)
|
|
if baseModel != "" && baseModel != model {
|
|
state, ok = auth.ModelStates[baseModel]
|
|
}
|
|
}
|
|
if ok && state != nil {
|
|
if state.Status == StatusDisabled {
|
|
return true, blockReasonDisabled, time.Time{}
|
|
}
|
|
if state.Unavailable {
|
|
if state.NextRetryAfter.IsZero() {
|
|
return false, blockReasonNone, time.Time{}
|
|
}
|
|
if state.NextRetryAfter.After(now) {
|
|
next := state.NextRetryAfter
|
|
if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) {
|
|
next = state.Quota.NextRecoverAt
|
|
}
|
|
if next.Before(now) {
|
|
next = now
|
|
}
|
|
if state.Quota.Exceeded {
|
|
return true, blockReasonCooldown, next
|
|
}
|
|
return true, blockReasonOther, next
|
|
}
|
|
}
|
|
return false, blockReasonNone, time.Time{}
|
|
}
|
|
}
|
|
return false, blockReasonNone, time.Time{}
|
|
}
|
|
if auth.Unavailable && auth.NextRetryAfter.After(now) {
|
|
next := auth.NextRetryAfter
|
|
if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) {
|
|
next = auth.Quota.NextRecoverAt
|
|
}
|
|
if next.Before(now) {
|
|
next = now
|
|
}
|
|
if auth.Quota.Exceeded {
|
|
return true, blockReasonCooldown, next
|
|
}
|
|
return true, blockReasonOther, next
|
|
}
|
|
return false, blockReasonNone, time.Time{}
|
|
}
|
|
|
|
// sessionPattern matches Claude Code user_id format:
|
|
// user_{hash}_account__session_{uuid}
|
|
var sessionPattern = regexp.MustCompile(`_session_([a-f0-9-]+)$`)
|
|
|
|
// SessionAffinitySelector wraps another selector with session-sticky behavior.
|
|
// It extracts session ID from multiple sources and maintains session-to-auth
|
|
// mappings with automatic failover when the bound auth becomes unavailable.
|
|
type SessionAffinitySelector struct {
|
|
fallback Selector
|
|
cache *SessionCache
|
|
}
|
|
|
|
// SessionAffinityConfig configures the session affinity selector.
|
|
type SessionAffinityConfig struct {
|
|
Fallback Selector
|
|
TTL time.Duration
|
|
}
|
|
|
|
// NewSessionAffinitySelector creates a new session-aware selector.
|
|
func NewSessionAffinitySelector(fallback Selector) *SessionAffinitySelector {
|
|
return NewSessionAffinitySelectorWithConfig(SessionAffinityConfig{
|
|
Fallback: fallback,
|
|
TTL: time.Hour,
|
|
})
|
|
}
|
|
|
|
// NewSessionAffinitySelectorWithConfig creates a selector with custom configuration.
|
|
func NewSessionAffinitySelectorWithConfig(cfg SessionAffinityConfig) *SessionAffinitySelector {
|
|
if cfg.Fallback == nil {
|
|
cfg.Fallback = &RoundRobinSelector{}
|
|
}
|
|
if cfg.TTL <= 0 {
|
|
cfg.TTL = time.Hour
|
|
}
|
|
return &SessionAffinitySelector{
|
|
fallback: cfg.Fallback,
|
|
cache: NewSessionCache(cfg.TTL),
|
|
}
|
|
}
|
|
|
|
// Pick selects an auth with session affinity when possible.
|
|
// Priority for session ID extraction:
|
|
// 1. metadata.user_id (Claude Code format) - highest priority
|
|
// 2. X-Session-ID header
|
|
// 3. metadata.user_id (non-Claude Code format)
|
|
// 4. conversation_id field
|
|
// 5. Hash-based fallback from messages
|
|
//
|
|
// Note: The cache key includes provider, session ID, and model to handle cases where
|
|
// a session uses multiple models (e.g., gemini-2.5-pro and gemini-3-flash-preview)
|
|
// that may be supported by different auth credentials, and to avoid cross-provider conflicts.
|
|
func (s *SessionAffinitySelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
|
entry := selectorLogEntry(ctx)
|
|
primaryID, fallbackID := extractSessionIDs(opts.Headers, opts.OriginalRequest, opts.Metadata)
|
|
if primaryID == "" {
|
|
entry.Debugf("session-affinity: no session ID extracted, falling back to default selector | provider=%s model=%s", provider, model)
|
|
return s.fallback.Pick(ctx, provider, model, opts, auths)
|
|
}
|
|
|
|
now := time.Now()
|
|
available, err := getAvailableAuths(auths, provider, model, now)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cacheKey := provider + "::" + primaryID + "::" + model
|
|
|
|
if cachedAuthID, ok := s.cache.GetAndRefresh(cacheKey); ok {
|
|
for _, auth := range available {
|
|
if auth.ID == cachedAuthID {
|
|
entry.Infof("session-affinity: cache hit | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
|
return auth, nil
|
|
}
|
|
}
|
|
// Cached auth not available, reselect via fallback selector for even distribution
|
|
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.cache.Set(cacheKey, auth.ID)
|
|
entry.Infof("session-affinity: cache hit but auth unavailable, reselected | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
|
return auth, nil
|
|
}
|
|
|
|
if fallbackID != "" && fallbackID != primaryID {
|
|
fallbackKey := provider + "::" + fallbackID + "::" + model
|
|
if cachedAuthID, ok := s.cache.Get(fallbackKey); ok {
|
|
for _, auth := range available {
|
|
if auth.ID == cachedAuthID {
|
|
s.cache.Set(cacheKey, auth.ID)
|
|
entry.Infof("session-affinity: fallback cache hit | session=%s fallback=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), truncateSessionID(fallbackID), auth.ID, provider, model)
|
|
return auth, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
auth, err := s.fallback.Pick(ctx, provider, model, opts, auths)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.cache.Set(cacheKey, auth.ID)
|
|
entry.Infof("session-affinity: cache miss, new binding | session=%s auth=%s provider=%s model=%s", truncateSessionID(primaryID), auth.ID, provider, model)
|
|
return auth, nil
|
|
}
|
|
|
|
func selectorLogEntry(ctx context.Context) *log.Entry {
|
|
if ctx == nil {
|
|
return log.NewEntry(log.StandardLogger())
|
|
}
|
|
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
|
return log.WithField("request_id", reqID)
|
|
}
|
|
return log.NewEntry(log.StandardLogger())
|
|
}
|
|
|
|
// truncateSessionID shortens session ID for logging (first 8 chars + "...")
|
|
func truncateSessionID(id string) string {
|
|
if len(id) <= 20 {
|
|
return id
|
|
}
|
|
return id[:8] + "..."
|
|
}
|
|
|
|
// Stop releases resources held by the selector.
|
|
func (s *SessionAffinitySelector) Stop() {
|
|
if s.cache != nil {
|
|
s.cache.Stop()
|
|
}
|
|
}
|
|
|
|
// InvalidateAuth removes all session bindings for a specific auth.
|
|
// Called when an auth becomes rate-limited or unavailable.
|
|
func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
|
|
if s.cache != nil {
|
|
s.cache.InvalidateAuth(authID)
|
|
}
|
|
}
|
|
|
|
// ExtractSessionID extracts session identifier from multiple sources.
|
|
// Priority order:
|
|
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
|
|
// 2. X-Session-ID header
|
|
// 3. metadata.user_id (non-Claude Code format)
|
|
// 4. conversation_id field in request body
|
|
// 5. Stable hash from first few messages content (fallback)
|
|
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
|
|
primary, _ := extractSessionIDs(headers, payload, metadata)
|
|
return primary
|
|
}
|
|
|
|
// extractSessionIDs returns (primaryID, fallbackID) for session affinity.
|
|
// primaryID: full hash including assistant response (stable after first turn)
|
|
// fallbackID: short hash without assistant (used to inherit binding from first turn)
|
|
func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]any) (string, string) {
|
|
// 1. metadata.user_id with Claude Code session format (highest priority)
|
|
if len(payload) > 0 {
|
|
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
|
if userID != "" {
|
|
// Old format: user_{hash}_account__session_{uuid}
|
|
if matches := sessionPattern.FindStringSubmatch(userID); len(matches) >= 2 {
|
|
id := "claude:" + matches[1]
|
|
return id, ""
|
|
}
|
|
// New format: JSON object with session_id field
|
|
// e.g. {"device_id":"...","account_uuid":"...","session_id":"uuid"}
|
|
if len(userID) > 0 && userID[0] == '{' {
|
|
if sid := gjson.Get(userID, "session_id").String(); sid != "" {
|
|
return "claude:" + sid, ""
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 2. X-Session-ID header
|
|
if headers != nil {
|
|
if sid := headers.Get("X-Session-ID"); sid != "" {
|
|
return "header:" + sid, ""
|
|
}
|
|
}
|
|
|
|
if len(payload) == 0 {
|
|
return "", ""
|
|
}
|
|
|
|
// 3. metadata.user_id (non-Claude Code format)
|
|
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
|
if userID != "" {
|
|
return "user:" + userID, ""
|
|
}
|
|
|
|
// 4. conversation_id field
|
|
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
|
|
return "conv:" + convID, ""
|
|
}
|
|
|
|
// 5. Hash-based fallback from message content
|
|
return extractMessageHashIDs(payload)
|
|
}
|
|
|
|
func extractMessageHashIDs(payload []byte) (primaryID, fallbackID string) {
|
|
var systemPrompt, firstUserMsg, firstAssistantMsg string
|
|
|
|
// OpenAI/Claude messages format
|
|
messages := gjson.GetBytes(payload, "messages")
|
|
if messages.Exists() && messages.IsArray() {
|
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
|
role := msg.Get("role").String()
|
|
content := extractMessageContent(msg.Get("content"))
|
|
if content == "" {
|
|
return true
|
|
}
|
|
|
|
switch role {
|
|
case "system":
|
|
if systemPrompt == "" {
|
|
systemPrompt = truncateString(content, 100)
|
|
}
|
|
case "user":
|
|
if firstUserMsg == "" {
|
|
firstUserMsg = truncateString(content, 100)
|
|
}
|
|
case "assistant":
|
|
if firstAssistantMsg == "" {
|
|
firstAssistantMsg = truncateString(content, 100)
|
|
}
|
|
}
|
|
|
|
if systemPrompt != "" && firstUserMsg != "" && firstAssistantMsg != "" {
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
// Claude API: top-level "system" field (array or string)
|
|
if systemPrompt == "" {
|
|
topSystem := gjson.GetBytes(payload, "system")
|
|
if topSystem.Exists() {
|
|
if topSystem.IsArray() {
|
|
topSystem.ForEach(func(_, part gjson.Result) bool {
|
|
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
|
systemPrompt = truncateString(text, 100)
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
} else if topSystem.Type == gjson.String {
|
|
systemPrompt = truncateString(topSystem.String(), 100)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Gemini format
|
|
if systemPrompt == "" && firstUserMsg == "" {
|
|
sysInstr := gjson.GetBytes(payload, "systemInstruction.parts")
|
|
if sysInstr.Exists() && sysInstr.IsArray() {
|
|
sysInstr.ForEach(func(_, part gjson.Result) bool {
|
|
if text := part.Get("text").String(); text != "" && systemPrompt == "" {
|
|
systemPrompt = truncateString(text, 100)
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
contents := gjson.GetBytes(payload, "contents")
|
|
if contents.Exists() && contents.IsArray() {
|
|
contents.ForEach(func(_, msg gjson.Result) bool {
|
|
role := msg.Get("role").String()
|
|
msg.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
|
text := part.Get("text").String()
|
|
if text == "" {
|
|
return true
|
|
}
|
|
switch role {
|
|
case "user":
|
|
if firstUserMsg == "" {
|
|
firstUserMsg = truncateString(text, 100)
|
|
}
|
|
case "model":
|
|
if firstAssistantMsg == "" {
|
|
firstAssistantMsg = truncateString(text, 100)
|
|
}
|
|
}
|
|
return false
|
|
})
|
|
if firstUserMsg != "" && firstAssistantMsg != "" {
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
// OpenAI Responses API format (v1/responses)
|
|
if systemPrompt == "" && firstUserMsg == "" {
|
|
if instr := gjson.GetBytes(payload, "instructions").String(); instr != "" {
|
|
systemPrompt = truncateString(instr, 100)
|
|
}
|
|
|
|
input := gjson.GetBytes(payload, "input")
|
|
if input.Exists() && input.IsArray() {
|
|
input.ForEach(func(_, item gjson.Result) bool {
|
|
itemType := item.Get("type").String()
|
|
if itemType == "reasoning" {
|
|
return true
|
|
}
|
|
// Skip non-message typed items (function_call, function_call_output, etc.)
|
|
// but allow items with no type that have a role (inline message format).
|
|
if itemType != "" && itemType != "message" {
|
|
return true
|
|
}
|
|
|
|
role := item.Get("role").String()
|
|
if itemType == "" && role == "" {
|
|
return true
|
|
}
|
|
|
|
// Handle both string content and array content (multimodal).
|
|
content := item.Get("content")
|
|
var text string
|
|
if content.Type == gjson.String {
|
|
text = content.String()
|
|
} else {
|
|
text = extractResponsesAPIContent(content)
|
|
}
|
|
if text == "" {
|
|
return true
|
|
}
|
|
|
|
switch role {
|
|
case "developer", "system":
|
|
if systemPrompt == "" {
|
|
systemPrompt = truncateString(text, 100)
|
|
}
|
|
case "user":
|
|
if firstUserMsg == "" {
|
|
firstUserMsg = truncateString(text, 100)
|
|
}
|
|
case "assistant":
|
|
if firstAssistantMsg == "" {
|
|
firstAssistantMsg = truncateString(text, 100)
|
|
}
|
|
}
|
|
|
|
if firstUserMsg != "" && firstAssistantMsg != "" {
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
if systemPrompt == "" && firstUserMsg == "" {
|
|
return "", ""
|
|
}
|
|
|
|
shortHash := computeSessionHash(systemPrompt, firstUserMsg, "")
|
|
if firstAssistantMsg == "" {
|
|
return shortHash, ""
|
|
}
|
|
|
|
fullHash := computeSessionHash(systemPrompt, firstUserMsg, firstAssistantMsg)
|
|
return fullHash, shortHash
|
|
}
|
|
|
|
func computeSessionHash(systemPrompt, userMsg, assistantMsg string) string {
|
|
h := fnv.New64a()
|
|
if systemPrompt != "" {
|
|
h.Write([]byte("sys:" + systemPrompt + "\n"))
|
|
}
|
|
if userMsg != "" {
|
|
h.Write([]byte("usr:" + userMsg + "\n"))
|
|
}
|
|
if assistantMsg != "" {
|
|
h.Write([]byte("ast:" + assistantMsg + "\n"))
|
|
}
|
|
return fmt.Sprintf("msg:%016x", h.Sum64())
|
|
}
|
|
|
|
func truncateString(s string, maxLen int) string {
|
|
if len(s) > maxLen {
|
|
return s[:maxLen]
|
|
}
|
|
return s
|
|
}
|
|
|
|
// extractMessageContent extracts text content from a message content field.
|
|
// Handles both string content and array content (multimodal messages).
|
|
// For array content, extracts text from all text-type elements.
|
|
func extractMessageContent(content gjson.Result) string {
|
|
// String content: "Hello world"
|
|
if content.Type == gjson.String {
|
|
return content.String()
|
|
}
|
|
|
|
// Array content: [{"type":"text","text":"Hello"},{"type":"image",...}]
|
|
if content.IsArray() {
|
|
var texts []string
|
|
content.ForEach(func(_, part gjson.Result) bool {
|
|
// Handle Claude format: {"type":"text","text":"content"}
|
|
if part.Get("type").String() == "text" {
|
|
if text := part.Get("text").String(); text != "" {
|
|
texts = append(texts, text)
|
|
}
|
|
}
|
|
// Handle OpenAI format: {"type":"text","text":"content"}
|
|
// Same structure as Claude, already handled above
|
|
return true
|
|
})
|
|
if len(texts) > 0 {
|
|
return strings.Join(texts, " ")
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func extractResponsesAPIContent(content gjson.Result) string {
|
|
if !content.IsArray() {
|
|
return ""
|
|
}
|
|
var texts []string
|
|
content.ForEach(func(_, part gjson.Result) bool {
|
|
partType := part.Get("type").String()
|
|
if partType == "input_text" || partType == "output_text" || partType == "text" {
|
|
if text := part.Get("text").String(); text != "" {
|
|
texts = append(texts, text)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
if len(texts) > 0 {
|
|
return strings.Join(texts, " ")
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// extractSessionID is kept for backward compatibility.
|
|
// Deprecated: Use ExtractSessionID instead.
|
|
func extractSessionID(payload []byte) string {
|
|
return ExtractSessionID(nil, payload, nil)
|
|
}
|