Compare commits

...

10 Commits

Author SHA1 Message Date
Luis Pater
65f47c196a Merge pull request #1 from chaudhryfaisal/main
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
Correct config in README.md
2025-07-09 16:57:19 +08:00
Faisal Chaudhry
9be56fe8e0 Correct config in README.md 2025-07-08 23:28:55 -04:00
Luis Pater
589ae6d3aa Add support for Generative Language API Key and improve client initialization
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
- Added `GlAPIKey` support in configuration to enable Generative Language API.
- Integrated `GenerativeLanguageAPIKey` handling in client and API handlers.
- Updated response translators to manage generative language responses properly.
- Enhanced HTTP client initialization logic with proxy support for API requests.
- Refactored streaming and non-streaming flows to account for generative language-specific logic.
2025-07-06 02:13:11 +08:00
Luis Pater
7cb76ae1a5 Enhance quota management and refactor configuration handling
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
- Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively.
- Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion.
- Refactored `APIHandlers` to leverage new configuration structure.
- Simplified server initialization and removed redundant `ServerConfig` structure.
- Streamlined client initialization by unifying configuration handling throughout the project.
- Improved error handling and response mechanisms in both streaming and non-streaming flows.
2025-07-05 07:53:46 +08:00
Luis Pater
e73f165070 Refactor API handlers to streamline response handling
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
- Replaced channel-based handling in `SendMessage` flow with direct synchronous execution.
- Introduced `hasFirstResponse` flag to manage keep-alive signals in streaming handler.
- Simplified error handling and removed redundant code for enhanced readability and maintainability.
2025-07-05 04:10:00 +08:00
Luis Pater
512f2d5247 Refactor API request flow and streamline response handling
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
- Replaced `SendMessageStream` with synchronous `SendMessage` in API handlers for better manageability.
- Simplified `ConvertCliToOpenAINonStream` to reduce complexity and improve efficiency.
- Adjusted `client.go` functions to handle both streaming and non-streaming API requests more effectively.
- Improved error handling and channel communication in API handlers.
- Removed redundant and unused code for cleaner implementation.
2025-07-05 02:27:34 +08:00
Luis Pater
bf086464dd Add archive configuration to .goreleaser.yml
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
- Included LICENSE, README.md, and config.yaml in the archive section for cli-proxy-api.
2025-07-04 18:50:55 +08:00
Luis Pater
5ec6450c50 Numerous Comments Added and Extensive Optimization Performed using Roo-Code with CLIProxyAPI itself.
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
2025-07-04 18:44:55 +08:00
Luis Pater
8dd7f8e82f Update model name to include release date in API handlers
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
2025-07-04 17:26:23 +08:00
Luis Pater
582280f4c5 Refactor token management, client initialization, and project handling
Some checks failed
goreleaser / goreleaser (push) Has been cancelled
- Consolidated `TokenStorage` struct into `internal/auth/models.go` for better organization.
- Updated `Client` to use `TokenStorage` for managing email and project ID.
- Simplified `SetupUser` method to ensure proper token and project assignment.
- Refactored API handlers to leverage new `GetEmail` and `GetProjectID` methods in `Client`.
- Cleanup: Removed unused structures and redundant code from `client.go` and `auth.go`.
- Adjusted CLI flow in `login.go` and `run.go` for streamlined user onboarding.
2025-07-04 17:08:58 +08:00
17 changed files with 1189 additions and 841 deletions

View File

@@ -8,4 +8,10 @@ builds:
- amd64 - amd64
- arm64 - arm64
main: ./cmd/server/ main: ./cmd/server/
binary: cli-proxy-api binary: cli-proxy-api
archives:
- id: "cli-proxy-api"
files:
- LICENSE
- README.md
- config.yaml

View File

@@ -10,6 +10,7 @@ A proxy server that provides an OpenAI-compatible API interface for CLI. This al
- Multimodal input support (text and images) - Multimodal input support (text and images)
- Multiple account support with load balancing - Multiple account support with load balancing
- Simple CLI authentication flow - Simple CLI authentication flow
- Support for Generative Language API Key
## Installation ## Installation
@@ -146,13 +147,14 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
### Configuration Options ### Configuration Options
| Parameter | Type | Default | Description | | Parameter | Type | Default | Description |
|-------------|----------|--------------------|----------------------------------------------------------------------------------------------| |-------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
| `port` | integer | 8317 | The port number on which the server will listen | | `port` | integer | 8317 | The port number on which the server will listen |
| `auth_dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory | | `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ | | `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
| `debug` | boolean | false | Enable debug mode for verbose logging | | `debug` | boolean | false | Enable debug mode for verbose logging |
| `api_keys` | string[] | [] | List of API keys that can be used to authenticate requests | | `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
### Example Configuration File ### Example Configuration File
@@ -161,24 +163,24 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
port: 8317 port: 8317
# Authentication directory (supports ~ for home directory) # Authentication directory (supports ~ for home directory)
auth_dir: "~/.cli-proxy-api" auth-dir: "~/.cli-proxy-api"
# Enable debug logging # Enable debug logging
debug: false debug: false
# API keys for authentication # API keys for authentication
api_keys: api-keys:
- "your-api-key-1" - "your-api-key-1"
- "your-api-key-2" - "your-api-key-2"
``` ```
### Authentication Directory ### Authentication Directory
The `auth_dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing. The `auth-dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing.
### API Keys ### API Keys
The `api_keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header: The `api-keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header:
``` ```
Authorization: Bearer your-api-key-1 Authorization: Bearer your-api-key-1

View File

