fix(translator): sanitize tool names for Gemini function_declarations compatibility

Claude Code and MCP clients may send tool names containing characters
invalid for Gemini's function_declarations (e.g. '/', '@', spaces).
Sanitize on request via SanitizeFunctionName and restore original names
on response for both antigravity/claude and gemini-cli/claude translators.
This commit is contained in:
sususu98
2026-03-22 13:10:53 +08:00
parent f81acd0760
commit 2398ebad55
6 changed files with 135 additions and 12 deletions
@@ -171,7 +171,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// NOTE: Do NOT inject dummy thinking blocks here. // NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected. // Antigravity API validates signatures, so dummy values are rejected.
functionName := contentResult.Get("name").String() functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
argsResult := contentResult.Get("input") argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String() functionID := contentResult.Get("id").String()
@@ -233,7 +233,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionResponseJSON := []byte(`{}`) functionResponseJSON := []byte(`{}`)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID) functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", funcName) functionResponseJSON, _ = sjson.SetBytes(functionResponseJSON, "name", util.SanitizeFunctionName(funcName))
responseData := "" responseData := ""
if functionResponseResult.Type == gjson.String { if functionResponseResult.Type == gjson.String {
@@ -398,6 +398,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw) inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema") tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema)) tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
for toolKey := range gjson.ParseBytes(tool).Map() { for toolKey := range gjson.ParseBytes(tool).Map() {
if util.InArray(allowedToolKeys, toolKey) { if util.InArray(allowedToolKeys, toolKey) {
continue continue
@@ -471,7 +472,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
case "tool": case "tool":
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" { if toolChoiceName != "" {
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
} }
} }
} }
@@ -44,6 +44,10 @@ type Params struct {
// Signature caching support // Signature caching support
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
// Reverse map: sanitized Gemini function name → original Claude tool name.
// Populated lazily on the first response chunk from the original request JSON.
ToolNameMap map[string]string
} }
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. // toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -77,6 +81,10 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
params := (*param).(*Params) params := (*param).(*Params)
if params.ToolNameMap == nil {
params.ToolNameMap = util.SanitizedToolNameMap(originalRequestRawJSON)
}
if bytes.Equal(rawJSON, []byte("[DONE]")) { if bytes.Equal(rawJSON, []byte("[DONE]")) {
output := make([]byte, 0, 256) output := make([]byte, 0, 256)
// Only send final events if we have actually output content // Only send final events if we have actually output content
@@ -212,7 +220,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Handle function/tool calls from the AI model // Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility // This processes tool usage requests and formats them for Claude Code API compatibility
params.HasToolUse = true params.HasToolUse = true
fcName := functionCallResult.Get("name").String() fcName := util.RestoreSanitizedToolName(params.ToolNameMap, functionCallResult.Get("name").String())
// Handle state transitions when switching to function calls // Handle state transitions when switching to function calls
// Close any existing function call block first // Close any existing function call block first
@@ -348,7 +356,7 @@ func resolveStopReason(params *Params) string {
// Returns: // Returns:
// - []byte: A Claude-compatible JSON response. // - []byte: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte { func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = originalRequestRawJSON toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
modelName := gjson.GetBytes(requestRawJSON, "model").String() modelName := gjson.GetBytes(requestRawJSON, "model").String()
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -450,7 +458,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
flushText() flushText()
hasToolCall = true hasToolCall = true
name := functionCall.Get("name").String() name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
toolIDCounter++ toolIDCounter++
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`) toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
@@ -89,7 +89,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "tool_use": case "tool_use":
functionName := contentResult.Get("name").String() functionName := util.SanitizeFunctionName(contentResult.Get("name").String())
functionArgs := contentResult.Get("input").String() functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs) argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) { if argsResult.IsObject() && gjson.Valid(functionArgs) {
@@ -112,7 +112,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
} }
responseData := contentResult.Get("content").Raw responseData := contentResult.Get("content").Raw
part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`) part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
part, _ = sjson.SetBytes(part, "functionResponse.name", funcName) part, _ = sjson.SetBytes(part, "functionResponse.name", util.SanitizeFunctionName(funcName))
part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData) part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
@@ -151,6 +151,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw) inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema") tool, _ := sjson.DeleteBytes([]byte(toolResult.Raw), "input_schema")
tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema)) tool, _ = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema))
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
tool, _ = sjson.DeleteBytes(tool, "strict") tool, _ = sjson.DeleteBytes(tool, "strict")
tool, _ = sjson.DeleteBytes(tool, "input_examples") tool, _ = sjson.DeleteBytes(tool, "input_examples")
tool, _ = sjson.DeleteBytes(tool, "type") tool, _ = sjson.DeleteBytes(tool, "type")
@@ -194,7 +195,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
case "tool": case "tool":
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY") out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" { if toolChoiceName != "" {
out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName}) out, _ = sjson.SetBytes(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)})
} }
} }
} }
@@ -28,6 +28,9 @@ type Params struct {
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
ResponseIndex int // Index counter for content blocks in the streaming response ResponseIndex int // Index counter for content blocks in the streaming response
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Reverse map: sanitized Gemini function name → original Claude tool name.
ToolNameMap map[string]string
} }
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. // toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -55,6 +58,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
HasFirstResponse: false, HasFirstResponse: false,
ResponseType: 0, ResponseType: 0,
ResponseIndex: 0, ResponseIndex: 0,
ToolNameMap: util.SanitizedToolNameMap(originalRequestRawJSON),
} }
} }
@@ -165,7 +169,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Handle function/tool calls from the AI model // Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility // This processes tool usage requests and formats them for Claude Code API compatibility
usedTool = true usedTool = true
fcName := functionCallResult.Get("name").String() fcName := util.RestoreSanitizedToolName((*param).(*Params).ToolNameMap, functionCallResult.Get("name").String())
// Handle state transitions when switching to function calls // Handle state transitions when switching to function calls
// Close any existing function call block first // Close any existing function call block first
@@ -248,7 +252,7 @@ func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalReque
// Returns: // Returns:
// - []byte: A Claude-compatible JSON response. // - []byte: A Claude-compatible JSON response.
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte { func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []byte {
_ = originalRequestRawJSON toolNameMap := util.SanitizedToolNameMap(originalRequestRawJSON)
_ = requestRawJSON _ = requestRawJSON
root := gjson.ParseBytes(rawJSON) root := gjson.ParseBytes(rawJSON)
@@ -306,7 +310,7 @@ func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, orig
flushText() flushText()
hasToolCall = true hasToolCall = true
name := functionCall.Get("name").String() name := util.RestoreSanitizedToolName(toolNameMap, functionCall.Get("name").String())
toolIDCounter++ toolIDCounter++
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`) toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter)) toolBlock, _ = sjson.SetBytes(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
+60
View File
@@ -54,3 +54,63 @@ func TestSanitizeFunctionName(t *testing.T) {
}) })
} }
} }
func TestSanitizedToolNameMap(t *testing.T) {
t.Run("returns map for tools needing sanitization", func(t *testing.T) {
raw := []byte(`{"tools":[
{"name":"valid_tool","input_schema":{}},
{"name":"mcp/server/read","input_schema":{}},
{"name":"tool@v2","input_schema":{}}
]}`)
m := SanitizedToolNameMap(raw)
if m == nil {
t.Fatal("expected non-nil map")
}
if m["mcp_server_read"] != "mcp/server/read" {
t.Errorf("expected mcp_server_read → mcp/server/read, got %q", m["mcp_server_read"])
}
if m["tool_v2"] != "tool@v2" {
t.Errorf("expected tool_v2 → tool@v2, got %q", m["tool_v2"])
}
if _, exists := m["valid_tool"]; exists {
t.Error("valid_tool should not be in the map (no sanitization needed)")
}
})
t.Run("returns nil when no tools need sanitization", func(t *testing.T) {
raw := []byte(`{"tools":[{"name":"Read","input_schema":{}},{"name":"Write","input_schema":{}}]}`)
m := SanitizedToolNameMap(raw)
if m != nil {
t.Errorf("expected nil, got %v", m)
}
})
t.Run("returns nil for empty/missing tools", func(t *testing.T) {
if m := SanitizedToolNameMap([]byte(`{}`)); m != nil {
t.Error("expected nil for no tools")
}
if m := SanitizedToolNameMap(nil); m != nil {
t.Error("expected nil for nil input")
}
})
}
func TestRestoreSanitizedToolName(t *testing.T) {
m := map[string]string{
"mcp_server_read": "mcp/server/read",
"tool_v2": "tool@v2",
}
if got := RestoreSanitizedToolName(m, "mcp_server_read"); got != "mcp/server/read" {
t.Errorf("expected mcp/server/read, got %q", got)
}
if got := RestoreSanitizedToolName(m, "unknown"); got != "unknown" {
t.Errorf("expected passthrough for unknown, got %q", got)
}
if got := RestoreSanitizedToolName(nil, "name"); got != "name" {
t.Errorf("expected passthrough for nil map, got %q", got)
}
if got := RestoreSanitizedToolName(m, ""); got != "" {
t.Errorf("expected empty for empty name, got %q", got)
}
}
+49
View File
@@ -271,3 +271,52 @@ func MapToolName(toolNameMap map[string]string, name string) string {
} }
return name return name
} }
// SanitizedToolNameMap builds a sanitized-name → original-name map from Claude request tools.
// It is used to restore exact tool names for clients (e.g. Claude Code) after the proxy
// sanitizes tool names for Gemini/Vertex API compatibility via SanitizeFunctionName.
// Only entries where sanitization actually changes the name are included.
func SanitizedToolNameMap(rawJSON []byte) map[string]string {
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
return nil
}
tools := gjson.GetBytes(rawJSON, "tools")
if !tools.Exists() || !tools.IsArray() {
return nil
}
out := make(map[string]string)
tools.ForEach(func(_, tool gjson.Result) bool {
name := strings.TrimSpace(tool.Get("name").String())
if name == "" {
return true
}
sanitized := SanitizeFunctionName(name)
if sanitized == name {
return true
}
if _, exists := out[sanitized]; !exists {
out[sanitized] = name
}
return true
})
if len(out) == 0 {
return nil
}
return out
}
// RestoreSanitizedToolName looks up a sanitized function name in the provided map
// and returns the original client-facing name. If no mapping exists, it returns
// the sanitized name unchanged.
func RestoreSanitizedToolName(toolNameMap map[string]string, sanitizedName string) string {
if sanitizedName == "" || toolNameMap == nil {
return sanitizedName
}
if original, ok := toolNameMap[sanitizedName]; ok {
return original
}
return sanitizedName
}