From 3c0c61aaf1193425d20fac29528089a7aafb6be2 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Fri, 11 Jul 2025 13:46:27 +0800 Subject: [PATCH] Add Claude compatibility and enhance API handling - Integrated Claude API compatibility in handlers, translators, and server routes. - Introduced `/messages` endpoint and upgraded `AuthMiddleware` for `X-Api-Key` header. - Improved streaming response handling with `ConvertCliToClaude` for SSE compatibility. - Enhanced request processing and tool-response mapping in translators. - Updated README to reflect Claude integration and clarify supported features. --- README.md | 6 +- internal/api/handlers.go | 80 ----------- internal/api/server.go | 6 +- internal/api/translator/request.go | 203 ++++++++++++++++++++++++++-- internal/api/translator/response.go | 192 ++++++++++++++++++++++++++ internal/client/client.go | 86 ++++++++++-- 6 files changed, 463 insertions(+), 110 deletions(-) diff --git a/README.md b/README.md index c03f3265..f8bf1380 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # CLI Proxy API -A proxy server that provides an OpenAI-compatible/Gemini-compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini API. +A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API. ## Features -- OpenAI/Gemini compatible API endpoints for CLI models +- OpenAI/Gemini/Claude compatible API endpoints for CLI models - Support for both streaming and non-streaming responses - Function calling/tools support - Multimodal input support (text and images) @@ -136,7 +136,7 @@ console.log(response.choices[0].message.content); - gemini-2.5-pro - gemini-2.5-flash -- And various preview versions +- And it automates switching to various preview versions ## Configuration diff --git a/internal/api/handlers.go b/internal/api/handlers.go index f627914d..589d787d 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -41,46 +41,6 @@ func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandler func (h *APIHandlers) Models(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "data": []map[string]any{ - { - "id": "gemini-2.5-pro-preview-05-06", - "object": "model", - "version": "2.5-preview-05-06", - "name": "Gemini 2.5 Pro Preview 05-06", - "description": "Preview release (May 6th, 2025) of Gemini 2.5 Pro", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "id": "gemini-2.5-pro-preview-06-05", - "object": "model", - "version": "2.5-preview-06-05", - "name": "Gemini 2.5 Pro Preview 06-05", - "description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, { "id": "gemini-2.5-pro", "object": "model", @@ -101,46 +61,6 @@ func (h *APIHandlers) Models(c *gin.Context) { "maxTemperature": 2, "thinking": true, }, - { - "id": "gemini-2.5-flash-preview-04-17", - "object": "model", - "version": "2.5-preview-04-17", - "name": "Gemini 2.5 Flash Preview 04-17", - "description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "id": "gemini-2.5-flash-preview-05-20", - "object": "model", - "version": "2.5-preview-05-20", - "name": "Gemini 2.5 Flash Preview 05-20", - "description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash", - "context_length": 1048576, - "max_completion_tokens": 65536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, { "id": "gemini-2.5-flash", "object": "model", diff --git a/internal/api/server.go b/internal/api/server.go index eb360e10..b5f352ac 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -68,6 +68,7 @@ func (s *Server) setupRoutes() { { v1.GET("/models", s.handlers.Models) v1.POST("/chat/completions", s.handlers.ChatCompletions) + v1.POST("/messages", s.handlers.ClaudeMessages) } // Gemini compatible API routes @@ -149,7 +150,8 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { // Get the Authorization header authHeader := c.GetHeader("Authorization") authHeaderGoogle := c.GetHeader("X-Goog-Api-Key") - if authHeader == "" && authHeaderGoogle == "" { + authHeaderAnthropic := c.GetHeader("X-Api-Key") + if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "error": "Missing API key", }) @@ -168,7 +170,7 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { // Find the API key in the in-memory list var foundKey string for i := range cfg.ApiKeys { - if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle { + if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic { foundKey = cfg.ApiKeys[i] break } diff --git a/internal/api/translator/request.go b/internal/api/translator/request.go index d8c3a75a..73bbbcf7 100644 --- a/internal/api/translator/request.go +++ b/internal/api/translator/request.go @@ -1,6 +1,7 @@ package translator import ( + "bytes" "encoding/json" "fmt" "github.com/tidwall/sjson" @@ -22,11 +23,15 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, modelName = modelResult.String() } - // Process the array of messages. + // Initialize data structures for processing conversation messages + // contents: stores the processed conversation history + // systemInstruction: stores system-level instructions separate from conversation contents := make([]client.Content, 0) var systemInstruction *client.Content messagesResult := gjson.GetBytes(rawJson, "messages") + // Pre-process tool responses to create a lookup map + // This first pass collects all tool responses so they can be matched with their corresponding calls toolItems := make(map[string]*client.FunctionResponse) if messagesResult.IsArray() { messagesResults := messagesResult.Array() @@ -37,21 +42,26 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, continue } contentResult := messageResult.Get("content") + + // Extract tool responses for later matching with function calls if roleResult.String() == "tool" { toolCallID := messageResult.Get("tool_call_id").String() if toolCallID != "" { var responseData string + // Handle both string and object-based tool response formats if contentResult.Type == gjson.String { responseData = contentResult.String() } else if contentResult.IsObject() && contentResult.Get("type").String() == "text" { responseData = contentResult.Get("text").String() } - // drop the timestamp from the tool call ID + // Clean up tool call ID by removing timestamp suffix + // This normalizes IDs for consistent matching between calls and responses toolCallIDs := strings.Split(toolCallID, "-") strings.Join(toolCallIDs, "-") newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-") + // Create function response object with normalized ID and response data functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}} toolItems[toolCallID] = &functionResponse } @@ -126,25 +136,33 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, } contents = append(contents, client.Content{Role: "user", Parts: parts}) } - // Assistant messages can contain text or tool calls. + // Assistant messages can contain text responses or tool calls + // In the internal format, assistant messages are converted to "model" role case "assistant": if contentResult.Type == gjson.String { + // Simple text response from the assistant contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}}) } else if !contentResult.Exists() || contentResult.Type == gjson.Null { - // Handle tool calls made by the assistant. + // Handle complex tool calls made by the assistant + // This processes function calls and matches them with their responses functionIDs := make([]string, 0) toolCallsResult := messageResult.Get("tool_calls") if toolCallsResult.IsArray() { parts := make([]client.Part, 0) tcsResult := toolCallsResult.Array() + + // Process each tool call in the assistant's message for j := 0; j < len(tcsResult); j++ { tcResult := tcsResult[j] + // Extract function call details functionID := tcResult.Get("id").String() functionIDs = append(functionIDs, functionID) functionName := tcResult.Get("function.name").String() functionArgs := tcResult.Get("function.arguments").String() + + // Parse function arguments from JSON string to map var args map[string]any if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { parts = append(parts, client.Part{ @@ -155,17 +173,22 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, }) } } + + // Add the model's function calls to the conversation if len(parts) > 0 { contents = append(contents, client.Content{ Role: "model", Parts: parts, }) + // Create a separate tool response message with the collected responses + // This matches function calls with their corresponding responses toolParts := make([]client.Part, 0) for _, functionID := range functionIDs { if functionResponse, ok := toolItems[functionID]; ok { toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse}) } } + // Add the tool responses as a separate message in the conversation contents = append(contents, client.Content{Role: "tool", Parts: toolParts}) } } @@ -207,23 +230,28 @@ type FunctionCallGroup struct { ResponsesNeeded int } -// FixCLIToolResponse converts the format from 1.json to 2.json -// It groups function calls with their corresponding responses +// FixCLIToolResponse performs sophisticated tool response format conversion and grouping. +// This function transforms the CLI tool response format by intelligently grouping function calls +// with their corresponding responses, ensuring proper conversation flow and API compatibility. +// It converts from a linear format (1.json) to a grouped format (2.json) where function calls +// and their responses are properly associated and structured. func FixCLIToolResponse(input string) (string, error) { - // Parse the input JSON + // Parse the input JSON to extract the conversation structure parsed := gjson.Parse(input) - // Get the contents array + // Extract the contents array which contains the conversation messages contents := parsed.Get("request.contents") if !contents.Exists() { return input, fmt.Errorf("contents not found in input") } - var newContents []interface{} - var pendingGroups []*FunctionCallGroup - var collectedResponses []gjson.Result + // Initialize data structures for processing and grouping + var newContents []interface{} // Final processed contents array + var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses + var collectedResponses []gjson.Result // Standalone responses to be matched - // Process each content object + // Process each content object in the conversation + // This iterates through messages and groups function calls with their responses contents.ForEach(func(key, value gjson.Result) bool { role := value.Get("role").String() parts := value.Get("parts") @@ -363,3 +391,154 @@ func FixCLIToolResponse(input string) (string, error) { return result, nil } + +func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) { + var pathsToDelete []string + root := gjson.ParseBytes(rawJson) + walk(root, "", "additionalProperties", &pathsToDelete) + walk(root, "", "$schema", &pathsToDelete) + + var err error + for _, p := range pathsToDelete { + rawJson, err = sjson.DeleteBytes(rawJson, p) + if err != nil { + continue + } + } + rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) + + // log.Debug(string(rawJson)) + modelName := "gemini-2.5-pro" + modelResult := gjson.GetBytes(rawJson, "model") + if modelResult.Type == gjson.String { + modelName = modelResult.String() + } + + contents := make([]client.Content, 0) + + var systemInstruction *client.Content + + systemResult := gjson.GetBytes(rawJson, "system") + if systemResult.IsArray() { + systemResults := systemResult.Array() + systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}} + for i := 0; i < len(systemResults); i++ { + systemPromptResult := systemResults[i] + systemTypePromptResult := systemPromptResult.Get("type") + if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" { + systemPrompt := systemPromptResult.Get("text").String() + systemPart := client.Part{Text: systemPrompt} + systemInstruction.Parts = append(systemInstruction.Parts, systemPart) + } + } + if len(systemInstruction.Parts) == 0 { + systemInstruction = nil + } + } + + messagesResult := gjson.GetBytes(rawJson, "messages") + if messagesResult.IsArray() { + messageResults := messagesResult.Array() + for i := 0; i < len(messageResults); i++ { + messageResult := messageResults[i] + roleResult := messageResult.Get("role") + if roleResult.Type != gjson.String { + continue + } + role := roleResult.String() + if role == "assistant" { + role = "model" + } + clientContent := client.Content{Role: role, Parts: []client.Part{}} + + contentsResult := messageResult.Get("content") + if contentsResult.IsArray() { + contentResults := contentsResult.Array() + for j := 0; j < len(contentResults); j++ { + contentResult := contentResults[j] + contentTypeResult := contentResult.Get("type") + if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { + prompt := contentResult.Get("text").String() + clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt}) + } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { + functionName := contentResult.Get("name").String() + functionArgs := contentResult.Get("input").String() + var args map[string]any + if err = json.Unmarshal([]byte(functionArgs), &args); err == nil { + clientContent.Parts = append(clientContent.Parts, client.Part{ + FunctionCall: &client.FunctionCall{ + Name: functionName, + Args: args, + }, + }) + } + } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { + toolCallID := contentResult.Get("tool_use_id").String() + if toolCallID != "" { + funcName := toolCallID + toolCallIDs := strings.Split(toolCallID, "-") + if len(toolCallIDs) > 1 { + funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-") + } + responseData := contentResult.Get("content").String() + functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}} + clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse}) + } + } + } + contents = append(contents, clientContent) + } else if contentsResult.Type == gjson.String { + prompt := contentsResult.String() + contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}}) + } + } + } + + var tools []client.ToolDeclaration + toolsResult := gjson.GetBytes(rawJson, "tools") + if toolsResult.IsArray() { + tools = make([]client.ToolDeclaration, 1) + tools[0].FunctionDeclarations = make([]any, 0) + toolsResults := toolsResult.Array() + for i := 0; i < len(toolsResults); i++ { + toolResult := toolsResults[i] + inputSchemaResult := toolResult.Get("input_schema") + if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { + inputSchema := inputSchemaResult.Raw + inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties") + inputSchema, _ = sjson.Delete(inputSchema, "$schema") + + tool, _ := sjson.Delete(toolResult.Raw, "input_schema") + tool, _ = sjson.SetRaw(tool, "parameters", inputSchema) + var toolDeclaration any + if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { + tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) + } + } + } + } else { + tools = make([]client.ToolDeclaration, 0) + } + + return modelName, systemInstruction, contents, tools +} + +func walk(value gjson.Result, path, field string, pathsToDelete *[]string) { + switch value.Type { + case gjson.JSON: + value.ForEach(func(key, val gjson.Result) bool { + var childPath string + if path == "" { + childPath = key.String() + } else { + childPath = path + "." + key.String() + } + if key.String() == field { + *pathsToDelete = append(*pathsToDelete, childPath) + } + walk(val, childPath, field, pathsToDelete) + return true + }) + case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null: + } +} diff --git a/internal/api/translator/response.go b/internal/api/translator/response.go index 82c9a551..0cc45da6 100644 --- a/internal/api/translator/response.go +++ b/internal/api/translator/response.go @@ -1,6 +1,7 @@ package translator import ( + "bytes" "fmt" "time" @@ -188,3 +189,194 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey return template } + +// ConvertCliToClaude performs sophisticated streaming response format conversion. +// This function implements a complex state machine that translates backend client responses +// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types +// and handles state transitions between content blocks, thinking processes, and function calls. +// +// Response type states: 0=none, 1=content, 2=thinking, 3=function +// The function maintains state across multiple calls to ensure proper SSE event sequencing. +func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string { + // Normalize the response format for different API key types + // Generative Language API keys have a different response structure + if isGlAPIKey { + rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson) + } + + // Track whether tools are being used in this response chunk + usedTool := false + output := "" + + // Initialize the streaming session with a message_start event + // This is only sent for the very first response chunk + if !hasFirstResponse { + output = "event: message_start\n" + + // Create the initial message structure with default values + // This follows the Claude API specification for streaming message initialization + messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}` + + // Override default values with actual response metadata if available + if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String()) + } + if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { + messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String()) + } + output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate) + } + + // Process the response parts array from the backend client + // Each part can contain text content, thinking content, or function calls + partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts") + if partsResult.IsArray() { + partResults := partsResult.Array() + for i := 0; i < len(partResults); i++ { + partResult := partResults[i] + + // Extract the different types of content from each part + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") + + // Handle text content (both regular content and thinking) + if partTextResult.Exists() { + // Process thinking content (internal reasoning) + if partResult.Get("thought").Bool() { + // Continue existing thinking block + if *responseType == 2 { + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + } else { + // Transition from another state to thinking + // First, close any existing content block + if *responseType != 0 { + if *responseType == 2 { + output = output + "event: content_block_delta\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) + output = output + "\n\n\n" + } + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + } + + // Start a new thinking content block + output = output + "event: content_block_start\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex) + output = output + "\n\n\n" + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + *responseType = 2 // Set state to thinking + } + } else { + // Process regular text content (user-visible output) + // Continue existing text block + if *responseType == 1 { + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + } else { + // Transition from another state to text content + // First, close any existing content block + if *responseType != 0 { + if *responseType == 2 { + output = output + "event: content_block_delta\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) + output = output + "\n\n\n" + } + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + } + + // Start a new text content block + output = output + "event: content_block_start\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex) + output = output + "\n\n\n" + output = output + "event: content_block_delta\n" + data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String()) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + *responseType = 1 // Set state to content + } + } + } else if functionCallResult.Exists() { + // Handle function/tool calls from the AI model + // This processes tool usage requests and formats them for Claude API compatibility + usedTool = true + fcName := functionCallResult.Get("name").String() + + // Handle state transitions when switching to function calls + // Close any existing function call block first + if *responseType == 3 { + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + *responseType = 0 + } + + // Special handling for thinking state transition + if *responseType == 2 { + output = output + "event: content_block_delta\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex) + output = output + "\n\n\n" + } + + // Close any other existing content block + if *responseType != 0 { + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + *responseIndex++ + } + + // Start a new tool use content block + // This creates the structure for a function call in Claude format + output = output + "event: content_block_start\n" + + // Create the tool use block with unique ID and function details + data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex) + data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano())) + data, _ = sjson.Set(data, "content_block.name", fcName) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + output = output + "event: content_block_delta\n" + data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw) + output = output + fmt.Sprintf("data: %s\n\n\n", data) + } + *responseType = 3 + } + } + } + + usageResult := gjson.GetBytes(rawJson, "response.usageMetadata") + if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) { + if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { + output = output + "event: content_block_stop\n" + output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex) + output = output + "\n\n\n" + + output = output + "event: message_delta\n" + output = output + `data: ` + + template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + if usedTool { + template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}` + } + + thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount) + template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int()) + + output = output + template + "\n\n\n" + } + } + + return output +} diff --git a/internal/client/client.go b/internal/client/client.go index 1d5dca10..3070ac32 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -412,124 +412,184 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, } } -// SendMessageStream handles a single conversational turn, including tool calls. -func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { +// SendMessageStream handles streaming conversational turns with comprehensive parameter management. +// This function implements a sophisticated streaming system that supports tool calls, reasoning modes, +// quota management, and automatic model fallback. It returns two channels for asynchronous communication: +// one for streaming response data and another for error handling. +func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) { + // Define the data prefix used in Server-Sent Events streaming format dataTag := []byte("data: ") + + // Create channels for asynchronous communication + // errChan: delivers error messages during streaming + // dataChan: delivers response data chunks errChan := make(chan *ErrorMessage) dataChan := make(chan []byte) + + // Launch a goroutine to handle the streaming process asynchronously + // This allows the function to return immediately while processing continues in the background go func() { + // Ensure channels are properly closed when the goroutine exits defer close(errChan) defer close(dataChan) + // Configure thinking/reasoning capabilities + // Default to including thoughts unless explicitly disabled + includeThoughtsFlag := true + if len(includeThoughts) > 0 { + includeThoughtsFlag = includeThoughts[0] + } + + // Build the base request structure for the Gemini API + // This includes conversation contents and generation configuration request := GenerateContentRequest{ Contents: contents, GenerationConfig: GenerationConfig{ ThinkingConfig: GenerationConfigThinkingConfig{ - IncludeThoughts: true, + IncludeThoughts: includeThoughtsFlag, }, }, } + // Add system instructions if provided + // System instructions guide the AI's behavior and response style request.SystemInstruction = systemInstruction + // Add available tools for function calling capabilities + // Tools allow the AI to perform actions beyond text generation request.Tools = tools + // Construct the complete request body with project context + // The project ID is essential for proper API routing and billing requestBody := map[string]interface{}{ - "project": c.GetProjectID(), // Assuming ProjectID is available + "project": c.GetProjectID(), // Project ID for API routing and quota management "request": request, "model": model, } + // Serialize the request body to JSON for API transmission byteRequestBody, _ := json.Marshal(requestBody) - // log.Debug(string(byteRequestBody)) - + // Parse and configure reasoning effort levels from the original request + // This maps Claude-style reasoning effort parameters to Gemini's thinking budget system reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") if reasoningEffortResult.String() == "none" { + // Disable thinking entirely for fastest responses byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) } else if reasoningEffortResult.String() == "auto" { + // Let the model decide the appropriate thinking budget automatically byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) } else if reasoningEffortResult.String() == "low" { + // Minimal thinking for simple tasks (1KB thinking budget) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024) } else if reasoningEffortResult.String() == "medium" { + // Moderate thinking for complex tasks (8KB thinking budget) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192) } else if reasoningEffortResult.String() == "high" { + // Maximum thinking for very complex tasks (24KB thinking budget) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576) } else { + // Default to automatic thinking budget if no specific level is provided byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) } + // Configure temperature parameter for response randomness control + // Temperature affects the creativity vs consistency trade-off in responses temperatureResult := gjson.GetBytes(rawJson, "temperature") if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) } + // Configure top-p parameter for nucleus sampling + // Controls the cumulative probability threshold for token selection topPResult := gjson.GetBytes(rawJson, "top_p") if topPResult.Exists() && topPResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) } + // Configure top-k parameter for limiting token candidates + // Restricts the model to consider only the top K most likely tokens topKResult := gjson.GetBytes(rawJson, "top_k") if topKResult.Exists() && topKResult.Type == gjson.Number { byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) } - // log.Debug(string(byteRequestBody)) + // Initialize model name for quota management and potential fallback modelName := model var stream io.ReadCloser + + // Quota management and model fallback loop + // This loop handles quota exceeded scenarios and automatic model switching for { + // Check if the current model has exceeded its quota if c.isModelQuotaExceeded(modelName) { + // Attempt to switch to a preview model if configured and using account auth if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { modelName = c.getPreviewModel(model) if modelName != "" { log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) + // Update the request body with the new model name byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName) - continue + continue // Retry with the preview model } } + // If no fallback is available, return a quota exceeded error errChan <- &ErrorMessage{ StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model), } return } + + // Attempt to establish a streaming connection with the API var err *ErrorMessage stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true) if err != nil { + // Handle quota exceeded errors by marking the model and potentially retrying if err.StatusCode == 429 { now := time.Now() - c.modelQuotaExceeded[modelName] = &now + c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded + // If preview model switching is enabled, retry the loop if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" { continue } } + // Forward other errors to the error channel errChan <- err return } + // Clear any previous quota exceeded status for this model delete(c.modelQuotaExceeded, modelName) - break + break // Successfully established connection, exit the retry loop } + // Process the streaming response using a scanner + // This handles the Server-Sent Events format from the API scanner := bufio.NewScanner(stream) for scanner.Scan() { line := scanner.Bytes() - // log.Printf("Received stream chunk: %s", line) + // Filter and forward only data lines (those prefixed with "data: ") + // This extracts the actual JSON content from the SSE format if bytes.HasPrefix(line, dataTag) { - dataChan <- line[6:] + dataChan <- line[6:] // Remove "data: " prefix and send the JSON content } } + // Handle any scanning errors that occurred during stream processing if errScanner := scanner.Err(); errScanner != nil { - // log.Println(err) + // Send a 500 Internal Server Error for scanning failures errChan <- &ErrorMessage{500, errScanner} _ = stream.Close() return } + // Ensure the stream is properly closed to prevent resource leaks _ = stream.Close() }() + // Return the channels immediately for asynchronous communication + // The caller can read from these channels while the goroutine processes the request return dataChan, errChan }