@@ -12,9 +12,11 @@ import (
"strings" "strings"
) )
// LogFormatter defines a custom log format for logrus.
type LogFormatter struct { type LogFormatter struct {
} }
// Format renders a single log entry.
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
var b *bytes.Buffer var b *bytes.Buffer
if entry.Buffer != nil { if entry.Buffer != nil {
@@ -25,33 +27,42 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05") timestamp := entry.Time.Format("2006-01-02 15:04:05")
var newLog string var newLog string
// Customize the log format to include timestamp, level, caller file/line, and message.
newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message) newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message)
b.WriteString(newLog) b.WriteString(newLog)
return b.Bytes(), nil return b.Bytes(), nil
} }
// init initializes the logger configuration.
func init() { func init() {
// Set logger output to standard output.
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
// Enable reporting the caller function's file and line number.
log.SetReportCaller(true) log.SetReportCaller(true)
// Set the custom log formatter.
log.SetFormatter(&LogFormatter{}) log.SetFormatter(&LogFormatter{})
} }
// main is the entry point of the application.
func main() { func main() {
var login bool var login bool
var projectID string var projectID string
var configPath string var configPath string
// Define command-line flags.
flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&login, "login", false, "Login Google Account")
flag.StringVar(&projectID, "project_id", "", "Project ID") flag.StringVar(&projectID, "project_id", "", "Project ID")
flag.StringVar(&configPath, "config", "", "Configure File Path") flag.StringVar(&configPath, "config", "", "Configure File Path")
// Parse the command-line flags.
flag.Parse() flag.Parse()
var err error var err error
var cfg *config.Config var cfg *config.Config
var wd string var wd string
// Load configuration from the specified path or the default path.
if configPath != "" { if configPath != "" {
cfg, err = config.LoadConfig(configPath) cfg, err = config.LoadConfig(configPath)
} else { } else {
@@ -65,12 +76,14 @@ func main() {
log.Fatalf("failed to load config: %v", err) log.Fatalf("failed to load config: %v", err)
} }
// Set the log level based on the configuration.
if cfg.Debug { if cfg.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} else { } else {
log.SetLevel(log.InfoLevel) log.SetLevel(log.InfoLevel)
} }
// Expand the tilde (~) in the auth directory path to the user's home directory.
if strings.HasPrefix(cfg.AuthDir, "~") { if strings.HasPrefix(cfg.AuthDir, "~") {
home, errUserHomeDir := os.UserHomeDir() home, errUserHomeDir := os.UserHomeDir()
if errUserHomeDir != nil { if errUserHomeDir != nil {
@@ -85,6 +98,7 @@ func main() {
} }
} }
// Either perform login or start the service based on the 'login' flag.
if login { if login {
cmd.DoLogin(cfg, projectID) cmd.DoLogin(cfg, projectID)
} else { } else {

View File

@@ -1,7 +1,15 @@
port: 8317 port: 8317
auth_dir: "~/.cli-proxy-api" auth-dir: "~/.cli-proxy-api"
debug: false debug: true
proxy-url: "" proxy-url: ""
api_keys: quota-exceeded:
switch-project: true
switch-preview-model: true
api-keys:
- "12345" - "12345"
- "23456" - "23456"
generative-language-api-key:
- "AIzaSy...01"
- "AIzaSy...02"
- "AIzaSy...03"
- "AIzaSy...04"

View File

@@ -2,14 +2,13 @@ package api
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@@ -21,20 +20,24 @@ var (
lastUsedClientIndex = 0 lastUsedClientIndex = 0
) )
// APIHandlers contains the handlers for API endpoints // APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct { type APIHandlers struct {
cliClients []*client.Client cliClients []*client.Client
debug bool cfg *config.Config
} }
// NewAPIHandlers creates a new API handlers instance // NewAPIHandlers creates a new API handlers instance.
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers { // It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{ return &APIHandlers{
cliClients: cliClients, cliClients: cliClients,
debug: debug, cfg: cfg,
} }
} }
// Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models.
func (h *APIHandlers) Models(c *gin.Context) { func (h *APIHandlers) Models(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": []map[string]any{ "data": []map[string]any{
@@ -62,7 +65,7 @@ func (h *APIHandlers) Models(c *gin.Context) {
"id": "gemini-2.5-pro-preview-06-05", "id": "gemini-2.5-pro-preview-06-05",
"object": "model", "object": "model",
"version": "2.5-preview-06-05", "version": "2.5-preview-06-05",
"name": "Gemini 2.5 Pro Preview", "name": "Gemini 2.5 Pro Preview 06-05",
"description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro", "description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576, "context_length": 1048576,
"max_completion_tokens": 65536, "max_completion_tokens": 65536,
@@ -162,15 +165,23 @@ func (h *APIHandlers) Models(c *gin.Context) {
}) })
} }
// ChatCompletions handles the /v1/chat/completions endpoint // ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler.
func (h *APIHandlers) ChatCompletions(c *gin.Context) { func (h *APIHandlers) ChatCompletions(c *gin.Context) {
rawJson, err := c.GetRawData() rawJson, err := c.GetRawData()
// If data retrieval fails, return 400 error // If data retrieval fails, return a 400 Bad Request error.
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err), "code": 400}) c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return return
} }
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJson, "stream") streamResult := gjson.GetBytes(rawJson, "stream")
if streamResult.Type == gjson.True { if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJson) h.handleStreamingResponse(c, rawJson)
@@ -179,184 +190,9 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
} }
} }
func (h *APIHandlers) prepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) { // handleNonStreamingResponse handles non-streaming chat completion responses.
// log.Debug(string(rawJson)) // It selects a client from the pool, sends the request, and aggregates the response
modelName := "gemini-2.5-pro" // before sending it back to the client.
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type == gjson.String {
if roleResult.String() == "system" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
}
}
}
} else if roleResult.String() == "user" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
}
}
} else if contentResult.IsArray() {
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentItemResult.Get("text")
if contentTextResult.Type == gjson.String {
parts = append(parts, client.Part{Text: contentTextResult.String()})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image_url" {
imageURLResult := contentItemResult.Get("image_url.url")
if imageURLResult.Type == gjson.String {
imageURL := imageURLResult.String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 {
if len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
}
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "file" {
filenameResult := contentItemResult.Get("file.filename")
fileDataResult := contentItemResult.Get("file.file_data")
if filenameResult.Type == gjson.String && fileDataResult.Type == gjson.String {
filename := filenameResult.String()
splitFilename := strings.Split(filename, ".")
ext := splitFilename[len(splitFilename)-1]
mimeType, ok := MimeTypes[ext]
if !ok {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
continue
}
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileDataResult.String(),
}})
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
} else if roleResult.String() == "assistant" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
}
}
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
tcsResult := toolCallsResult.Array()
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
functionNameResult := tcResult.Get("function.name")
functionArguments := tcResult.Get("function.arguments")
if functionNameResult.Exists() && functionNameResult.Type == gjson.String && functionArguments.Exists() && functionArguments.Type == gjson.String {
var args map[string]any
err := json.Unmarshal([]byte(functionArguments.String()), &args)
if err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{
{
FunctionCall: &client.FunctionCall{
Name: functionNameResult.String(),
Args: args,
},
},
},
})
}
}
}
}
}
} else if roleResult.String() == "tool" {
toolCallIDResult := messageResult.Get("tool_call_id")
if toolCallIDResult.Exists() && toolCallIDResult.Type == gjson.String {
if contentResult.Type == gjson.String {
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
}
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolTypeResult := toolsResults[i].Get("type")
if toolTypeResult.Type != gjson.String || toolTypeResult.String() != "function" {
continue
}
functionTypeResult := toolsResults[i].Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration)
if err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
}
// handleNonStreamingResponse handles non-streaming responses
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) { func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
@@ -372,7 +208,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
return return
} }
modelName, contents, tools := h.prepareRequest(rawJson) modelName, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
defer func() { defer func() {
@@ -381,69 +217,74 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
} }
}() }()
// Lock the mutex to update the last used page index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the pages to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients))
for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
}
locked := false
for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
log.Debugf("Request use account: %s, project id: %s", cliClient.Email, cliClient.ProjectID)
jsonTemplate := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
for { for {
select { // Lock the mutex to update the last used client index
case <-c.Request.Context().Done(): mutex.Lock()
if c.Request.Context().Err().Error() == "context canceled" { startIndex := lastUsedClientIndex
log.Debugf("Client disconnected: %v", c.Request.Context().Err()) currentIndex := (startIndex + 1) % len(h.cliClients)
cliCancel() lastUsedClientIndex = currentIndex
return mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
} }
case chunk, okStream := <-respChan: reorderedClients = append(reorderedClients, cliClient)
if !okStream { }
_, _ = fmt.Fprint(c.Writer, jsonTemplate)
flusher.Flush() if len(reorderedClients) == 0 {
cliCancel() c.Status(429)
return _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else { } else {
jsonTemplate = h.convertCliToOpenAINonStream(jsonTemplate, chunk)
}
case err, okError := <-errChan:
if okError {
c.Status(err.StatusCode) c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error()) _, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush() flusher.Flush()
// c.JSON(http.StatusInternalServerError, ErrorResponse{
// Error: ErrorDetail{
// Message: err.Error(),
// Type: "server_error",
// },
// })
cliCancel() cliCancel()
return
} }
case <-time.After(500 * time.Millisecond): break
_, _ = c.Writer.Write([]byte("\n")) } else {
flusher.Flush() openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
}
cliCancel()
break
} }
} }
} }
@@ -455,7 +296,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
// Handle streaming manually // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, ErrorResponse{
@@ -466,265 +307,116 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
}) })
return return
} }
modelName, contents, tools := h.prepareRequest(rawJson)
// Prepare the request for the backend client.
modelName, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
defer func() { defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil { if cliClient != nil {
cliClient.RequestMutex.Unlock() cliClient.RequestMutex.Unlock()
} }
}() }()
// Lock the mutex to update the last used page index outLoop:
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the pages to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients))
for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
}
locked := false
for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
log.Debugf("Request use account: %s, project id: %s", cliClient.Email, cliClient.ProjectID)
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
for { for {
select { // Lock the mutex to update the last used client index
case <-c.Request.Context().Done(): mutex.Lock()
if c.Request.Context().Err().Error() == "context canceled" { startIndex := lastUsedClientIndex
log.Debugf("Client disconnected: %v", c.Request.Context().Err()) currentIndex := (startIndex + 1) % len(h.cliClients)
cliCancel() lastUsedClientIndex = currentIndex
return mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
} }
case chunk, okStream := <-respChan: reorderedClients = append(reorderedClients, cliClient)
if !okStream { }
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush() if len(reorderedClients) == 0 {
cliCancel() c.Status(429)
return _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
} else { flusher.Flush()
openAIFormat := h.convertCliToOpenAI(chunk) cliCancel()
if openAIFormat != "" { return
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) }
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
// Stream is closed, send the final [DONE] message.
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
} else {
// Convert the chunk to OpenAI format and send it to the client.
hasFirstResponse = true
openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
}
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush() flusher.Flush()
} }
} }
case err, okError := <-errChan:
if okError {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
// c.JSON(http.StatusInternalServerError, ErrorResponse{
// Error: ErrorDetail{
// Message: err.Error(),
// Type: "server_error",
// },
// })
cliCancel()
return
}
case <-time.After(500 * time.Millisecond):
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush()
} }
} }
} }
func (h *APIHandlers) convertCliToOpenAI(rawJson []byte) string {
// log.Debugf(string(rawJson))
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
var unixTimestamp int64
if err == nil {
unixTimestamp = t.Unix()
} else {
unixTimestamp = time.Now().Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
totalTokenCountResult := usageResult.Get("totalTokenCount")
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
promptTokenCountResult := usageResult.Get("promptTokenCount")
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
} else {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
}
}
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
}
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() && partTextResult.Type == gjson.String {
partThoughtResult := partResult.Get("thought")
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcNameResult := functionCallResult.Get("name")
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcNameResult.String())
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcNameResult.String())
}
fcArgsResult := functionCallResult.Get("args")
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
return ""
}
return template
}
func (h *APIHandlers) convertCliToOpenAINonStream(template string, rawJson []byte) string {
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
var unixTimestamp int64
if err == nil {
unixTimestamp = t.Unix()
} else {
unixTimestamp = time.Now().Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
totalTokenCountResult := usageResult.Get("totalTokenCount")
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
promptTokenCountResult := usageResult.Get("promptTokenCount")
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
} else {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
}
}
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
}
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() && partTextResult.Type == gjson.String {
partThoughtResult := partResult.Get("thought")
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
reasoningContentResult := gjson.Get(template, "choices.0.message.reasoning_content")
if reasoningContentResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningContentResult.String()+partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
}
} else {
reasoningContentResult := gjson.Get(template, "choices.0.message.content")
if reasoningContentResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.message.content", reasoningContentResult.String()+partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || toolCallsResult.Type == gjson.Null {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcNameResult := functionCallResult.Get("name")
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcNameResult.String())
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcNameResult.String())
}
fcArgsResult := functionCallResult.Get("args")
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
return ""
}
return template
}

