fix(claude): only reverse-remap OAuth tool names that were forward-renamed
remapOAuthToolNames renames lowercase client-sent tools (e.g. `glob` → `Glob`) to Claude Code equivalents on OAuth requests to avoid tool-name fingerprinting. The reverse pass previously ran against a *global* reverse map and rewrote every tool_use block whose name matched any value in oauthToolRenameMap — regardless of what the client actually sent. For clients that send mixed casing (notably Amp CLI — `Bash`, `Read`, `Grep`, `Task` alongside `glob`, `skill`, etc.) this corrupted the response. Any forward rename in the request set the "renamed" flag, which then unconditionally lowercased every `Bash` in the response to `bash`. Amp's tool registry has `Bash`, not `bash`, so it rejected the tool_use with `tool "bash" is not allowed for smart mode` and tool execution failed. Fix: `remapOAuthToolNames` now returns a per-request map keyed on the upstream (TitleCase) name valued with the original client-sent name. The reverse functions take this map and only touch entries in it. Names the client sent in TitleCase pass through untouched in both directions. - Change remapOAuthToolNames signature from `([]byte, bool)` to `([]byte, map[string]string)`; populate at every rename site (tools[], tool_choice.name, message tool_use, tool_reference, nested tool_reference inside tool_result). - Change reverseRemapOAuthToolNames and reverseRemapOAuthToolNamesFromStreamLine to accept and consume the per-request map; remove the global oauthToolRenameReverseMap. - Update all three executor call sites (Execute, ExecuteStream direct passthrough, ExecuteStream translated) + count_tokens. - Add regression tests for the mixed-case scenario in both the non-streaming and SSE code paths.
This commit is contained in:
@@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{
|
|||||||
"notebookedit": "NotebookEdit",
|
"notebookedit": "NotebookEdit",
|
||||||
}
|
}
|
||||||
|
|
||||||
// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding.
|
// The reverse map is now computed per-request in remapOAuthToolNames so that
|
||||||
var oauthToolRenameReverseMap = func() map[string]string {
|
// only names the client actually caused us to rewrite are restored on the
|
||||||
m := make(map[string]string, len(oauthToolRenameMap))
|
// response. A global reverse map — as used previously — corrupted responses
|
||||||
for k, v := range oauthToolRenameMap {
|
// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase
|
||||||
m[v] = k
|
// alongside `glob` lowercase; the request flagged renames via `glob→Glob`,
|
||||||
}
|
// then the global reverse map incorrectly rewrote every `Bash` in the
|
||||||
return m
|
// response to `bash`, causing Amp to reject the tool_use as unknown).
|
||||||
}()
|
|
||||||
|
|
||||||
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
|
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
|
||||||
// even after remapping. Currently empty — all tools are mapped instead of removed.
|
// even after remapping. Currently empty — all tools are mapped instead of removed.
|
||||||
@@ -191,7 +190,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
bodyForTranslation := body
|
bodyForTranslation := body
|
||||||
bodyForUpstream := body
|
bodyForUpstream := body
|
||||||
oauthToken := isClaudeOAuthToken(apiKey)
|
oauthToken := isClaudeOAuthToken(apiKey)
|
||||||
oauthToolNamesRemapped := false
|
var oauthToolNamesReverseMap map[string]string
|
||||||
if oauthToken && !auth.ToolPrefixDisabled() {
|
if oauthToken && !auth.ToolPrefixDisabled() {
|
||||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
@@ -199,7 +198,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
// tools without official counterparts. This prevents Anthropic from
|
// tools without official counterparts. This prevents Anthropic from
|
||||||
// fingerprinting the request as third-party via tool naming patterns.
|
// fingerprinting the request as third-party via tool naming patterns.
|
||||||
if oauthToken {
|
if oauthToken {
|
||||||
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
|
bodyForUpstream, oauthToolNamesReverseMap = remapOAuthToolNames(bodyForUpstream)
|
||||||
}
|
}
|
||||||
// Enable cch signing by default for OAuth tokens (not just experimental flag).
|
// Enable cch signing by default for OAuth tokens (not just experimental flag).
|
||||||
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
|
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
|
||||||
@@ -297,8 +296,8 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
// Reverse the OAuth tool name remap so the downstream client sees original names.
|
// Reverse the OAuth tool name remap so the downstream client sees original names.
|
||||||
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
|
if isClaudeOAuthToken(apiKey) && len(oauthToolNamesReverseMap) > 0 {
|
||||||
data = reverseRemapOAuthToolNames(data)
|
data = reverseRemapOAuthToolNames(data, oauthToolNamesReverseMap)
|
||||||
}
|
}
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(
|
out := sdktranslator.TranslateNonStream(
|
||||||
@@ -373,7 +372,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
bodyForTranslation := body
|
bodyForTranslation := body
|
||||||
bodyForUpstream := body
|
bodyForUpstream := body
|
||||||
oauthToken := isClaudeOAuthToken(apiKey)
|
oauthToken := isClaudeOAuthToken(apiKey)
|
||||||
oauthToolNamesRemapped := false
|
var oauthToolNamesReverseMap map[string]string
|
||||||
if oauthToken && !auth.ToolPrefixDisabled() {
|
if oauthToken && !auth.ToolPrefixDisabled() {
|
||||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
@@ -381,7 +380,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
// tools without official counterparts. This prevents Anthropic from
|
// tools without official counterparts. This prevents Anthropic from
|
||||||
// fingerprinting the request as third-party via tool naming patterns.
|
// fingerprinting the request as third-party via tool naming patterns.
|
||||||
if oauthToken {
|
if oauthToken {
|
||||||
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
|
bodyForUpstream, oauthToolNamesReverseMap = remapOAuthToolNames(bodyForUpstream)
|
||||||
}
|
}
|
||||||
// Enable cch signing by default for OAuth tokens (not just experimental flag).
|
// Enable cch signing by default for OAuth tokens (not just experimental flag).
|
||||||
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
|
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
|
||||||
@@ -475,8 +474,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
|
if isClaudeOAuthToken(apiKey) && len(oauthToolNamesReverseMap) > 0 {
|
||||||
line = reverseRemapOAuthToolNamesFromStreamLine(line)
|
line = reverseRemapOAuthToolNamesFromStreamLine(line, oauthToolNamesReverseMap)
|
||||||
}
|
}
|
||||||
// Forward the line as-is to preserve SSE format
|
// Forward the line as-is to preserve SSE format
|
||||||
cloned := make([]byte, len(line)+1)
|
cloned := make([]byte, len(line)+1)
|
||||||
@@ -505,8 +504,8 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||||
}
|
}
|
||||||
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
|
if isClaudeOAuthToken(apiKey) && len(oauthToolNamesReverseMap) > 0 {
|
||||||
line = reverseRemapOAuthToolNamesFromStreamLine(line)
|
line = reverseRemapOAuthToolNamesFromStreamLine(line, oauthToolNamesReverseMap)
|
||||||
}
|
}
|
||||||
chunks := sdktranslator.TranslateStream(
|
chunks := sdktranslator.TranslateStream(
|
||||||
ctx,
|
ctx,
|
||||||
@@ -1009,8 +1008,25 @@ func isClaudeOAuthToken(apiKey string) bool {
|
|||||||
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
|
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
|
||||||
// references in messages. Removed tools' corresponding tool_result blocks are preserved
|
// references in messages. Removed tools' corresponding tool_result blocks are preserved
|
||||||
// (they just become orphaned, which is safe for Claude).
|
// (they just become orphaned, which is safe for Claude).
|
||||||
func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
//
|
||||||
renamed := false
|
// The returned map is keyed on the upstream (TitleCase) name and maps to the
|
||||||
|
// client-supplied original name. Callers MUST pass this map to the reverse
|
||||||
|
// functions so only names the client actually caused us to rewrite are restored
|
||||||
|
// on the response. A global reverse map (the previous implementation) incorrectly
|
||||||
|
// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`)
|
||||||
|
// when any OTHER tool in the same request triggered a forward rename (e.g.
|
||||||
|
// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash`
|
||||||
|
// regardless of what the client originally sent.
|
||||||
|
func remapOAuthToolNames(body []byte) ([]byte, map[string]string) {
|
||||||
|
reverseMap := make(map[string]string)
|
||||||
|
recordRename := func(original, renamed string) {
|
||||||
|
// Preserve the first-seen original name if the same upstream name is
|
||||||
|
// produced from multiple call sites; they all map back identically.
|
||||||
|
if _, exists := reverseMap[renamed]; !exists {
|
||||||
|
reverseMap[renamed] = original
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Rewrite tools array in a single pass (if present).
|
// 1. Rewrite tools array in a single pass (if present).
|
||||||
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
|
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
|
||||||
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
|
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
|
||||||
@@ -1043,7 +1059,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
|||||||
updatedTool, err := sjson.Set(toolJSON, "name", newName)
|
updatedTool, err := sjson.Set(toolJSON, "name", newName)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
toolJSON = updatedTool
|
toolJSON = updatedTool
|
||||||
renamed = true
|
recordRename(name, newName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1068,7 +1084,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
|||||||
body, _ = sjson.DeleteBytes(body, "tool_choice")
|
body, _ = sjson.DeleteBytes(body, "tool_choice")
|
||||||
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
|
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
|
||||||
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
|
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
|
||||||
renamed = true
|
recordRename(tcName, newName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1088,14 +1104,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
|||||||
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
|
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
|
||||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||||
body, _ = sjson.SetBytes(body, path, newName)
|
body, _ = sjson.SetBytes(body, path, newName)
|
||||||
renamed = true
|
recordRename(name, newName)
|
||||||
}
|
}
|
||||||
case "tool_reference":
|
case "tool_reference":
|
||||||
toolName := part.Get("tool_name").String()
|
toolName := part.Get("tool_name").String()
|
||||||
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
|
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
|
||||||
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||||
body, _ = sjson.SetBytes(body, path, newName)
|
body, _ = sjson.SetBytes(body, path, newName)
|
||||||
renamed = true
|
recordRename(toolName, newName)
|
||||||
}
|
}
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||||
@@ -1109,7 +1125,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
|||||||
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
|
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
|
||||||
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||||
body, _ = sjson.SetBytes(body, nestedPath, newName)
|
body, _ = sjson.SetBytes(body, nestedPath, newName)
|
||||||
renamed = true
|
recordRename(nestedToolName, newName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
@@ -1122,13 +1138,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return body, renamed
|
return body, reverseMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
|
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses
|
||||||
// It maps Claude Code TitleCase names back to the original lowercase names so the
|
// using the per-request map produced by remapOAuthToolNames. Names the client sent
|
||||||
// downstream client receives tool names it recognizes.
|
// that were NOT forward-renamed are passed through unchanged.
|
||||||
func reverseRemapOAuthToolNames(body []byte) []byte {
|
func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte {
|
||||||
|
if len(reverseMap) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
content := gjson.GetBytes(body, "content")
|
content := gjson.GetBytes(body, "content")
|
||||||
if !content.Exists() || !content.IsArray() {
|
if !content.Exists() || !content.IsArray() {
|
||||||
return body
|
return body
|
||||||
@@ -1138,13 +1157,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
|
|||||||
switch partType {
|
switch partType {
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
name := part.Get("name").String()
|
name := part.Get("name").String()
|
||||||
if origName, ok := oauthToolRenameReverseMap[name]; ok {
|
if origName, ok := reverseMap[name]; ok {
|
||||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||||
body, _ = sjson.SetBytes(body, path, origName)
|
body, _ = sjson.SetBytes(body, path, origName)
|
||||||
}
|
}
|
||||||
case "tool_reference":
|
case "tool_reference":
|
||||||
toolName := part.Get("tool_name").String()
|
toolName := part.Get("tool_name").String()
|
||||||
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
|
if origName, ok := reverseMap[toolName]; ok {
|
||||||
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||||
body, _ = sjson.SetBytes(body, path, origName)
|
body, _ = sjson.SetBytes(body, path, origName)
|
||||||
}
|
}
|
||||||
@@ -1154,8 +1173,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
|
|||||||
return body
|
return body
|
||||||
}
|
}
|
||||||
|
|
||||||
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines.
|
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE
|
||||||
func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
|
// stream lines, using the per-request reverseMap produced by remapOAuthToolNames.
|
||||||
|
func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte {
|
||||||
|
if len(reverseMap) == 0 {
|
||||||
|
return line
|
||||||
|
}
|
||||||
payload := helps.JSONPayload(line)
|
payload := helps.JSONPayload(line)
|
||||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||||
return line
|
return line
|
||||||
@@ -1173,7 +1196,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
|
|||||||
switch blockType {
|
switch blockType {
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
name := contentBlock.Get("name").String()
|
name := contentBlock.Get("name").String()
|
||||||
if origName, ok := oauthToolRenameReverseMap[name]; ok {
|
if origName, ok := reverseMap[name]; ok {
|
||||||
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
|
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return line
|
return line
|
||||||
@@ -1183,7 +1206,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
|
|||||||
}
|
}
|
||||||
case "tool_reference":
|
case "tool_reference":
|
||||||
toolName := contentBlock.Get("tool_name").String()
|
toolName := contentBlock.Get("tool_name").String()
|
||||||
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
|
if origName, ok := reverseMap[toolName]; ok {
|
||||||
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
|
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return line
|
return line
|
||||||
|
|||||||
@@ -1989,19 +1989,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
|
|||||||
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
|
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
|
||||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
out, renamed := remapOAuthToolNames(body)
|
out, reverseMap := remapOAuthToolNames(body)
|
||||||
if renamed {
|
if len(reverseMap) != 0 {
|
||||||
t.Fatalf("renamed = true, want false")
|
t.Fatalf("reverseMap = %v, want empty", reverseMap)
|
||||||
}
|
}
|
||||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
||||||
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
|
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||||
reversed := resp
|
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
|
||||||
if renamed {
|
|
||||||
reversed = reverseRemapOAuthToolNames(resp)
|
|
||||||
}
|
|
||||||
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
|
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
|
||||||
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
|
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
|
||||||
}
|
}
|
||||||
@@ -2010,20 +2007,86 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
|
|||||||
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
|
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
|
||||||
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||||
|
|
||||||
out, renamed := remapOAuthToolNames(body)
|
out, reverseMap := remapOAuthToolNames(body)
|
||||||
if !renamed {
|
if reverseMap["Bash"] != "bash" {
|
||||||
t.Fatalf("renamed = false, want true")
|
t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap)
|
||||||
}
|
}
|
||||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
||||||
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
|
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||||
reversed := resp
|
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
|
||||||
if renamed {
|
|
||||||
reversed = reverseRemapOAuthToolNames(resp)
|
|
||||||
}
|
|
||||||
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
|
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
|
||||||
t.Fatalf("content.0.name = %q, want %q", got, "bash")
|
t.Fatalf("content.0.name = %q, want %q", got, "bash")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression
|
||||||
|
// test for a case where a single request contains both a TitleCase tool (which
|
||||||
|
// must pass through unchanged) and a lowercase tool that we forward-rename.
|
||||||
|
// Before the fix, triggering ANY forward rename caused the reverse pass to
|
||||||
|
// lowercase every TitleCase tool in the response using a global reverse map,
|
||||||
|
// corrupting tool names the client originally sent in TitleCase (notably Amp
|
||||||
|
// CLI's `Bash`, which its registry lookup cannot find as `bash`).
|
||||||
|
func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) {
|
||||||
|
body := []byte(`{"tools":[` +
|
||||||
|
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
|
||||||
|
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
|
||||||
|
`]}`)
|
||||||
|
|
||||||
|
out, reverseMap := remapOAuthToolNames(body)
|
||||||
|
|
||||||
|
// Forward: TitleCase `Bash` is not a forward-map key, must pass through.
|
||||||
|
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
||||||
|
t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash")
|
||||||
|
}
|
||||||
|
// Forward: `glob` is a forward-map key, upstream sees `Glob`.
|
||||||
|
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" {
|
||||||
|
t.Fatalf("tools.1.name = %q, want %q", got, "Glob")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse map records ONLY the rename that happened.
|
||||||
|
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
|
||||||
|
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`,
|
||||||
|
// reverseRemap MUST leave it alone.
|
||||||
|
bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||||
|
reversed := reverseRemapOAuthToolNames(bashResp, reverseMap)
|
||||||
|
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
|
||||||
|
t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`,
|
||||||
|
// reverseRemap MUST restore the original `glob`.
|
||||||
|
globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`)
|
||||||
|
reversed = reverseRemapOAuthToolNames(globResp, reverseMap)
|
||||||
|
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" {
|
||||||
|
t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the
|
||||||
|
// SSE streaming code path against the same mixed-case bug.
|
||||||
|
func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) {
|
||||||
|
reverseMap := map[string]string{"Glob": "glob"}
|
||||||
|
|
||||||
|
// Bash block was never renamed, must pass through as-is.
|
||||||
|
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`)
|
||||||
|
out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap)
|
||||||
|
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
|
||||||
|
t.Fatalf("Bash should be preserved, got: %s", string(out))
|
||||||
|
}
|
||||||
|
if bytes.Contains(out, []byte(`"name":"bash"`)) {
|
||||||
|
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Glob block IS in the reverseMap, must be restored to `glob`.
|
||||||
|
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`)
|
||||||
|
out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap)
|
||||||
|
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
|
||||||
|
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user