fix(amp): preserve lowercase glob tool name
This commit is contained in:
@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
|
|||||||
|
|
||||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||||
|
|
||||||
|
// ampCanonicalToolNames maps tool names to the exact casing expected by the
|
||||||
|
// Amp mode tool whitelist (case-sensitive match).
|
||||||
|
var ampCanonicalToolNames = map[string]string{
|
||||||
|
"bash": "Bash",
|
||||||
|
"read": "Read",
|
||||||
|
"grep": "Grep",
|
||||||
|
"glob": "glob",
|
||||||
|
"task": "Task",
|
||||||
|
"check": "Check",
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
|
||||||
|
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
|
||||||
|
// which causes Amp's case-sensitive mode whitelist to reject them.
|
||||||
|
func normalizeAmpToolNames(data []byte) []byte {
|
||||||
|
// Non-streaming: content[].name in tool_use blocks
|
||||||
|
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||||
|
if block.Get("type").String() != "tool_use" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
name := block.Get("name").String()
|
||||||
|
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||||
|
path := fmt.Sprintf("content.%d.name", index)
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, path, canonical)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming: content_block.name in content_block_start events
|
||||||
|
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
|
||||||
|
name := gjson.GetBytes(data, "content_block.name").String()
|
||||||
|
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||||
|
var err error
|
||||||
|
data, err = sjson.SetBytes(data, "content_block.name", canonical)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||||
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||||
func ensureAmpSignature(data []byte) []byte {
|
func ensureAmpSignature(data []byte) []byte {
|
||||||
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
|||||||
|
|
||||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
data = ensureAmpSignature(data)
|
data = ensureAmpSignature(data)
|
||||||
|
data = normalizeAmpToolNames(data)
|
||||||
data = rw.suppressAmpThinking(data)
|
data = rw.suppressAmpThinking(data)
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return data
|
return data
|
||||||
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
|||||||
// Inject empty signature where needed
|
// Inject empty signature where needed
|
||||||
data = ensureAmpSignature(data)
|
data = ensureAmpSignature(data)
|
||||||
|
|
||||||
|
// Normalize tool names to canonical casing
|
||||||
|
data = normalizeAmpToolNames(data)
|
||||||
|
|
||||||
// Rewrite model name
|
// Rewrite model name
|
||||||
if rw.originalModel != "" {
|
if rw.originalModel != "" {
|
||||||
for _, path := range modelFieldPaths {
|
for _, path := range modelFieldPaths {
|
||||||
|
|||||||
@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
|
||||||
|
result := normalizeAmpToolNames(input)
|
||||||
|
|
||||||
|
if !contains(result, []byte(`"name":"Bash"`)) {
|
||||||
|
t.Errorf("expected bash->Bash, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"name":"Read"`)) {
|
||||||
|
t.Errorf("expected read->Read, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte(`"name":"bash"`)) {
|
||||||
|
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
|
||||||
|
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
|
||||||
|
result := normalizeAmpToolNames(input)
|
||||||
|
|
||||||
|
if !contains(result, []byte(`"name":"Grep"`)) {
|
||||||
|
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||||
|
result := normalizeAmpToolNames(input)
|
||||||
|
|
||||||
|
if string(result) != string(input) {
|
||||||
|
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
|
||||||
|
result := normalizeAmpToolNames(input)
|
||||||
|
|
||||||
|
if string(result) != string(input) {
|
||||||
|
t.Errorf("expected glob to remain lowercase, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
|
||||||
|
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
|
||||||
|
result := normalizeAmpToolNames(input)
|
||||||
|
|
||||||
|
if string(result) != string(input) {
|
||||||
|
t.Errorf("expected no modification for unknown tool, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func contains(data, substr []byte) bool {
|
func contains(data, substr []byte) bool {
|
||||||
for i := 0; i <= len(data)-len(substr); i++ {
|
for i := 0; i <= len(data)-len(substr); i++ {
|
||||||
if string(data[i:i+len(substr)]) == string(substr) {
|
if string(data[i:i+len(substr)]) == string(substr) {
|
||||||
|
|||||||
Reference in New Issue
Block a user