View File

@@ -1,13 +1,18 @@
package api package api
// ErrorResponse represents an error response // ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct { type ErrorResponse struct {
Error ErrorDetail `json:"error"` Error ErrorDetail `json:"error"`
} }
// ErrorDetail represents error details // ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct { type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"` Message string `json:"message"`
Type string `json:"type"` // The type of error that occurred (e.g., "invalid_request_error").
Code string `json:"code,omitempty"` Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
} }

View File

@@ -6,35 +6,31 @@ import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http" "net/http"
"strings" "strings"
) )
// Server represents the API server // Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct { type Server struct {
engine *gin.Engine engine *gin.Engine
server *http.Server server *http.Server
handlers *APIHandlers handlers *APIHandlers
cfg *ServerConfig cfg *config.Config
} }
// ServerConfig contains configuration for the API server // NewServer creates and initializes a new API server instance.
type ServerConfig struct { // It sets up the Gin engine, middleware, routes, and handlers.
Port string func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
Debug bool
ApiKeys []string
}
// NewServer creates a new API server instance
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// Set gin mode // Set gin mode
if !config.Debug { if !cfg.Debug {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
// Create handlers // Create handlers
handlers := NewAPIHandlers(cliClients, config.Debug) handlers := NewAPIHandlers(cliClients, cfg)
// Create gin engine // Create gin engine
engine := gin.New() engine := gin.New()
@@ -48,7 +44,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
s := &Server{ s := &Server{
engine: engine, engine: engine,
handlers: handlers, handlers: handlers,
cfg: config, cfg: cfg,
} }
// Setup routes // Setup routes
@@ -56,14 +52,15 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// Create HTTP server // Create HTTP server
s.server = &http.Server{ s.server = &http.Server{
Addr: ":" + config.Port, Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: engine, Handler: engine,
} }
return s return s
} }
// setupRoutes configures the API routes // setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() { func (s *Server) setupRoutes() {
// OpenAI compatible API routes // OpenAI compatible API routes
v1 := s.engine.Group("/v1") v1 := s.engine.Group("/v1")
@@ -86,11 +83,12 @@ func (s *Server) setupRoutes() {
}) })
} }
// Start starts the API server // Start begins listening for and serving HTTP requests.
// It's a blocking call and will only return on an unrecoverable error.
func (s *Server) Start() error { func (s *Server) Start() error {
log.Debugf("Starting API server on %s", s.server.Addr) log.Debugf("Starting API server on %s", s.server.Addr)
// Start the HTTP server // Start the HTTP server.
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", err) return fmt.Errorf("failed to start HTTP server: %v", err)
} }
@@ -98,11 +96,12 @@ func (s *Server) Start() error {
return nil return nil
} }
// Stop gracefully stops the API server // Stop gracefully shuts down the API server without interrupting any
// active connections.
func (s *Server) Stop(ctx context.Context) error { func (s *Server) Stop(ctx context.Context) error {
log.Debug("Stopping API server...") log.Debug("Stopping API server...")
// Shutdown the HTTP server // Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil { if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err) return fmt.Errorf("failed to shutdown HTTP server: %v", err)
} }
@@ -111,7 +110,8 @@ func (s *Server) Stop(ctx context.Context) error {
return nil return nil
} }
// corsMiddleware adds CORS headers // corsMiddleware returns a Gin middleware handler that adds CORS headers
// to every response, allowing cross-origin requests.
func corsMiddleware() gin.HandlerFunc { func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
@@ -127,8 +127,9 @@ func corsMiddleware() gin.HandlerFunc {
} }
} }
// AuthMiddleware authenticates requests using API keys // AuthMiddleware returns a Gin middleware handler that authenticates requests
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc { // using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 { if len(cfg.ApiKeys) == 0 {
c.Next() c.Next()

View File

@@ -1,5 +1,7 @@
package api package translator
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
// This is used to identify the type of file being uploaded or processed.
var MimeTypes = map[string]string{ var MimeTypes = map[string]string{
"ez": "application/andrew-inset", "ez": "application/andrew-inset",
"aw": "application/applixware", "aw": "application/applixware",

View File

@@ -0,0 +1,163 @@
package translator
import (
"encoding/json"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
// Process the array of messages.
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type != gjson.String {
continue
}
switch roleResult.String() {
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
}
}
// User messages can contain simple text or a multi-part body.
case "user":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsArray() {
// Handle multi-part user messages (text, images, files).
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
switch contentTypeResult.String() {
case "text":
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
case "image_url":
// Parse data URI for images.
imageURL := contentItemResult.Get("image_url.url").String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
case "file":
// Handle file attachments by determining MIME type from extension.
filename := contentItemResult.Get("file.filename").String()
fileData := contentItemResult.Get("file.file_data").String()
ext := ""
if split := strings.Split(filename, "."); len(split) > 1 {
ext = split[len(split)-1]
}
if mimeType, ok := MimeTypes[ext]; ok {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileData,
}})
} else {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
// Assistant messages can contain text or tool calls.
case "assistant":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle tool calls made by the assistant.
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
tcsResult := toolCallsResult.Array()
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
}},
})
}
}
}
}
// Tool messages contain the output of a tool call.
case "tool":
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
// Translate the tool declarations from the request.
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolResult.Get("type").String() == "function" {
functionTypeResult := toolResult.Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
}

