Merge pull request #2522 from aikins01/fix/strip-tool-use-signature
fix(amp): strip signature from tool_use blocks before forwarding to Claude
This commit is contained in:
@@ -2,6 +2,7 @@ package amp
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -290,8 +291,10 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||||
// from the messages array in a request body before forwarding to the upstream API.
|
// and strips the proxy-injected "signature" field from tool_use blocks in the messages
|
||||||
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
|
// array before forwarding to the upstream API.
|
||||||
|
// This prevents 400 errors from the API which requires valid signatures on thinking
|
||||||
|
// blocks and does not accept a signature field on tool_use blocks.
|
||||||
func SanitizeAmpRequestBody(body []byte) []byte {
|
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||||
messages := gjson.GetBytes(body, "messages")
|
messages := gjson.GetBytes(body, "messages")
|
||||||
if !messages.Exists() || !messages.IsArray() {
|
if !messages.Exists() || !messages.IsArray() {
|
||||||
@@ -309,21 +312,30 @@ func SanitizeAmpRequestBody(body []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var keepBlocks []interface{}
|
var keepBlocks []interface{}
|
||||||
removedCount := 0
|
contentModified := false
|
||||||
|
|
||||||
for _, block := range content.Array() {
|
for _, block := range content.Array() {
|
||||||
blockType := block.Get("type").String()
|
blockType := block.Get("type").String()
|
||||||
if blockType == "thinking" {
|
if blockType == "thinking" {
|
||||||
sig := block.Get("signature")
|
sig := block.Get("signature")
|
||||||
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||||
removedCount++
|
contentModified = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
keepBlocks = append(keepBlocks, block.Value())
|
|
||||||
|
// Use raw JSON to prevent float64 rounding of large integers in tool_use inputs
|
||||||
|
blockRaw := []byte(block.Raw)
|
||||||
|
if blockType == "tool_use" && block.Get("signature").Exists() {
|
||||||
|
blockRaw, _ = sjson.DeleteBytes(blockRaw, "signature")
|
||||||
|
contentModified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// sjson.SetBytes supports raw JSON strings if wrapped in gjson.Raw
|
||||||
|
keepBlocks = append(keepBlocks, json.RawMessage(blockRaw))
|
||||||
}
|
}
|
||||||
|
|
||||||
if removedCount > 0 {
|
if contentModified {
|
||||||
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||||
var err error
|
var err error
|
||||||
if len(keepBlocks) == 0 {
|
if len(keepBlocks) == 0 {
|
||||||
@@ -332,11 +344,10 @@ func SanitizeAmpRequestBody(body []byte) []byte {
|
|||||||
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
|
log.Warnf("Amp RequestSanitizer: failed to sanitize message %d: %v", msgIdx, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
modified = true
|
modified = true
|
||||||
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -145,6 +145,36 @@ func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_StripsSignatureFromToolUseBlocks(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"thought","signature":"valid-sig"},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte(`"signature":""`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"valid-sig"`)) {
|
||||||
|
t.Fatalf("expected thinking signature to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testing.T) {
|
||||||
|
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-me","signature":""},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"},"signature":""}]}]}`)
|
||||||
|
result := SanitizeAmpRequestBody(input)
|
||||||
|
|
||||||
|
if contains(result, []byte("drop-me")) {
|
||||||
|
t.Fatalf("expected invalid thinking block to be removed, got %s", string(result))
|
||||||
|
}
|
||||||
|
if contains(result, []byte(`"signature"`)) {
|
||||||
|
t.Fatalf("expected signature to be stripped from tool_use block, got %s", string(result))
|
||||||
|
}
|
||||||
|
if !contains(result, []byte(`"tool_use"`)) {
|
||||||
|
t.Fatalf("expected tool_use block to remain, got %s", string(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func contains(data, substr []byte) bool {
|
func contains(data, substr []byte) bool {
|
||||||
for i := 0; i <= len(data)-len(substr); i++ {
|
for i := 0; i <= len(data)-len(substr); i++ {
|
||||||
if string(data[i:i+len(substr)]) == string(substr) {
|
if string(data[i:i+len(substr)]) == string(substr) {
|
||||||
|
|||||||
Reference in New Issue
Block a user