Merge pull request #2972 from XYenon/feat/amp-thread-id
feat: support X-Amp-Thread-Id for session affinity
This commit is contained in:
@@ -223,6 +223,19 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
return meta
|
return meta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// headersFromContext extracts the original HTTP request headers from the gin context
|
||||||
|
// embedded in the provided context. This allows session affinity selectors to read
|
||||||
|
// client headers like X-Amp-Thread-Id.
|
||||||
|
func headersFromContext(ctx context.Context) http.Header {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||||
|
return ginCtx.Request.Header.Clone()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func pinnedAuthIDFromContext(ctx context.Context) string {
|
func pinnedAuthIDFromContext(ctx context.Context) string {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return ""
|
return ""
|
||||||
@@ -508,6 +521,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
Alt: alt,
|
Alt: alt,
|
||||||
OriginalRequest: rawJSON,
|
OriginalRequest: rawJSON,
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
|
Headers: headersFromContext(ctx),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||||
@@ -555,6 +569,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
Alt: alt,
|
Alt: alt,
|
||||||
OriginalRequest: rawJSON,
|
OriginalRequest: rawJSON,
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
|
Headers: headersFromContext(ctx),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||||
@@ -606,6 +621,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
Alt: alt,
|
Alt: alt,
|
||||||
OriginalRequest: rawJSON,
|
OriginalRequest: rawJSON,
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
|
Headers: headersFromContext(ctx),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||||
|
|||||||
@@ -570,9 +570,10 @@ func (s *SessionAffinitySelector) InvalidateAuth(authID string) {
|
|||||||
// Priority order:
|
// Priority order:
|
||||||
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
|
// 1. metadata.user_id (Claude Code format with _session_{uuid}) - highest priority for Claude Code clients
|
||||||
// 2. X-Session-ID header
|
// 2. X-Session-ID header
|
||||||
// 3. metadata.user_id (non-Claude Code format)
|
// 3. X-Amp-Thread-Id header (Amp CLI thread ID)
|
||||||
// 4. conversation_id field in request body
|
// 4. metadata.user_id (non-Claude Code format)
|
||||||
// 5. Stable hash from first few messages content (fallback)
|
// 5. conversation_id field in request body
|
||||||
|
// 6. Stable hash from first few messages content (fallback)
|
||||||
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
|
func ExtractSessionID(headers http.Header, payload []byte, metadata map[string]any) string {
|
||||||
primary, _ := extractSessionIDs(headers, payload, metadata)
|
primary, _ := extractSessionIDs(headers, payload, metadata)
|
||||||
return primary
|
return primary
|
||||||
@@ -608,22 +609,29 @@ func extractSessionIDs(headers http.Header, payload []byte, metadata map[string]
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3. X-Amp-Thread-Id header (Amp CLI thread ID)
|
||||||
|
if headers != nil {
|
||||||
|
if tid := headers.Get("X-Amp-Thread-Id"); tid != "" {
|
||||||
|
return "amp:" + tid, ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(payload) == 0 {
|
if len(payload) == 0 {
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. metadata.user_id (non-Claude Code format)
|
// 4. metadata.user_id (non-Claude Code format)
|
||||||
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
userID := gjson.GetBytes(payload, "metadata.user_id").String()
|
||||||
if userID != "" {
|
if userID != "" {
|
||||||
return "user:" + userID, ""
|
return "user:" + userID, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. conversation_id field
|
// 5. conversation_id field
|
||||||
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
|
if convID := gjson.GetBytes(payload, "conversation_id").String(); convID != "" {
|
||||||
return "conv:" + convID, ""
|
return "conv:" + convID, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5. Hash-based fallback from message content
|
// 6. Hash-based fallback from message content
|
||||||
return extractMessageHashIDs(payload)
|
return extractMessageHashIDs(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -776,6 +776,46 @@ func TestExtractSessionID_Headers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractSessionID_AmpThreadId(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("X-Amp-Thread-Id", "T-7873e6bd-6354-4a9a-be2c-c7702c6e1b64")
|
||||||
|
|
||||||
|
got := ExtractSessionID(headers, nil, nil)
|
||||||
|
want := "amp:T-7873e6bd-6354-4a9a-be2c-c7702c6e1b64"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("ExtractSessionID() with X-Amp-Thread-Id = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractSessionID_AmpThreadIdLowerPriority verifies X-Amp-Thread-Id is lower
|
||||||
|
// priority than Claude Code metadata.user_id but higher than conversation_id.
|
||||||
|
func TestExtractSessionID_AmpThreadIdPriority(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// X-Amp-Thread-Id should be used when no Claude Code user_id is present
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("X-Amp-Thread-Id", "T-priority-test")
|
||||||
|
|
||||||
|
payload := []byte(`{"conversation_id":"conv-12345"}`)
|
||||||
|
got := ExtractSessionID(headers, payload, nil)
|
||||||
|
want := "amp:T-priority-test"
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want %q (Amp thread ID should take priority over conversation_id)", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude Code user_id should take priority over X-Amp-Thread-Id
|
||||||
|
headers2 := make(http.Header)
|
||||||
|
headers2.Set("X-Amp-Thread-Id", "T-priority-test")
|
||||||
|
payload2 := []byte(`{"metadata":{"user_id":"user_xxx_account__session_ac980658-63bd-4fb3-97ba-8da64cb1e344"}}`)
|
||||||
|
got2 := ExtractSessionID(headers2, payload2, nil)
|
||||||
|
want2 := "claude:ac980658-63bd-4fb3-97ba-8da64cb1e344"
|
||||||
|
if got2 != want2 {
|
||||||
|
t.Errorf("ExtractSessionID() = %q, want %q (Claude Code should take priority over Amp thread ID)", got2, want2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally
|
// TestExtractSessionID_IdempotencyKey verifies that idempotency_key is intentionally
|
||||||
// ignored for session affinity (it's auto-generated per-request, causing cache misses).
|
// ignored for session affinity (it's auto-generated per-request, causing cache misses).
|
||||||
func TestExtractSessionID_IdempotencyKey(t *testing.T) {
|
func TestExtractSessionID_IdempotencyKey(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user