View File

@@ -0,0 +1,181 @@
package translator
import (
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
// backend client format to the OpenAI Server-Sent Events (SSE) format.
// It returns an empty string if the chunk contains no useful data.
func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
return template
}
// ConvertCliToOpenAINonStream aggregates response from the backend client
// convert a single, non-streaming OpenAI-compatible JSON response.
func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partsResults := partsResult.Array()
for i := 0; i < len(partsResults); i++ {
partResult := partsResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Append text content, distinguishing between regular content and reasoning.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
// Append function call content to the tool_calls array.
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
}
}
return template
}

View File

@@ -5,17 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
@@ -33,84 +34,78 @@ var (
} }
) )
type TokenStorage struct { // GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
Token any `json:"token"` // It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
ProjectID string `json:"project_id"` // initiating a new web-based OAuth flow if necessary, and refreshing tokens.
Email string `json:"email"`
Auto bool `json:"auto"`
Checked bool `json:"checked"`
}
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
// It handles the entire flow: loading, refreshing, and fetching new tokens.
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) { func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyUrl) proxyURL, err := url.Parse(cfg.ProxyUrl)
if err == nil { if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" { if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username() username := proxyURL.User.Username()
password, _ := proxyURL.User.Password() password, _ := proxyURL.User.Password()
auth := &proxy.Auth{ auth := &proxy.Auth{User: username, Password: password}
User: username,
Password: password,
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil { if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
} }
transport = &http.Transport{
transport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
return dialer.Dial(network, addr) return dialer.Dial(network, addr)
}, },
} }
proxyClient := &http.Client{
Transport: transport,
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport := &http.Transport{ // Handle HTTP/HTTPS proxy.
Proxy: http.ProxyURL(proxyURL), transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
} }
proxyClient := &http.Client{
Transport: transport, if transport != nil {
} proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
} }
} }
// Configure the OAuth2 client.
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: oauthClientID, ClientID: oauthClientID,
ClientSecret: oauthClientSecret, ClientSecret: oauthClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", // Placeholder, will be updated RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
Scopes: oauthScopes, Scopes: oauthScopes,
Endpoint: google.Endpoint, Endpoint: google.Endpoint,
} }
var token *oauth2.Token var token *oauth2.Token
// If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil { if ts.Token == nil {
log.Info("Could not load token from file, starting OAuth flow.") log.Info("Could not load token from file, starting OAuth flow.")
token, err = getTokenFromWeb(ctx, conf) token, err = getTokenFromWeb(ctx, conf)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err) return nil, fmt.Errorf("failed to get token from web: %w", err)
} }
newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID) // After getting a new token, create a new token storage object with user info.
if errSaveTokenToFile != nil { newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
log.Errorf("Warning: failed to save token to file: %v", err) if errCreateTokenStorage != nil {
return nil, errSaveTokenToFile log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
return nil, errCreateTokenStorage
} }
*ts = *newTs *ts = *newTs
} }
// Unmarshal the stored token into an oauth2.Token object.
tsToken, _ := json.Marshal(ts.Token) tsToken, _ := json.Marshal(ts.Token)
if err = json.Unmarshal(tsToken, &token); err != nil { if err = json.Unmarshal(tsToken, &token); err != nil {
return nil, err return nil, fmt.Errorf("failed to unmarshal token: %w", err)
} }
// Return an HTTP client that automatically handles token refreshing.
return conf.Client(ctx, token), nil return conf.Client(ctx, token), nil
} }
// createTokenStorage creates a token storage. // createTokenStorage creates a new TokenStorage object. It fetches the user's email
// using the provided token and populates the storage structure.
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) { func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
httpClient := config.Client(ctx, token) httpClient := config.Client(ctx, token)
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
@@ -125,7 +120,9 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
return nil, fmt.Errorf("failed to execute request: %w", err) return nil, fmt.Errorf("failed to execute request: %w", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}() }()
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
@@ -162,7 +159,10 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
return &ts, nil return &ts, nil
} }
// getTokenFromWeb starts a local server to handle the OAuth2 flow. // getTokenFromWeb initiates the web-based OAuth2 authorization flow.
// It starts a local HTTP server to listen for the callback from Google's auth server,
// opens the user's browser to the authorization URL, and exchanges the received
// authorization code for an access token.
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string)
@@ -199,7 +199,8 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL) log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL)
err := open.Run(authURL) var err error
err = open.Run(authURL)
if err != nil { if err != nil {
log.Errorf("Failed to open browser: %v. Please open the URL manually.", err) log.Errorf("Failed to open browser: %v. Please open the URL manually.", err)
} }

