fix(gemini): clean tool schemas and eager_input_streaming
delegate schema sanitization to util.CleanJSONSchemaForGemini and drop the top-level eager_input_streaming key to prevent validation errors when sending claude tools to the gemini api
This commit is contained in:
@@ -253,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = true
|
||||||
c.Writer = rewriter
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
@@ -267,6 +268,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// proxies (e.g. NewAPI) may return a different model name and lack
|
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||||
// Amp-required fields like thinking.signature.
|
// Amp-required fields like thinking.signature.
|
||||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||||
|
rewriter.suppressThinking = providerName != "claude"
|
||||||
c.Writer = rewriter
|
c.Writer = rewriter
|
||||||
// Filter Anthropic-Beta header only for local handling paths
|
// Filter Anthropic-Beta header only for local handling paths
|
||||||
filterAntropicBetaHeader(c)
|
filterAntropicBetaHeader(c)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ type ResponseRewriter struct {
|
|||||||
body *bytes.Buffer
|
body *bytes.Buffer
|
||||||
originalModel string
|
originalModel string
|
||||||
isStreaming bool
|
isStreaming bool
|
||||||
suppressedContentBlock map[int]struct{}
|
suppressThinking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||||
@@ -28,8 +28,7 @@ func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRe
|
|||||||
return &ResponseRewriter{
|
return &ResponseRewriter{
|
||||||
ResponseWriter: w,
|
ResponseWriter: w,
|
||||||
body: &bytes.Buffer{},
|
body: &bytes.Buffer{},
|
||||||
originalModel: originalModel,
|
originalModel: originalModel,
|
||||||
suppressedContentBlock: make(map[int]struct{}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,7 +90,8 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if rw.isStreaming {
|
if rw.isStreaming {
|
||||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
rewritten := rw.rewriteStreamChunk(data)
|
||||||
|
n, err := rw.ResponseWriter.Write(rewritten)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
@@ -154,19 +154,11 @@ func ensureAmpSignature(data []byte) []byte {
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *ResponseRewriter) markSuppressedContentBlock(index int) {
|
|
||||||
if rw.suppressedContentBlock == nil {
|
|
||||||
rw.suppressedContentBlock = make(map[int]struct{})
|
|
||||||
}
|
|
||||||
rw.suppressedContentBlock[index] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool {
|
|
||||||
_, ok := rw.suppressedContentBlock[index]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||||
|
if !rw.suppressThinking {
|
||||||
|
return data
|
||||||
|
}
|
||||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||||
if filtered.Exists() {
|
if filtered.Exists() {
|
||||||
@@ -177,33 +169,11 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
|||||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||||
} else {
|
|
||||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
eventType := gjson.GetBytes(data, "type").String()
|
|
||||||
indexResult := gjson.GetBytes(data, "index")
|
|
||||||
if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() {
|
|
||||||
rw.markSuppressedContentBlock(int(indexResult.Int()))
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" {
|
|
||||||
if indexResult.Exists() {
|
|
||||||
rw.markSuppressedContentBlock(int(indexResult.Int()))
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if eventType == "content_block_stop" && indexResult.Exists() {
|
|
||||||
index := int(indexResult.Int())
|
|
||||||
if rw.isSuppressedContentBlock(index) {
|
|
||||||
delete(rw.suppressedContentBlock, index)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -255,7 +225,6 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
|||||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
rewritten := rw.rewriteStreamEvent(jsonData)
|
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||||
if rewritten == nil {
|
if rewritten == nil {
|
||||||
// Event suppressed (e.g. thinking block), skip event+data pair
|
|
||||||
i = dataIdx + 1
|
i = dataIdx + 1
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -303,12 +272,6 @@ func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
|||||||
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
||||||
// It rewrites model names and ensures signature fields exist.
|
// It rewrites model names and ensures signature fields exist.
|
||||||
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||||
// Suppress thinking blocks before any other processing.
|
|
||||||
data = rw.suppressAmpThinking(data)
|
|
||||||
if len(data) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Inject empty signature where needed
|
// Inject empty signature where needed
|
||||||
data = ensureAmpSignature(data)
|
data = ensureAmpSignature(data)
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@
|
|||||||
package claude
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
@@ -31,8 +30,6 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
|||||||
// - []byte: The transformed request in Gemini CLI format.
|
// - []byte: The transformed request in Gemini CLI format.
|
||||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
rawJSON := inputRawJSON
|
rawJSON := inputRawJSON
|
||||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
|
||||||
|
|
||||||
// Build output Gemini CLI request JSON
|
// Build output Gemini CLI request JSON
|
||||||
out := []byte(`{"contents":[]}`)
|
out := []byte(`{"contents":[]}`)
|
||||||
out, _ = sjson.SetBytes(out, "model", modelName)
|
out, _ = sjson.SetBytes(out, "model", modelName)
|
||||||
@@ -152,7 +149,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
|
||||||
inputSchemaResult := toolResult.Get("input_schema")
|
inputSchemaResult := toolResult.Get("input_schema")
|
||||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||||
inputSchema := inputSchemaResult.Raw
|
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
|
||||||
tool := []byte(toolResult.Raw)
|
tool := []byte(toolResult.Raw)
|
||||||
var err error
|
var err error
|
||||||
tool, err = sjson.DeleteBytes(tool, "input_schema")
|
tool, err = sjson.DeleteBytes(tool, "input_schema")
|
||||||
@@ -168,6 +165,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
|||||||
tool, _ = sjson.DeleteBytes(tool, "type")
|
tool, _ = sjson.DeleteBytes(tool, "type")
|
||||||
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
tool, _ = sjson.DeleteBytes(tool, "cache_control")
|
||||||
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
tool, _ = sjson.DeleteBytes(tool, "defer_loading")
|
||||||
|
tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming")
|
||||||
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String()))
|
||||||
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() {
|
||||||
if !hasTools {
|
if !hasTools {
|
||||||
|
|||||||
Reference in New Issue
Block a user