327 lines
13 KiB
Go
327 lines
13 KiB
Go
// Package gemini provides request translation functionality for Gemini CLI to Gemini API compatibility.
|
|
// It handles parsing and transforming Gemini CLI API requests into Gemini API format,
|
|
// extracting model information, system instructions, message contents, and tool declarations.
|
|
// The package performs JSON data transformation to ensure compatibility
|
|
// between Gemini CLI API format and Gemini API's expected format.
|
|
package gemini
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// ConvertGeminiRequestToAntigravity parses and transforms a Gemini CLI API request into Gemini API format.
|
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
|
// from the raw JSON request and returns them in the format expected by the Gemini API.
|
|
// The function performs the following transformations:
|
|
// 1. Extracts the model information from the request
|
|
// 2. Restructures the JSON to match Gemini API format
|
|
// 3. Converts system instructions to the expected format
|
|
// 4. Fixes CLI tool response format and grouping
|
|
//
|
|
// Parameters:
|
|
// - modelName: The name of the model to use for the request (unused in current implementation)
|
|
// - rawJSON: The raw JSON request data from the Gemini CLI API
|
|
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
|
|
//
|
|
// Returns:
|
|
// - []byte: The transformed request data in Gemini API format
|
|
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
|
rawJSON := inputRawJSON
|
|
template := `{"project":"","request":{},"model":""}`
|
|
templateBytes, _ := sjson.SetRawBytes([]byte(template), "request", rawJSON)
|
|
templateBytes, _ = sjson.SetBytes(templateBytes, "model", modelName)
|
|
template = string(templateBytes)
|
|
template, _ = sjson.Delete(template, "request.model")
|
|
|
|
template, errFixCLIToolResponse := fixCLIToolResponse(template)
|
|
if errFixCLIToolResponse != nil {
|
|
return []byte{}
|
|
}
|
|
|
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
|
if systemInstructionResult.Exists() {
|
|
templateBytes, _ = sjson.SetRawBytes([]byte(template), "request.systemInstruction", []byte(systemInstructionResult.Raw))
|
|
template = string(templateBytes)
|
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
|
}
|
|
rawJSON = []byte(template)
|
|
|
|
// Normalize roles in request.contents: default to valid values if missing/invalid
|
|
contents := gjson.GetBytes(rawJSON, "request.contents")
|
|
if contents.Exists() {
|
|
prevRole := ""
|
|
idx := 0
|
|
contents.ForEach(func(_ gjson.Result, value gjson.Result) bool {
|
|
role := value.Get("role").String()
|
|
valid := role == "user" || role == "model"
|
|
if role == "" || !valid {
|
|
var newRole string
|
|
if prevRole == "" {
|
|
newRole = "user"
|
|
} else if prevRole == "user" {
|
|
newRole = "model"
|
|
} else {
|
|
newRole = "user"
|
|
}
|
|
path := fmt.Sprintf("request.contents.%d.role", idx)
|
|
rawJSON, _ = sjson.SetBytes(rawJSON, path, newRole)
|
|
role = newRole
|
|
}
|
|
prevRole = role
|
|
idx++
|
|
return true
|
|
})
|
|
}
|
|
|
|
toolsResult := gjson.GetBytes(rawJSON, "request.tools")
|
|
if toolsResult.Exists() && toolsResult.IsArray() {
|
|
toolResults := toolsResult.Array()
|
|
for i := 0; i < len(toolResults); i++ {
|
|
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations", i))
|
|
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
|
|
functionDeclarationsResults := functionDeclarationsResult.Array()
|
|
for j := 0; j < len(functionDeclarationsResults); j++ {
|
|
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j))
|
|
if parametersResult.Exists() {
|
|
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("request.tools.%d.function_declarations.%d.parametersJsonSchema", i, j))
|
|
rawJSON = []byte(strJson)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Gemini-specific handling for non-Claude models:
|
|
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
|
|
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
|
|
if !strings.Contains(modelName, "claude") {
|
|
const skipSentinel = "skip_thought_signature_validator"
|
|
|
|
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
|
if content.Get("role").String() == "model" {
|
|
// First pass: collect indices of thinking parts to mark with skip sentinel
|
|
var thinkingIndicesToSkipSignature []int64
|
|
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
|
// Collect indices of thinking blocks to mark with skip sentinel
|
|
if part.Get("thought").Bool() {
|
|
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
|
|
}
|
|
// Add skip sentinel to functionCall parts
|
|
if part.Get("functionCall").Exists() {
|
|
existingSig := part.Get("thoughtSignature").String()
|
|
if existingSig == "" || len(existingSig) < 50 {
|
|
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
|
|
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
|
|
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
|
|
idx := thinkingIndicesToSkipSignature[i]
|
|
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
|
}
|
|
|
|
// FunctionCallGroup represents a group of function calls and their responses
|
|
type FunctionCallGroup struct {
|
|
ResponsesNeeded int
|
|
CallNames []string // ordered function call names for backfilling empty response names
|
|
}
|
|
|
|
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
|
|
// Falls back to a minimal "functionResponse" object when parsing fails.
|
|
// fallbackName is used when the response's own name is empty.
|
|
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
|
|
if response.IsObject() && gjson.Valid(response.Raw) {
|
|
raw := response.Raw
|
|
name := response.Get("functionResponse.name").String()
|
|
if strings.TrimSpace(name) == "" && fallbackName != "" {
|
|
updated, _ := sjson.SetBytes([]byte(raw), "functionResponse.name", fallbackName)
|
|
raw = string(updated)
|
|
}
|
|
return raw
|
|
}
|
|
|
|
log.Debugf("parse function response failed, using fallback")
|
|
funcResp := response.Get("functionResponse")
|
|
if funcResp.Exists() {
|
|
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
|
name := funcResp.Get("name").String()
|
|
if strings.TrimSpace(name) == "" {
|
|
name = fallbackName
|
|
}
|
|
fr, _ = sjson.SetBytes(fr, "functionResponse.name", name)
|
|
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", funcResp.Get("response").String())
|
|
if id := funcResp.Get("id").String(); id != "" {
|
|
fr, _ = sjson.SetBytes(fr, "functionResponse.id", id)
|
|
}
|
|
return string(fr)
|
|
}
|
|
|
|
useName := fallbackName
|
|
if useName == "" {
|
|
useName = "unknown"
|
|
}
|
|
fr := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`)
|
|
fr, _ = sjson.SetBytes(fr, "functionResponse.name", useName)
|
|
fr, _ = sjson.SetBytes(fr, "functionResponse.response.result", response.String())
|
|
return string(fr)
|
|
}
|
|
|
|
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
|
// This function transforms the CLI tool response format by intelligently grouping function calls
|
|
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
|
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
|
|
// and their responses are properly associated and structured.
|
|
//
|
|
// Parameters:
|
|
// - input: The input JSON string to be processed
|
|
//
|
|
// Returns:
|
|
// - string: The processed JSON string with grouped function calls and responses
|
|
// - error: An error if the processing fails
|
|
func fixCLIToolResponse(input string) (string, error) {
|
|
// Parse the input JSON to extract the conversation structure
|
|
parsed := gjson.Parse(input)
|
|
|
|
// Extract the contents array which contains the conversation messages
|
|
contents := parsed.Get("request.contents")
|
|
if !contents.Exists() {
|
|
// log.Debugf(input)
|
|
return input, fmt.Errorf("contents not found in input")
|
|
}
|
|
|
|
// Initialize data structures for processing and grouping
|
|
contentsWrapper := []byte(`{"contents":[]}`)
|
|
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
|
var collectedResponses []gjson.Result // Standalone responses to be matched
|
|
|
|
// Process each content object in the conversation
|
|
// This iterates through messages and groups function calls with their responses
|
|
contents.ForEach(func(key, value gjson.Result) bool {
|
|
role := value.Get("role").String()
|
|
parts := value.Get("parts")
|
|
|
|
// Check if this content has function responses
|
|
var responsePartsInThisContent []gjson.Result
|
|
parts.ForEach(func(_, part gjson.Result) bool {
|
|
if part.Get("functionResponse").Exists() {
|
|
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
|
}
|
|
return true
|
|
})
|
|
|
|
// If this content has function responses, collect them
|
|
if len(responsePartsInThisContent) > 0 {
|
|
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
|
|
|
// Check if pending groups can be satisfied (FIFO: oldest group first)
|
|
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
|
|
group := pendingGroups[0]
|
|
pendingGroups = pendingGroups[1:]
|
|
|
|
// Take the needed responses for this group
|
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
|
|
|
// Create merged function response content
|
|
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
|
for ri, response := range groupResponses {
|
|
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
|
if partRaw != "" {
|
|
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
|
}
|
|
}
|
|
|
|
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
|
}
|
|
}
|
|
|
|
return true // Skip adding this content, responses are merged
|
|
}
|
|
|
|
// If this is a model with function calls, create a new group
|
|
if role == "model" {
|
|
var callNames []string
|
|
parts.ForEach(func(_, part gjson.Result) bool {
|
|
if part.Get("functionCall").Exists() {
|
|
callNames = append(callNames, part.Get("functionCall.name").String())
|
|
}
|
|
return true
|
|
})
|
|
|
|
if len(callNames) > 0 {
|
|
// Add the model content
|
|
if !value.IsObject() {
|
|
log.Warnf("failed to parse model content")
|
|
return true
|
|
}
|
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
|
|
|
// Create a new group for tracking responses
|
|
group := &FunctionCallGroup{
|
|
ResponsesNeeded: len(callNames),
|
|
CallNames: callNames,
|
|
}
|
|
pendingGroups = append(pendingGroups, group)
|
|
} else {
|
|
// Regular model content without function calls
|
|
if !value.IsObject() {
|
|
log.Warnf("failed to parse content")
|
|
return true
|
|
}
|
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
|
}
|
|
} else {
|
|
// Non-model content (user, etc.)
|
|
if !value.IsObject() {
|
|
log.Warnf("failed to parse content")
|
|
return true
|
|
}
|
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", []byte(value.Raw))
|
|
}
|
|
|
|
return true
|
|
})
|
|
|
|
// Handle any remaining pending groups with remaining responses
|
|
for _, group := range pendingGroups {
|
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
|
|
|
functionResponseContent := []byte(`{"parts":[],"role":"function"}`)
|
|
for ri, response := range groupResponses {
|
|
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
|
|
if partRaw != "" {
|
|
functionResponseContent, _ = sjson.SetRawBytes(functionResponseContent, "parts.-1", []byte(partRaw))
|
|
}
|
|
}
|
|
|
|
if gjson.GetBytes(functionResponseContent, "parts.#").Int() > 0 {
|
|
contentsWrapper, _ = sjson.SetRawBytes(contentsWrapper, "contents.-1", functionResponseContent)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Update the original JSON with the new contents
|
|
result, _ := sjson.SetRawBytes([]byte(input), "request.contents", []byte(gjson.GetBytes(contentsWrapper, "contents").Raw))
|
|
|
|
return string(result), nil
|
|
}
|