17
internal/auth/models.go Normal file
View File

@@ -0,0 +1,17 @@
package auth
// TokenStorage defines the structure for storing OAuth2 token information,
// along with associated user and project details. This data is typically
// serialized to a JSON file for persistence.
type TokenStorage struct {
// Token holds the raw OAuth2 token data, including access and refresh tokens.
Token any `json:"token"`
// ProjectID is the Google Cloud Project ID associated with this token.
ProjectID string `json:"project_id"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Auto indicates if the project ID was automatically selected.
Auto bool `json:"auto"`
// Checked indicates if the associated Cloud AI API has been verified as enabled.
Checked bool `json:"checked"`
}

View File

@@ -6,12 +6,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"io" "io"
"net/http" "net/http"
"os" "os"
@@ -20,108 +14,53 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
) )
// --- Constants ---
const ( const (
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com" codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
apiVersion = "v1internal" apiVersion = "v1internal"
pluginVersion = "1.0.0" pluginVersion = "0.1.9"
glEndPoint = "https://generativelanguage.googleapis.com/"
glApiVersion = "v1beta"
) )
type ErrorMessage struct { var (
StatusCode int previewModels = map[string][]string{
Error error "gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
} "gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
}
type GCPProject struct { )
Projects []GCPProjectProjects `json:"projects"`
}
type GCPProjectLabels struct {
GenerativeLanguage string `json:"generative-language"`
}
type GCPProjectProjects struct {
ProjectNumber string `json:"projectNumber"`
ProjectID string `json:"projectId"`
LifecycleState string `json:"lifecycleState"`
Name string `json:"name"`
Labels GCPProjectLabels `json:"labels"`
CreateTime time.Time `json:"createTime"`
}
type Content struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}
// Part represents a single part of a message's content.
type Part struct {
Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
}
type InlineData struct {
MimeType string `json:"mime_type,omitempty"`
Data string `json:"data,omitempty"`
}
// FunctionCall represents a tool call requested by the model.
type FunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}
// FunctionResponse represents the result of a tool execution.
type FunctionResponse struct {
Name string `json:"name"`
Response map[string]interface{} `json:"response"`
}
// GenerateContentRequest is the request payload for the streamGenerateContent endpoint.
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"`
}
// GenerationConfig defines model generation parameters.
type GenerationConfig struct {
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
// Temperature, TopP, TopK, etc. can be added here.
}
type GenerationConfigThinkingConfig struct {
IncludeThoughts bool `json:"include_thoughts,omitempty"`
}
// ToolDeclaration is the structure for declaring tools to the API.
// For now, we'll assume a simple structure. A more complete implementation
// would mirror the OpenAPI schema definition.
type ToolDeclaration struct {
FunctionDeclarations []interface{} `json:"functionDeclarations"`
}
// Client is the main client for interacting with the CLI API. // Client is the main client for interacting with the CLI API.
type Client struct { type Client struct {
httpClient *http.Client httpClient *http.Client
ProjectID string RequestMutex sync.Mutex
RequestMutex sync.Mutex tokenStorage *auth.TokenStorage
Email string cfg *config.Config
tokenStorage *auth.TokenStorage modelQuotaExceeded map[string]*time.Time
cfg *config.Config glAPIKey string
} }
// NewClient creates a new CLI API client. // NewClient creates a new CLI API client.
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client { func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config, glAPIKey ...string) *Client {
var glKey string
if len(glAPIKey) > 0 {
glKey = glAPIKey[0]
}
return &Client{ return &Client{
httpClient: httpClient, httpClient: httpClient,
tokenStorage: ts, tokenStorage: ts,
cfg: cfg, cfg: cfg,
modelQuotaExceeded: make(map[string]*time.Time),
glAPIKey: glKey,
} }
} }
@@ -145,9 +84,24 @@ func (c *Client) IsAuto() bool {
return c.tokenStorage.Auto return c.tokenStorage.Auto
} }
func (c *Client) GetEmail() string {
return c.tokenStorage.Email
}
func (c *Client) GetProjectID() string {
if c.tokenStorage != nil {
return c.tokenStorage.ProjectID
}
return ""
}
func (c *Client) GetGenerativeLanguageAPIKey() string {
return c.glAPIKey
}
// SetupUser performs the initial user onboarding and setup. // SetupUser performs the initial user onboarding and setup.
func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string, error) { func (c *Client) SetupUser(ctx context.Context, email, projectID string) error {
c.Email = email c.tokenStorage.Email = email
log.Info("Performing user onboarding...") log.Info("Performing user onboarding...")
// 1. LoadCodeAssist // 1. LoadCodeAssist
@@ -161,7 +115,7 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string
var loadAssistResp map[string]interface{} var loadAssistResp map[string]interface{}
err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp) err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp)
if err != nil { if err != nil {
return projectID, fmt.Errorf("failed to load code assist: %w", err) return fmt.Errorf("failed to load code assist: %w", err)
} }
// a, _ := json.Marshal(&loadAssistResp) // a, _ := json.Marshal(&loadAssistResp)
@@ -197,14 +151,14 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string
if onboardProjectID != "" { if onboardProjectID != "" {
onboardReqBody["cloudaicompanionProject"] = onboardProjectID onboardReqBody["cloudaicompanionProject"] = onboardProjectID
} else { } else {
return projectID, fmt.Errorf("failed to start user onboarding, need define a project id") return fmt.Errorf("failed to start user onboarding, need define a project id")
} }
for { for {
var lroResp map[string]interface{} var lroResp map[string]interface{}
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp) err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
if err != nil { if err != nil {
return projectID, fmt.Errorf("failed to start user onboarding: %w", err) return fmt.Errorf("failed to start user onboarding: %w", err)
} }
// a, _ := json.Marshal(&lroResp) // a, _ := json.Marshal(&lroResp)
// log.Debug(string(a)) // log.Debug(string(a))
@@ -214,12 +168,12 @@ func (c *Client) SetupUser(ctx context.Context, email, projectID string) (string
if doneOk && done { if doneOk && done {
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk { if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
if projectID != "" { if projectID != "" {
c.ProjectID = projectID c.tokenStorage.ProjectID = projectID
} else { } else {
c.ProjectID = project["id"].(string) c.tokenStorage.ProjectID = project["id"].(string)
} }
log.Infof("Onboarding complete. Using Project ID: %s", c.ProjectID) log.Infof("Onboarding complete. Using Project ID: %s", c.tokenStorage.ProjectID)
return c.ProjectID, nil return nil
} }
} else { } else {
log.Println("Onboarding in progress, waiting 5 seconds...") log.Println("Onboarding in progress, waiting 5 seconds...")
@@ -266,7 +220,9 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
return fmt.Errorf("failed to execute request: %w", err) return fmt.Errorf("failed to execute request: %w", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -283,8 +239,8 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
return nil return nil
} }
// StreamAPIRequest handles making streaming requests to the CLI API endpoints. // APIRequest handles making requests to the CLI API endpoints.
func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body interface{}) (io.ReadCloser, *ErrorMessage) { func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte var jsonBody []byte
var err error var err error
if byteBody, ok := body.([]byte); ok { if byteBody, ok := body.([]byte); ok {
@@ -295,47 +251,156 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)} return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
} }
} }
var url string
if c.glAPIKey == "" {
// Add alt=sse for streaming
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
if stream {
url = url + "?alt=sse"
}
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
if stream {
url = url + "?alt=sse"
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
}
// log.Debug(string(jsonBody)) // log.Debug(string(jsonBody))
reqBody := bytes.NewBuffer(jsonBody) reqBody := bytes.NewBuffer(jsonBody)
// Add alt=sse for streaming
url := fmt.Sprintf("%s/%s:%s?alt=sse", codeAssistEndpoint, apiVersion, endpoint)
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody) req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
if err != nil { if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)} return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)}
}
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)}
} }
// Set headers // Set headers
metadataStr := getClientMetadataString() metadataStr := getClientMetadataString()
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", getUserAgent()) if c.glAPIKey == "" {
req.Header.Set("Client-Metadata", metadataStr) token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) if errToken != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)}
}
req.Header.Set("User-Agent", getUserAgent())
req.Header.Set("Client-Metadata", metadataStr)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
} else {
req.Header.Set("x-goog-api-key", c.glAPIKey)
}
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)} return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)}
} }
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() { defer func() {
_ = resp.Body.Close() if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}() }()
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))} return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
// return nil, fmt.Errorf("api streaming request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
} }
return resp.Body, nil return resp.Body, nil
} }
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.GetProjectID(), // Assuming ProjectID is available
"request": request,
"model": model,
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
modelName := model
// log.Debug(string(byteRequestBody))
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendMessageStream handles a single conversational turn, including tool calls. // SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) { func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ") dataTag := []byte("data: ")
@@ -356,7 +421,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
request.Tools = tools request.Tools = tools
requestBody := map[string]interface{}{ requestBody := map[string]interface{}{
"project": c.ProjectID, // Assuming ProjectID is available "project": c.GetProjectID(), // Assuming ProjectID is available
"request": request, "request": request,
"model": model, "model": model,
} }
@@ -397,12 +462,39 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
} }
// log.Debug(string(byteRequestBody)) // log.Debug(string(byteRequestBody))
modelName := model
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", byteRequestBody) var stream io.ReadCloser
if err != nil { for {
// log.Println(err) if c.isModelQuotaExceeded(modelName) {
errChan <- err if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
return modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue
}
}
errChan <- &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
errChan <- err
return
}
delete(c.modelQuotaExceeded, modelName)
break
} }
scanner := bufio.NewScanner(stream) scanner := bufio.NewScanner(stream)
@@ -427,6 +519,41 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan return dataChan, errChan
} }
func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime)
if duration > 30*time.Minute {
return false
}
return true
}
return false
}
func (c *Client) getPreviewModel(model string) string {
if models, hasKey := previewModels[model]; hasKey {
for i := 0; i < len(models); i++ {
if !c.isModelQuotaExceeded(models[i]) {
return models[i]
}
}
}
return ""
}
func (c *Client) IsModelQuotaExceeded(model string) bool {
if c.isModelQuotaExceeded(model) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
return c.getPreviewModel(model) == ""
}
return true
}
return false
}
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
// that the Cloud AI API is enabled for the user's project. It provides
// an activation URL if the API is disabled.
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer func() { defer func() {
@@ -435,79 +562,78 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
}() }()
c.RequestMutex.Lock() c.RequestMutex.Lock()
requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}` // A simple request to test the API endpoint.
requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID) requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
// log.Debug(requestBody)
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody)) stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
if err != nil { if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 { if err.StatusCode == 403 {
errJson := err.Error.Error() errJson := err.Error.Error()
codeResult := gjson.Get(errJson, "error.code") // Check for a specific error code and extract the activation URL.
if codeResult.Exists() && codeResult.Type == gjson.Number { if gjson.Get(errJson, "error.code").Int() == 403 {
if codeResult.Int() == 403 { activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl") if activationUrl != "" {
if activationUrlResult.Exists() { log.Warnf(
log.Warnf( "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", activationUrl,
activationUrlResult.String(), os.Args[0],
os.Args[0], c.tokenStorage.ProjectID,
c.tokenStorage.ProjectID, )
)
}
} }
} }
return false, nil return false, nil
} }
return false, err.Error return false, err.Error
} }
defer func() {
_ = stream.Close()
}()
// We only need to know if the request was successful, so we can drain the stream.
scanner := bufio.NewScanner(stream) scanner := bufio.NewScanner(stream)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() // Do nothing, just consume the stream.
if !strings.HasPrefix(line, "data: ") {
continue
}
} }
if scannerErr := scanner.Err(); scannerErr != nil { return scanner.Err() == nil, scanner.Err()
_ = stream.Close()
} else {
_ = stream.Close()
}
return true, nil
} }
// GetProjectList fetches a list of Google Cloud projects accessible by the user.
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) { func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
if err != nil {
return nil, fmt.Errorf("failed to get token: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil) req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get project list: %v", err) return nil, fmt.Errorf("could not create project list request: %v", err)
} }
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err) return nil, fmt.Errorf("failed to execute project list request: %w", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = resp.Body.Close()
}() }()
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
} }
var project GCPProject var project GCPProject
err = json.Unmarshal(bodyBytes, &project) if err = json.NewDecoder(resp.Body).Decode(&project); err != nil {
if err != nil {
return nil, fmt.Errorf("failed to unmarshal project list: %w", err) return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
} }
return &project, nil return &project, nil
} }
// SaveTokenToFile serializes the client's current token storage to a JSON file.
// The filename is constructed from the user's email and project ID.
func (c *Client) SaveTokenToFile() error { func (c *Client) SaveTokenToFile() error {
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil { if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err) return fmt.Errorf("failed to create directory: %v", err)
@@ -529,7 +655,8 @@ func (c *Client) SaveTokenToFile() error {
return nil return nil
} }
// getClientMetadata returns metadata about the client environment. // getClientMetadata returns a map of metadata about the client environment,
// such as IDE type, platform, and plugin version.
func getClientMetadata() map[string]string { func getClientMetadata() map[string]string {
return map[string]string{ return map[string]string{
"ideType": "IDE_UNSPECIFIED", "ideType": "IDE_UNSPECIFIED",
@@ -539,7 +666,8 @@ func getClientMetadata() map[string]string {
} }
} }
// getClientMetadataString returns the metadata as a comma-separated string. // getClientMetadataString returns the client metadata as a single,
// comma-separated string, which is required for the 'Client-Metadata' header.
func getClientMetadataString() string { func getClientMetadataString() string {
md := getClientMetadata() md := getClientMetadata()
parts := make([]string, 0, len(md)) parts := make([]string, 0, len(md))
@@ -549,11 +677,13 @@ func getClientMetadataString() string {
return strings.Join(parts, ",") return strings.Join(parts, ",")
} }
// getUserAgent constructs the User-Agent string for HTTP requests.
func getUserAgent() string { func getUserAgent() string {
return fmt.Sprintf(fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)) return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
} }
// getPlatform returns the OS and architecture in the format expected by the API. // getPlatform determines the operating system and architecture and formats
// it into a string expected by the backend API.
func getPlatform() string { func getPlatform() string {
goOS := runtime.GOOS goOS := runtime.GOOS
arch := runtime.GOARCH arch := runtime.GOARCH

90
internal/client/models.go Normal file
View File

@@ -0,0 +1,90 @@
package client
import "time"
// ErrorMessage encapsulates an error with an associated HTTP status code.
type ErrorMessage struct {
StatusCode int
Error error
}
// GCPProject represents the response structure for a Google Cloud project list request.
type GCPProject struct {
Projects []GCPProjectProjects `json:"projects"`
}
// GCPProjectLabels defines the labels associated with a GCP project.
type GCPProjectLabels struct {
GenerativeLanguage string `json:"generative-language"`
}
// GCPProjectProjects contains details about a single Google Cloud project.
type GCPProjectProjects struct {
ProjectNumber string `json:"projectNumber"`
ProjectID string `json:"projectId"`
LifecycleState string `json:"lifecycleState"`
Name string `json:"name"`
Labels GCPProjectLabels `json:"labels"`
CreateTime time.Time `json:"createTime"`
}
// Content represents a single message in a conversation, with a role and parts.
type Content struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}
// Part represents a distinct piece of content within a message, which can be
// text, inline data (like an image), a function call, or a function response.
type Part struct {
Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
}
// InlineData represents base64-encoded data with its MIME type.
type InlineData struct {
MimeType string `json:"mime_type,omitempty"`
Data string `json:"data,omitempty"`
}
// FunctionCall represents a tool call requested by the model, including the
// function name and its arguments.
type FunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}
// FunctionResponse represents the result of a tool execution, sent back to the model.
type FunctionResponse struct {
Name string `json:"name"`
Response map[string]interface{} `json:"response"`
}
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"`
}
// GenerationConfig defines parameters that control the model's generation behavior.
type GenerationConfig struct {
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
}
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
type GenerationConfigThinkingConfig struct {
// IncludeThoughts determines whether the model should output its reasoning process.
IncludeThoughts bool `json:"include_thoughts,omitempty"`
}
// ToolDeclaration defines the structure for declaring tools (like functions)
// that the model can call.
type ToolDeclaration struct {
FunctionDeclarations []interface{} `json:"functionDeclarations"`
}

