fix(executor): handle OAuth tool name remapping with rename detection and add tests

Closes: #2656
This commit is contained in:
Luis Pater
2026-04-10 21:54:59 +08:00
parent 65ce86338b
commit 5ab9afac83
2 changed files with 96 additions and 46 deletions
+21 -13
View File
@@ -192,6 +192,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
bodyForTranslation := body bodyForTranslation := body
bodyForUpstream := body bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey) oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() { if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
} }
@@ -199,7 +200,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
// tools without official counterparts. This prevents Anthropic from // tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns. // fingerprinting the request as third-party via tool naming patterns.
if oauthToken { if oauthToken {
bodyForUpstream = remapOAuthToolNames(bodyForUpstream) bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
} }
// Enable cch signing by default for OAuth tokens (not just experimental flag). // Enable cch signing by default for OAuth tokens (not just experimental flag).
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint. // Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
@@ -297,7 +298,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix) data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
} }
// Reverse the OAuth tool name remap so the downstream client sees original names. // Reverse the OAuth tool name remap so the downstream client sees original names.
if isClaudeOAuthToken(apiKey) { if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
data = reverseRemapOAuthToolNames(data) data = reverseRemapOAuthToolNames(data)
} }
var param any var param any
@@ -373,6 +374,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
bodyForTranslation := body bodyForTranslation := body
bodyForUpstream := body bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey) oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() { if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix) bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
} }
@@ -380,7 +382,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
// tools without official counterparts. This prevents Anthropic from // tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns. // fingerprinting the request as third-party via tool naming patterns.
if oauthToken { if oauthToken {
bodyForUpstream = remapOAuthToolNames(bodyForUpstream) bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
} }
// Enable cch signing by default for OAuth tokens (not just experimental flag). // Enable cch signing by default for OAuth tokens (not just experimental flag).
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) { if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
@@ -474,7 +476,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
} }
if isClaudeOAuthToken(apiKey) { if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line) line = reverseRemapOAuthToolNamesFromStreamLine(line)
} }
// Forward the line as-is to preserve SSE format // Forward the line as-is to preserve SSE format
@@ -504,7 +506,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() { if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix) line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
} }
if isClaudeOAuthToken(apiKey) { if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line) line = reverseRemapOAuthToolNamesFromStreamLine(line)
} }
chunks := sdktranslator.TranslateStream( chunks := sdktranslator.TranslateStream(
@@ -561,7 +563,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
} }
// Remap tool names for OAuth token requests to avoid third-party fingerprinting. // Remap tool names for OAuth token requests to avoid third-party fingerprinting.
if isClaudeOAuthToken(apiKey) { if isClaudeOAuthToken(apiKey) {
body = remapOAuthToolNames(body) body, _ = remapOAuthToolNames(body)
} }
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL) url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
@@ -1018,7 +1020,8 @@ func isClaudeOAuthToken(apiKey string) bool {
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference // It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
// references in messages. Removed tools' corresponding tool_result blocks are preserved // references in messages. Removed tools' corresponding tool_result blocks are preserved
// (they just become orphaned, which is safe for Claude). // (they just become orphaned, which is safe for Claude).
func remapOAuthToolNames(body []byte) []byte { func remapOAuthToolNames(body []byte) ([]byte, bool) {
renamed := false
// 1. Rewrite tools array in a single pass (if present). // 1. Rewrite tools array in a single pass (if present).
// IMPORTANT: do not mutate names first and then rebuild from an older gjson // IMPORTANT: do not mutate names first and then rebuild from an older gjson
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a // snapshot. gjson results are snapshots of the original bytes; rebuilding from a
@@ -1047,10 +1050,11 @@ func remapOAuthToolNames(body []byte) []byte {
} }
toolJSON := tool.Raw toolJSON := tool.Raw
if newName, ok := oauthToolRenameMap[name]; ok { if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
updatedTool, err := sjson.Set(toolJSON, "name", newName) updatedTool, err := sjson.Set(toolJSON, "name", newName)
if err == nil { if err == nil {
toolJSON = updatedTool toolJSON = updatedTool
renamed = true
} }
} }
@@ -1073,8 +1077,9 @@ func remapOAuthToolNames(body []byte) []byte {
// The chosen tool was removed from the tools array, so drop tool_choice to // The chosen tool was removed from the tools array, so drop tool_choice to
// keep the payload internally consistent and fall back to normal auto tool use. // keep the payload internally consistent and fall back to normal auto tool use.
body, _ = sjson.DeleteBytes(body, "tool_choice") body, _ = sjson.DeleteBytes(body, "tool_choice")
} else if newName, ok := oauthToolRenameMap[tcName]; ok { } else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
body, _ = sjson.SetBytes(body, "tool_choice.name", newName) body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
renamed = true
} }
} }
@@ -1091,15 +1096,17 @@ func remapOAuthToolNames(body []byte) []byte {
switch partType { switch partType {
case "tool_use": case "tool_use":
name := part.Get("name").String() name := part.Get("name").String()
if newName, ok := oauthToolRenameMap[name]; ok { if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int()) path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName) body, _ = sjson.SetBytes(body, path, newName)
renamed = true
} }
case "tool_reference": case "tool_reference":
toolName := part.Get("tool_name").String() toolName := part.Get("tool_name").String()
if newName, ok := oauthToolRenameMap[toolName]; ok { if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int()) path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName) body, _ = sjson.SetBytes(body, path, newName)
renamed = true
} }
case "tool_result": case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[] // Handle nested tool_reference blocks inside tool_result.content[]
@@ -1110,9 +1117,10 @@ func remapOAuthToolNames(body []byte) []byte {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool { nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" { if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String() nestedToolName := nestedPart.Get("tool_name").String()
if newName, ok := oauthToolRenameMap[nestedToolName]; ok { if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int()) nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, newName) body, _ = sjson.SetBytes(body, nestedPath, newName)
renamed = true
} }
} }
return true return true
@@ -1125,7 +1133,7 @@ func remapOAuthToolNames(body []byte) []byte {
}) })
} }
return body return body, renamed
} }
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses. // reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
@@ -1949,3 +1949,45 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
t.Fatalf("temperature = %v, want 0", got) t.Fatalf("temperature = %v, want 0", got)
} }
} }
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if renamed {
t.Fatalf("renamed = true, want false")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
}
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if !renamed {
t.Fatalf("renamed = false, want true")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
t.Fatalf("content.0.name = %q, want %q", got, "bash")
}
}