View File

@@ -9,6 +9,9 @@ import (
"os" "os"
) )
// DoLogin handles the entire user login and setup process.
// It authenticates the user, sets up the user's project, checks API enablement,
// and saves the token for future use.
func DoLogin(cfg *config.Config, projectID string) { func DoLogin(cfg *config.Config, projectID string) {
var err error var err error
var ts auth.TokenStorage var ts auth.TokenStorage
@@ -16,9 +19,8 @@ func DoLogin(cfg *config.Config, projectID string) {
ts.ProjectID = projectID ts.ProjectID = projectID
} }
// 2. Initialize authenticated HTTP Client // Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
clientCtx := context.Background() clientCtx := context.Background()
log.Info("Initializing authentication...") log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil { if errGetClient != nil {
@@ -27,52 +29,57 @@ func DoLogin(cfg *config.Config, projectID string) {
} }
log.Info("Authentication successful.") log.Info("Authentication successful.")
// 3. Initialize CLI Client // Initialize the API client.
cliClient := client.NewClient(httpClient, &ts, cfg) cliClient := client.NewClient(httpClient, &ts, cfg)
projectID, err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
// Perform the user setup process.
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
if err != nil { if err != nil {
// Handle the specific case where a project ID is required but not provided.
if err.Error() == "failed to start user onboarding, need define a project id" { if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("failed to start user onboarding") log.Error("Failed to start user onboarding: A project ID is required.")
// Fetch and display the user's available projects to help them choose one.
project, errGetProjectList := cliClient.GetProjectList(clientCtx) project, errGetProjectList := cliClient.GetProjectList(clientCtx)
if errGetProjectList != nil { if errGetProjectList != nil {
log.Fatalf("failed to complete user setup: %v", err) log.Fatalf("Failed to get project list: %v", err)
} else { } else {
log.Infof("Your account %s needs specify a project id.", ts.Email) log.Infof("Your account %s needs to specify a project ID.", ts.Email)
log.Info("========================================================================") log.Info("========================================================================")
for i := 0; i < len(project.Projects); i++ { for _, p := range project.Projects {
log.Infof("Project ID: %s", project.Projects[i].ProjectID) log.Infof("Project ID: %s", p.ProjectID)
log.Infof("Project Name: %s", project.Projects[i].Name) log.Infof("Project Name: %s", p.Name)
log.Info("========================================================================") log.Info("------------------------------------------------------------------------")
} }
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0]) log.Infof("Please run this command to login again with a specific project:\n\n%s --login --project_id <project_id>\n", os.Args[0])
} }
} else { } else {
// Log as a warning because in some cases, the CLI might still be usable log.Fatalf("Failed to complete user setup: %v", err)
// or the user might want to retry setup later.
log.Fatalf("failed to complete user setup: %v", err)
} }
} else { return // Exit after handling the error.
auto := ts.ProjectID == "" }
cliClient.SetProjectID(projectID)
cliClient.SetIsAuto(auto)
if !cliClient.IsChecked() && !cliClient.IsAuto() { // If setup is successful, proceed to check API status and save the token.
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled() auto := projectID == ""
if checkErr != nil { cliClient.SetIsAuto(auto)
log.Fatalf("failed to check cloud api is enabled: %v", checkErr)
return
}
cliClient.SetIsChecked(isChecked)
}
if !cliClient.IsChecked() && !cliClient.IsAuto() { // If the project was not automatically selected, check if the Cloud AI API is enabled.
if !cliClient.IsChecked() && !cliClient.IsAuto() {
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
if checkErr != nil {
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
return return
} }
cliClient.SetIsChecked(isChecked)
err = cliClient.SaveTokenToFile() // If the check fails (returns false), the CheckCloudAPIIsEnabled function
if err != nil { // will have already printed instructions, so we can just exit.
log.Fatal(err) if !isChecked {
return return
} }
} }
// Save the successfully obtained and verified token to a file.
err = cliClient.SaveTokenToFile()
if err != nil {
log.Fatalf("Failed to save token to file: %v", err)
}
} }

View File

@@ -3,13 +3,16 @@ package cmd
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api" "github.com/luispater/CLIProxyAPI/internal/api"
"github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"io/fs" "io/fs"
"net"
"net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@@ -18,20 +21,18 @@ import (
"time" "time"
) )
// StartService initializes and starts the main API proxy service.
// It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) { func StartService(cfg *config.Config) {
// Create API server configuration // Create a pool of API clients, one for each token file found.
apiConfig := &api.ServerConfig{
Port: fmt.Sprintf("%d", cfg.Port),
Debug: cfg.Debug,
ApiKeys: cfg.ApiKeys,
}
cliClients := make([]*client.Client, 0) cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error { err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err
} }
// Process only JSON files in the auth directory.
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") { if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
log.Debugf("Loading token from: %s", path) log.Debugf("Loading token from: %s", path)
f, errOpen := os.Open(path) f, errOpen := os.Open(path)
@@ -42,81 +43,96 @@ func StartService(cfg *config.Config) {
_ = f.Close() _ = f.Close()
}() }()
// Decode the token storage file.
var ts auth.TokenStorage var ts auth.TokenStorage
if err = json.NewDecoder(f).Decode(&ts); err == nil { if err = json.NewDecoder(f).Decode(&ts); err == nil {
// 2. Initialize authenticated HTTP Client // For each valid token, create an authenticated client.
clientCtx := context.Background() clientCtx := context.Background()
log.Info("Initializing authentication for token...")
log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil { if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient) // Log fatal will exit, but we return the error for completeness.
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
return errGetClient return errGetClient
} }
log.Info("Authentication successful.") log.Info("Authentication successful.")
// 3. Initialize CLI Client // Add the new client to the pool.
cliClient := client.NewClient(httpClient, &ts, cfg) cliClient := client.NewClient(httpClient, &ts, cfg)
if _, err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID); err != nil { cliClients = append(cliClients, cliClient)
if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("failed to start user onboarding")
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
if errGetProjectList != nil {
log.Fatalf("failed to complete user setup: %v", err)
} else {
log.Infof("Your account %s needs specify a project id.", ts.Email)
log.Info("========================================================================")
for i := 0; i < len(project.Projects); i++ {
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
log.Infof("Project Name: %s", project.Projects[i].Name)
log.Info("========================================================================")
}
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
}
} else {
// Log as a warning because in some cases, the CLI might still be usable
// or the user might want to retry setup later.
log.Fatalf("failed to complete user setup: %v", err)
}
} else {
cliClients = append(cliClients, cliClient)
}
} }
} }
return nil return nil
}) })
if err != nil {
// Create API server log.Fatalf("Error walking auth directory: %v", err)
apiServer := api.NewServer(apiConfig, cliClients)
log.Infof("Starting API server on port %s", apiConfig.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
return
} }
// Set up graceful shutdown if len(cfg.GlAPIKey) > 0 {
var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
for i := 0; i < len(cfg.GlAPIKey); i++ {
httpClient := &http.Client{}
if transport != nil {
httpClient.Transport = transport
}
log.Debug("Initializing with Generative Language API key...")
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
cliClients = append(cliClients, cliClient)
}
}
// Create and start the API server with the pool of clients.
apiServer := api.NewServer(cfg, cliClients)
log.Infof("Starting API server on port %d", cfg.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
}
// Set up a channel to listen for OS signals for graceful shutdown.
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Main loop to wait for shutdown signal.
for { for {
select { select {
case <-sigChan: case <-sigChan:
log.Debugf("Received shutdown signal. Cleaning up...") log.Debugf("Received shutdown signal. Cleaning up...")
// Create shutdown context // Create a context with a timeout for the shutdown process.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out _ = cancel
// Stop API server // Stop the API server gracefully.
if err = apiServer.Stop(ctx); err != nil { if err = apiServer.Stop(ctx); err != nil {
log.Debugf("Error stopping API server: %v", err) log.Debugf("Error stopping API server: %v", err)
} }
cancel()
log.Debugf("Cleanup completed. Exiting...") log.Debugf("Cleanup completed. Exiting...")
os.Exit(0) os.Exit(0)
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
// This case is currently empty and acts as a periodic check.
// It could be used for periodic tasks in the future.
} }
} }
} }

View File

@@ -6,33 +6,46 @@ import (
"os" "os"
) )
// Config represents the application's configuration // Config represents the application's configuration, loaded from a YAML file.
type Config struct { type Config struct {
Port int `yaml:"port"` // Port is the network port on which the API server will listen.
AuthDir string `yaml:"auth_dir"` Port int `yaml:"port"`
Debug bool `yaml:"debug"` // AuthDir is the directory where authentication token files are stored.
ProxyUrl string `yaml:"proxy-url"` AuthDir string `yaml:"auth-dir"`
ApiKeys []string `yaml:"api_keys"` // Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api-keys"`
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
// GlAPIKey is the API key for the generative language API.
GlAPIKey []string `yaml:"generative-language-api-key"`
} }
// / LoadConfig loads the configuration from the specified file type ConfigQuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model"`
}
// LoadConfig reads a YAML configuration file from the given path,
// unmarshals it into a Config struct, and returns it.
func LoadConfig(configFile string) (*Config, error) { func LoadConfig(configFile string) (*Config, error) {
// Read the configuration file // Read the entire configuration file into memory.
data, err := os.ReadFile(configFile) data, err := os.ReadFile(configFile)
// If reading the file fails
if err != nil { if err != nil {
// Return an error
return nil, fmt.Errorf("failed to read config file: %w", err) return nil, fmt.Errorf("failed to read config file: %w", err)
} }
// Parse the YAML data // Unmarshal the YAML data into the Config struct.
var config Config var config Config
// If parsing the YAML data fails
if err = yaml.Unmarshal(data, &config); err != nil { if err = yaml.Unmarshal(data, &config); err != nil {
// Return an error
return nil, fmt.Errorf("failed to parse config file: %w", err) return nil, fmt.Errorf("failed to parse config file: %w", err)
} }
// Return the configuration // Return the populated configuration struct.
return &config, nil return &config, nil
} }