diff --git a/cmd/server/main.go b/cmd/server/main.go index e6e9e686..6e40c4f0 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -12,9 +12,11 @@ import ( "strings" ) +// LogFormatter defines a custom log format for logrus. type LogFormatter struct { } +// Format renders a single log entry. func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { var b *bytes.Buffer if entry.Buffer != nil { @@ -25,33 +27,42 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { timestamp := entry.Time.Format("2006-01-02 15:04:05") var newLog string + // Customize the log format to include timestamp, level, caller file/line, and message. newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message) b.WriteString(newLog) return b.Bytes(), nil } +// init initializes the logger configuration. func init() { + // Set logger output to standard output. log.SetOutput(os.Stdout) + // Enable reporting the caller function's file and line number. log.SetReportCaller(true) + // Set the custom log formatter. log.SetFormatter(&LogFormatter{}) } +// main is the entry point of the application. func main() { var login bool var projectID string var configPath string + // Define command-line flags. flag.BoolVar(&login, "login", false, "Login Google Account") flag.StringVar(&projectID, "project_id", "", "Project ID") flag.StringVar(&configPath, "config", "", "Configure File Path") + // Parse the command-line flags. flag.Parse() var err error var cfg *config.Config var wd string + // Load configuration from the specified path or the default path. if configPath != "" { cfg, err = config.LoadConfig(configPath) } else { @@ -65,12 +76,14 @@ func main() { log.Fatalf("failed to load config: %v", err) } + // Set the log level based on the configuration. if cfg.Debug { log.SetLevel(log.DebugLevel) } else { log.SetLevel(log.InfoLevel) } + // Expand the tilde (~) in the auth directory path to the user's home directory. if strings.HasPrefix(cfg.AuthDir, "~") { home, errUserHomeDir := os.UserHomeDir() if errUserHomeDir != nil { @@ -85,6 +98,7 @@ func main() { } } + // Either perform login or start the service based on the 'login' flag. if login { cmd.DoLogin(cfg, projectID) } else { diff --git a/config.yaml b/config.yaml index ccc53dce..3f8b1885 100644 --- a/config.yaml +++ b/config.yaml @@ -1,6 +1,6 @@ port: 8317 auth_dir: "~/.cli-proxy-api" -debug: false +debug: true proxy-url: "" api_keys: - "12345" diff --git a/internal/api/handlers.go b/internal/api/handlers.go index f549934b..85bdfef9 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -2,14 +2,12 @@ package api import ( "context" - "encoding/json" "fmt" + "github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/client" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "github.com/tidwall/sjson" "net/http" - "strings" "sync" "time" @@ -21,13 +19,15 @@ var ( lastUsedClientIndex = 0 ) -// APIHandlers contains the handlers for API endpoints +// APIHandlers contains the handlers for API endpoints. +// It holds a pool of clients to interact with the backend service. type APIHandlers struct { cliClients []*client.Client debug bool } -// NewAPIHandlers creates a new API handlers instance +// NewAPIHandlers creates a new API handlers instance. +// It takes a slice of clients and a debug flag as input. func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers { return &APIHandlers{ cliClients: cliClients, @@ -35,6 +35,8 @@ func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers { } } +// Models handles the /v1/models endpoint. +// It returns a hardcoded list of available AI models. func (h *APIHandlers) Models(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "data": []map[string]any{ @@ -162,15 +164,23 @@ func (h *APIHandlers) Models(c *gin.Context) { }) } -// ChatCompletions handles the /v1/chat/completions endpoint +// ChatCompletions handles the /v1/chat/completions endpoint. +// It determines whether the request is for a streaming or non-streaming response +// and calls the appropriate handler. func (h *APIHandlers) ChatCompletions(c *gin.Context) { rawJson, err := c.GetRawData() - // If data retrieval fails, return 400 error + // If data retrieval fails, return a 400 Bad Request error. if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err), "code": 400}) + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) return } + // Check if the client requested a streaming response. streamResult := gjson.GetBytes(rawJson, "stream") if streamResult.Type == gjson.True { h.handleStreamingResponse(c, rawJson) @@ -179,184 +189,9 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) { } } -func (h *APIHandlers) prepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) { - // log.Debug(string(rawJson)) - modelName := "gemini-2.5-pro" - modelResult := gjson.GetBytes(rawJson, "model") - if modelResult.Type == gjson.String { - modelName = modelResult.String() - } - - contents := make([]client.Content, 0) - messagesResult := gjson.GetBytes(rawJson, "messages") - if messagesResult.IsArray() { - messagesResults := messagesResult.Array() - for i := 0; i < len(messagesResults); i++ { - messageResult := messagesResults[i] - roleResult := messageResult.Get("role") - contentResult := messageResult.Get("content") - if roleResult.Type == gjson.String { - if roleResult.String() == "system" { - if contentResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) - } else if contentResult.IsObject() { - contentTypeResult := contentResult.Get("type") - if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { - contentTextResult := contentResult.Get("text") - if contentTextResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}}) - contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}}) - } - } - } - } else if roleResult.String() == "user" { - if contentResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) - } else if contentResult.IsObject() { - contentTypeResult := contentResult.Get("type") - if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { - contentTextResult := contentResult.Get("text") - if contentTextResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}}) - } - } - } else if contentResult.IsArray() { - contentItemResults := contentResult.Array() - parts := make([]client.Part, 0) - for j := 0; j < len(contentItemResults); j++ { - contentItemResult := contentItemResults[j] - contentTypeResult := contentItemResult.Get("type") - if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { - contentTextResult := contentItemResult.Get("text") - if contentTextResult.Type == gjson.String { - parts = append(parts, client.Part{Text: contentTextResult.String()}) - } - } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image_url" { - imageURLResult := contentItemResult.Get("image_url.url") - if imageURLResult.Type == gjson.String { - imageURL := imageURLResult.String() - if len(imageURL) > 5 { - imageURLs := strings.SplitN(imageURL[5:], ";", 2) - if len(imageURLs) == 2 { - if len(imageURLs[1]) > 7 { - parts = append(parts, client.Part{InlineData: &client.InlineData{ - MimeType: imageURLs[0], - Data: imageURLs[1][7:], - }}) - } - } - } - } - } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "file" { - filenameResult := contentItemResult.Get("file.filename") - fileDataResult := contentItemResult.Get("file.file_data") - if filenameResult.Type == gjson.String && fileDataResult.Type == gjson.String { - filename := filenameResult.String() - splitFilename := strings.Split(filename, ".") - ext := splitFilename[len(splitFilename)-1] - - mimeType, ok := MimeTypes[ext] - if !ok { - log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j) - continue - } - - parts = append(parts, client.Part{InlineData: &client.InlineData{ - MimeType: mimeType, - Data: fileDataResult.String(), - }}) - } - } - } - contents = append(contents, client.Content{Role: "user", Parts: parts}) - } - } else if roleResult.String() == "assistant" { - if contentResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}}) - } else if contentResult.IsObject() { - contentTypeResult := contentResult.Get("type") - if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { - contentTextResult := contentResult.Get("text") - if contentTextResult.Type == gjson.String { - contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}}) - } - } - } else if !contentResult.Exists() || contentResult.Type == gjson.Null { - toolCallsResult := messageResult.Get("tool_calls") - if toolCallsResult.IsArray() { - tcsResult := toolCallsResult.Array() - for j := 0; j < len(tcsResult); j++ { - tcResult := tcsResult[j] - functionNameResult := tcResult.Get("function.name") - functionArguments := tcResult.Get("function.arguments") - if functionNameResult.Exists() && functionNameResult.Type == gjson.String && functionArguments.Exists() && functionArguments.Type == gjson.String { - var args map[string]any - err := json.Unmarshal([]byte(functionArguments.String()), &args) - if err == nil { - contents = append(contents, client.Content{ - Role: "model", Parts: []client.Part{ - { - FunctionCall: &client.FunctionCall{ - Name: functionNameResult.String(), - Args: args, - }, - }, - }, - }) - } - } - } - } - } - } else if roleResult.String() == "tool" { - toolCallIDResult := messageResult.Get("tool_call_id") - if toolCallIDResult.Exists() && toolCallIDResult.Type == gjson.String { - if contentResult.Type == gjson.String { - functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}} - contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}}) - } else if contentResult.IsObject() { - contentTypeResult := contentResult.Get("type") - if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" { - contentTextResult := contentResult.Get("text") - if contentTextResult.Type == gjson.String { - functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}} - contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}}) - } - } - } - } - } - } - } - } - - var tools []client.ToolDeclaration - toolsResult := gjson.GetBytes(rawJson, "tools") - if toolsResult.IsArray() { - tools = make([]client.ToolDeclaration, 1) - tools[0].FunctionDeclarations = make([]any, 0) - toolsResults := toolsResult.Array() - for i := 0; i < len(toolsResults); i++ { - toolTypeResult := toolsResults[i].Get("type") - if toolTypeResult.Type != gjson.String || toolTypeResult.String() != "function" { - continue - } - functionTypeResult := toolsResults[i].Get("function") - if functionTypeResult.Exists() && functionTypeResult.IsObject() { - var functionDeclaration any - err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration) - if err == nil { - tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration) - } - } - } - } else { - tools = make([]client.ToolDeclaration, 0) - } - return modelName, contents, tools -} - -// handleNonStreamingResponse handles non-streaming responses +// handleNonStreamingResponse handles non-streaming chat completion responses. +// It selects a client from the pool, sends the request, and aggregates the response +// before sending it back to the client. func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) { c.Header("Content-Type", "application/json") @@ -372,7 +207,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) return } - modelName, contents, tools := h.prepareRequest(rawJson) + modelName, contents, tools := translator.PrepareRequest(rawJson) cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { @@ -425,19 +260,13 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) cliCancel() return } else { - jsonTemplate = h.convertCliToOpenAINonStream(jsonTemplate, chunk) + jsonTemplate = translator.ConvertCliToOpenAINonStream(jsonTemplate, chunk) } case err, okError := <-errChan: if okError { c.Status(err.StatusCode) _, _ = fmt.Fprint(c.Writer, err.Error.Error()) flusher.Flush() - // c.JSON(http.StatusInternalServerError, ErrorResponse{ - // Error: ErrorDetail{ - // Message: err.Error(), - // Type: "server_error", - // }, - // }) cliCancel() return } @@ -455,7 +284,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { c.Header("Connection", "keep-alive") c.Header("Access-Control-Allow-Origin", "*") - // Handle streaming manually + // Get the http.Flusher interface to manually flush the response. flusher, ok := c.Writer.(http.Flusher) if !ok { c.JSON(http.StatusInternalServerError, ErrorResponse{ @@ -466,28 +295,33 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { }) return } - modelName, contents, tools := h.prepareRequest(rawJson) + + // Prepare the request for the backend client. + modelName, contents, tools := translator.PrepareRequest(rawJson) cliCtx, cliCancel := context.WithCancel(context.Background()) var cliClient *client.Client defer func() { + // Ensure the client's mutex is unlocked on function exit. if cliClient != nil { cliClient.RequestMutex.Unlock() } }() - // Lock the mutex to update the last used page index + // Use a round-robin approach to select the next available client. + // This distributes the load among the available clients. mutex.Lock() startIndex := lastUsedClientIndex currentIndex := (startIndex + 1) % len(h.cliClients) lastUsedClientIndex = currentIndex mutex.Unlock() - // Reorder the pages to start from the last used index + // Reorder the clients to start from the next client in the rotation. reorderedPages := make([]*client.Client, len(h.cliClients)) for i := 0; i < len(h.cliClients); i++ { reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)] } + // Attempt to lock a client for the request. locked := false for i := 0; i < len(reorderedPages); i++ { cliClient = reorderedPages[i] @@ -496,235 +330,52 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { break } } + // If no client is available, block and wait for the first client. if !locked { cliClient = h.cliClients[0] cliClient.RequestMutex.Lock() } - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + // Send the message and receive response chunks and errors via channels. respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) for { select { + // Handle client disconnection. case <-c.Request.Context().Done(): if c.Request.Context().Err().Error() == "context canceled" { log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() + cliCancel() // Cancel the backend request. return } + // Process incoming response chunks. case chunk, okStream := <-respChan: if !okStream { + // Stream is closed, send the final [DONE] message. _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() cliCancel() return } else { - openAIFormat := h.convertCliToOpenAI(chunk) + // Convert the chunk to OpenAI format and send it to the client. + openAIFormat := translator.ConvertCliToOpenAI(chunk) if openAIFormat != "" { _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) flusher.Flush() } } + // Handle errors from the backend. case err, okError := <-errChan: if okError { c.Status(err.StatusCode) _, _ = fmt.Fprint(c.Writer, err.Error.Error()) flusher.Flush() - // c.JSON(http.StatusInternalServerError, ErrorResponse{ - // Error: ErrorDetail{ - // Message: err.Error(), - // Type: "server_error", - // }, - // }) cliCancel() return } + // Send a keep-alive signal to the client. case <-time.After(500 * time.Millisecond): _, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n")) flusher.Flush() } } } - -func (h *APIHandlers) convertCliToOpenAI(rawJson []byte) string { - // log.Debugf(string(rawJson)) - template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` - - modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion") - if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String { - template, _ = sjson.Set(template, "model", modelVersionResult.String()) - } - - createTimeResult := gjson.GetBytes(rawJson, "response.createTime") - if createTimeResult.Exists() && createTimeResult.Type == gjson.String { - t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - var unixTimestamp int64 - if err == nil { - unixTimestamp = t.Unix() - } else { - unixTimestamp = time.Now().Unix() - } - template, _ = sjson.Set(template, "created", unixTimestamp) - } - - responseIdResult := gjson.GetBytes(rawJson, "response.responseId") - if responseIdResult.Exists() && responseIdResult.Type == gjson.String { - template, _ = sjson.Set(template, "id", responseIdResult.String()) - } - - finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason") - if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String { - template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) - template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) - } - - usageResult := gjson.GetBytes(rawJson, "response.usageMetadata") - candidatesTokenCountResult := usageResult.Get("candidatesTokenCount") - if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) - } - totalTokenCountResult := usageResult.Get("totalTokenCount") - if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) - } - thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount") - promptTokenCountResult := usageResult.Get("promptTokenCount") - if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number { - if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int()) - } else { - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()) - } - } - if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int()) - } - - partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0") - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - - if partTextResult.Exists() && partTextResult.Type == gjson.String { - partThoughtResult := partResult.Get("thought") - if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True { - template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String()) - } else { - template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String()) - } - template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") - } else if functionCallResult.Exists() { - functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]` - fcNameResult := functionCallResult.Get("name") - if fcNameResult.Exists() && fcNameResult.Type == gjson.String { - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcNameResult.String()) - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcNameResult.String()) - } - fcArgsResult := functionCallResult.Get("args") - if fcArgsResult.Exists() && fcArgsResult.IsObject() { - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw) - } - template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") - template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate) - } else { - return "" - } - - return template -} - -func (h *APIHandlers) convertCliToOpenAINonStream(template string, rawJson []byte) string { - modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion") - if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String { - template, _ = sjson.Set(template, "model", modelVersionResult.String()) - } - - createTimeResult := gjson.GetBytes(rawJson, "response.createTime") - if createTimeResult.Exists() && createTimeResult.Type == gjson.String { - t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) - var unixTimestamp int64 - if err == nil { - unixTimestamp = t.Unix() - } else { - unixTimestamp = time.Now().Unix() - } - template, _ = sjson.Set(template, "created", unixTimestamp) - } - - responseIdResult := gjson.GetBytes(rawJson, "response.responseId") - if responseIdResult.Exists() && responseIdResult.Type == gjson.String { - template, _ = sjson.Set(template, "id", responseIdResult.String()) - } - - finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason") - if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String { - template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) - template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) - } - - usageResult := gjson.GetBytes(rawJson, "response.usageMetadata") - candidatesTokenCountResult := usageResult.Get("candidatesTokenCount") - if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) - } - totalTokenCountResult := usageResult.Get("totalTokenCount") - if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) - } - thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount") - promptTokenCountResult := usageResult.Get("promptTokenCount") - if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number { - if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int()) - } else { - template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()) - } - } - if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number { - template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int()) - } - - partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0") - partTextResult := partResult.Get("text") - functionCallResult := partResult.Get("functionCall") - - if partTextResult.Exists() && partTextResult.Type == gjson.String { - partThoughtResult := partResult.Get("thought") - if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True { - reasoningContentResult := gjson.Get(template, "choices.0.message.reasoning_content") - if reasoningContentResult.Type == gjson.String { - template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningContentResult.String()+partTextResult.String()) - } else { - template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String()) - } - } else { - reasoningContentResult := gjson.Get(template, "choices.0.message.content") - if reasoningContentResult.Type == gjson.String { - template, _ = sjson.Set(template, "choices.0.message.content", reasoningContentResult.String()+partTextResult.String()) - } else { - template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String()) - } - } - template, _ = sjson.Set(template, "choices.0.message.role", "assistant") - } else if functionCallResult.Exists() { - toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls") - if !toolCallsResult.Exists() || toolCallsResult.Type == gjson.Null { - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) - } - - functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` - fcNameResult := functionCallResult.Get("name") - if fcNameResult.Exists() && fcNameResult.Type == gjson.String { - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcNameResult.String()) - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcNameResult.String()) - } - fcArgsResult := functionCallResult.Get("args") - if fcArgsResult.Exists() && fcArgsResult.IsObject() { - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) - } - template, _ = sjson.Set(template, "choices.0.message.role", "assistant") - template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate) - } else { - return "" - } - - return template -} diff --git a/internal/api/models.go b/internal/api/models.go index 8cfbbea6..71f2bb5a 100644 --- a/internal/api/models.go +++ b/internal/api/models.go @@ -1,13 +1,18 @@ package api -// ErrorResponse represents an error response +// ErrorResponse represents a standard error response format for the API. +// It contains a single ErrorDetail field. type ErrorResponse struct { Error ErrorDetail `json:"error"` } -// ErrorDetail represents error details +// ErrorDetail provides specific information about an error that occurred. +// It includes a human-readable message, an error type, and an optional error code. type ErrorDetail struct { + // A human-readable message providing more details about the error. Message string `json:"message"` - Type string `json:"type"` - Code string `json:"code,omitempty"` + // The type of error that occurred (e.g., "invalid_request_error"). + Type string `json:"type"` + // A short code identifying the error, if applicable. + Code string `json:"code,omitempty"` } diff --git a/internal/api/server.go b/internal/api/server.go index 0f8e7ac2..8bfea165 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -11,7 +11,8 @@ import ( "strings" ) -// Server represents the API server +// Server represents the main API server. +// It encapsulates the Gin engine, HTTP server, handlers, and configuration. type Server struct { engine *gin.Engine server *http.Server @@ -19,14 +20,18 @@ type Server struct { cfg *ServerConfig } -// ServerConfig contains configuration for the API server +// ServerConfig contains the configuration for the API server. type ServerConfig struct { - Port string - Debug bool + // Port is the port number the server will listen on. + Port string + // Debug enables or disables debug mode for the server and Gin. + Debug bool + // ApiKeys is a list of valid API keys for authentication. ApiKeys []string } -// NewServer creates a new API server instance +// NewServer creates and initializes a new API server instance. +// It sets up the Gin engine, middleware, routes, and handlers. func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { // Set gin mode if !config.Debug { @@ -63,7 +68,8 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { return s } -// setupRoutes configures the API routes +// setupRoutes configures the API routes for the server. +// It defines the endpoints and associates them with their respective handlers. func (s *Server) setupRoutes() { // OpenAI compatible API routes v1 := s.engine.Group("/v1") @@ -86,11 +92,12 @@ func (s *Server) setupRoutes() { }) } -// Start starts the API server +// Start begins listening for and serving HTTP requests. +// It's a blocking call and will only return on an unrecoverable error. func (s *Server) Start() error { log.Debugf("Starting API server on %s", s.server.Addr) - // Start the HTTP server + // Start the HTTP server. if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("failed to start HTTP server: %v", err) } @@ -98,11 +105,12 @@ func (s *Server) Start() error { return nil } -// Stop gracefully stops the API server +// Stop gracefully shuts down the API server without interrupting any +// active connections. func (s *Server) Stop(ctx context.Context) error { log.Debug("Stopping API server...") - // Shutdown the HTTP server + // Shutdown the HTTP server. if err := s.server.Shutdown(ctx); err != nil { return fmt.Errorf("failed to shutdown HTTP server: %v", err) } @@ -111,7 +119,8 @@ func (s *Server) Stop(ctx context.Context) error { return nil } -// corsMiddleware adds CORS headers +// corsMiddleware returns a Gin middleware handler that adds CORS headers +// to every response, allowing cross-origin requests. func corsMiddleware() gin.HandlerFunc { return func(c *gin.Context) { c.Header("Access-Control-Allow-Origin", "*") @@ -127,7 +136,8 @@ func corsMiddleware() gin.HandlerFunc { } } -// AuthMiddleware authenticates requests using API keys +// AuthMiddleware returns a Gin middleware handler that authenticates requests +// using API keys. If no API keys are configured, it allows all requests. func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc { return func(c *gin.Context) { if len(cfg.ApiKeys) == 0 { diff --git a/internal/api/mine-type.go b/internal/api/translator/mime-type.go similarity index 99% rename from internal/api/mine-type.go rename to internal/api/translator/mime-type.go index 98b62408..95938ff1 100644 --- a/internal/api/mine-type.go +++ b/internal/api/translator/mime-type.go @@ -1,5 +1,7 @@ -package api +package translator +// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types. +// This is used to identify the type of file being uploaded or processed. var MimeTypes = map[string]string{ "ez": "application/andrew-inset", "aw": "application/applixware", diff --git a/internal/api/translator/request.go b/internal/api/translator/request.go new file mode 100644 index 00000000..69a9794e --- /dev/null +++ b/internal/api/translator/request.go @@ -0,0 +1,163 @@ +package translator + +import ( + "encoding/json" + "strings" + + "github.com/luispater/CLIProxyAPI/internal/client" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" +) + +// PrepareRequest translates a raw JSON request from an OpenAI-compatible format +// to the internal format expected by the backend client. It parses messages, +// roles, content types (text, image, file), and tool calls. +func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) { + // Extract the model name from the request, defaulting to "gemini-2.5-pro". + modelName := "gemini-2.5-pro" + modelResult := gjson.GetBytes(rawJson, "model") + if modelResult.Type == gjson.String { + modelName = modelResult.String() + } + + // Process the array of messages. + contents := make([]client.Content, 0) + messagesResult := gjson.GetBytes(rawJson, "messages") + if messagesResult.IsArray() { + messagesResults := messagesResult.Array() + for i := 0; i < len(messagesResults); i++ { + messageResult := messagesResults[i] + roleResult := messageResult.Get("role") + contentResult := messageResult.Get("content") + if roleResult.Type != gjson.String { + continue + } + + switch roleResult.String() { + // System messages are converted to a user message followed by a model's acknowledgment. + case "system": + if contentResult.Type == gjson.String { + contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) + contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}}) + } else if contentResult.IsObject() { + // Handle object-based system messages. + if contentResult.Get("type").String() == "text" { + contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}) + contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}}) + } + } + // User messages can contain simple text or a multi-part body. + case "user": + if contentResult.Type == gjson.String { + contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}) + } else if contentResult.IsArray() { + // Handle multi-part user messages (text, images, files). + contentItemResults := contentResult.Array() + parts := make([]client.Part, 0) + for j := 0; j < len(contentItemResults); j++ { + contentItemResult := contentItemResults[j] + contentTypeResult := contentItemResult.Get("type") + switch contentTypeResult.String() { + case "text": + parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()}) + case "image_url": + // Parse data URI for images. + imageURL := contentItemResult.Get("image_url.url").String() + if len(imageURL) > 5 { + imageURLs := strings.SplitN(imageURL[5:], ";", 2) + if len(imageURLs) == 2 && len(imageURLs[1]) > 7 { + parts = append(parts, client.Part{InlineData: &client.InlineData{ + MimeType: imageURLs[0], + Data: imageURLs[1][7:], + }}) + } + } + case "file": + // Handle file attachments by determining MIME type from extension. + filename := contentItemResult.Get("file.filename").String() + fileData := contentItemResult.Get("file.file_data").String() + ext := "" + if split := strings.Split(filename, "."); len(split) > 1 { + ext = split[len(split)-1] + } + if mimeType, ok := MimeTypes[ext]; ok { + parts = append(parts, client.Part{InlineData: &client.InlineData{ + MimeType: mimeType, + Data: fileData, + }}) + } else { + log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j) + } + } + } + contents = append(contents, client.Content{Role: "user", Parts: parts}) + } + // Assistant messages can contain text or tool calls. + case "assistant": + if contentResult.Type == gjson.String { + contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}}) + } else if !contentResult.Exists() || contentResult.Type == gjson.Null { + // Handle tool calls made by the assistant. + toolCallsResult := messageResult.Get("tool_calls") + if toolCallsResult.IsArray() { + tcsResult := toolCallsResult.Array() + for j := 0; j < len(tcsResult); j++ { + tcResult := tcsResult[j] + functionName := tcResult.Get("function.name").String() + functionArgs := tcResult.Get("function.arguments").String() + var args map[string]any + if err := json.Unmarshal([]byte(functionArgs), &args); err == nil { + contents = append(contents, client.Content{ + Role: "model", Parts: []client.Part{{ + FunctionCall: &client.FunctionCall{ + Name: functionName, + Args: args, + }, + }}, + }) + } + } + } + } + // Tool messages contain the output of a tool call. + case "tool": + toolCallID := messageResult.Get("tool_call_id").String() + if toolCallID != "" { + var responseData string + if contentResult.Type == gjson.String { + responseData = contentResult.String() + } else if contentResult.IsObject() && contentResult.Get("type").String() == "text" { + responseData = contentResult.Get("text").String() + } + functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}} + contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}}) + } + } + } + } + + // Translate the tool declarations from the request. + var tools []client.ToolDeclaration + toolsResult := gjson.GetBytes(rawJson, "tools") + if toolsResult.IsArray() { + tools = make([]client.ToolDeclaration, 1) + tools[0].FunctionDeclarations = make([]any, 0) + toolsResults := toolsResult.Array() + for i := 0; i < len(toolsResults); i++ { + toolResult := toolsResults[i] + if toolResult.Get("type").String() == "function" { + functionTypeResult := toolResult.Get("function") + if functionTypeResult.Exists() && functionTypeResult.IsObject() { + var functionDeclaration any + if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil { + tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration) + } + } + } + } + } else { + tools = make([]client.ToolDeclaration, 0) + } + + return modelName, contents, tools +} diff --git a/internal/api/translator/response.go b/internal/api/translator/response.go new file mode 100644 index 00000000..885b5f30 --- /dev/null +++ b/internal/api/translator/response.go @@ -0,0 +1,169 @@ +package translator + +import ( + "time" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// ConvertCliToOpenAI translates a single chunk of a streaming response from the +// backend client format to the OpenAI Server-Sent Events (SSE) format. +// It returns an empty string if the chunk contains no useful data. +func ConvertCliToOpenAI(rawJson []byte) string { + // Initialize the OpenAI SSE template. + template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}` + + // Extract and set the model version. + if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { + template, _ = sjson.Set(template, "model", modelVersionResult.String()) + } + + // Extract and set the creation timestamp. + if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { + t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) + unixTimestamp := time.Now().Unix() + if err == nil { + unixTimestamp = t.Unix() + } + template, _ = sjson.Set(template, "created", unixTimestamp) + } + + // Extract and set the response ID. + if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { + template, _ = sjson.Set(template, "id", responseIdResult.String()) + } + + // Extract and set the finish reason. + if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() { + template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) + template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) + } + + // Extract and set usage metadata (token counts). + if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() { + if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) + } + if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) + } + promptTokenCount := usageResult.Get("promptTokenCount").Int() + thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + if thoughtsTokenCount > 0 { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) + } + } + + // Process the main content part of the response. + partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0") + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") + + if partTextResult.Exists() { + // Handle text content, distinguishing between regular content and reasoning/thoughts. + if partResult.Get("thought").Bool() { + template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String()) + } else { + template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String()) + } + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + } else if functionCallResult.Exists() { + // Handle function call content. + functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]` + fcName := functionCallResult.Get("name").String() + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName) + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName) + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw) + } + template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") + template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate) + } else { + // If no usable content is found, return an empty string. + return "" + } + + return template +} + +// ConvertCliToOpenAINonStream aggregates response chunks from the backend client +// into a single, non-streaming OpenAI-compatible JSON response. +func ConvertCliToOpenAINonStream(template string, rawJson []byte) string { + // Extract and set metadata fields that are typically set once per response. + if gjson.Get(template, "id").String() == "" { + if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() { + template, _ = sjson.Set(template, "model", modelVersionResult.String()) + } + if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() { + t, err := time.Parse(time.RFC3339Nano, createTimeResult.String()) + unixTimestamp := time.Now().Unix() + if err == nil { + unixTimestamp = t.Unix() + } + template, _ = sjson.Set(template, "created", unixTimestamp) + } + if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() { + template, _ = sjson.Set(template, "id", responseIdResult.String()) + } + } + + // Extract and set the finish reason. + if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() { + template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String()) + template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String()) + } + + // Extract and set usage metadata (token counts). + if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() { + if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int()) + } + if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() { + template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int()) + } + promptTokenCount := usageResult.Get("promptTokenCount").Int() + thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int() + template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount) + if thoughtsTokenCount > 0 { + template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount) + } + } + + // Process the main content part of the response. + partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0") + partTextResult := partResult.Get("text") + functionCallResult := partResult.Get("functionCall") + + if partTextResult.Exists() { + // Append text content, distinguishing between regular content and reasoning. + if partResult.Get("thought").Bool() { + currentContent := gjson.Get(template, "choices.0.message.reasoning_content").String() + template, _ = sjson.Set(template, "choices.0.message.reasoning_content", currentContent+partTextResult.String()) + } else { + currentContent := gjson.Get(template, "choices.0.message.content").String() + template, _ = sjson.Set(template, "choices.0.message.content", currentContent+partTextResult.String()) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + } else if functionCallResult.Exists() { + // Append function call content to the tool_calls array. + if !gjson.Get(template, "choices.0.message.tool_calls").Exists() { + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`) + } + functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}` + fcName := functionCallResult.Get("name").String() + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName) + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName) + if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw) + } + template, _ = sjson.Set(template, "choices.0.message.role", "assistant") + template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate) + } else { + // If no usable content is found, return an empty string. + return "" + } + + return template +} diff --git a/internal/auth/auth.go b/internal/auth/auth.go index f1be254f..208c90d2 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -5,17 +5,18 @@ import ( "encoding/json" "errors" "fmt" - "github.com/luispater/CLIProxyAPI/internal/config" - log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" - "github.com/tidwall/gjson" - "golang.org/x/net/proxy" "io" "net" "net/http" "net/url" "time" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" + "github.com/skratchdot/open-golang/open" + "github.com/tidwall/gjson" + "golang.org/x/net/proxy" + "golang.org/x/oauth2" "golang.org/x/oauth2/google" ) @@ -33,76 +34,78 @@ var ( } ) -// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens. -// It handles the entire flow: loading, refreshing, and fetching new tokens. +// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls. +// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens, +// initiating a new web-based OAuth flow if necessary, and refreshing tokens. func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) { + // Configure proxy settings for the HTTP client if a proxy URL is provided. proxyURL, err := url.Parse(cfg.ProxyUrl) if err == nil { + var transport *http.Transport if proxyURL.Scheme == "socks5" { + // Handle SOCKS5 proxy. username := proxyURL.User.Username() password, _ := proxyURL.User.Password() - auth := &proxy.Auth{ - User: username, - Password: password, - } + auth := &proxy.Auth{User: username, Password: password} dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) if errSOCKS5 != nil { log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) } - - transport := &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) { + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return dialer.Dial(network, addr) }, } - proxyClient := &http.Client{ - Transport: transport, - } - - ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - transport := &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - } - proxyClient := &http.Client{ - Transport: transport, - } + // Handle HTTP/HTTPS proxy. + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + } + + if transport != nil { + proxyClient := &http.Client{Transport: transport} ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) } } + // Configure the OAuth2 client. conf := &oauth2.Config{ ClientID: oauthClientID, ClientSecret: oauthClientSecret, - RedirectURL: "http://localhost:8085/oauth2callback", // Placeholder, will be updated + RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server. Scopes: oauthScopes, Endpoint: google.Endpoint, } var token *oauth2.Token + // If no token is found in storage, initiate the web-based OAuth flow. if ts.Token == nil { log.Info("Could not load token from file, starting OAuth flow.") token, err = getTokenFromWeb(ctx, conf) if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } - newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID) - if errSaveTokenToFile != nil { - log.Errorf("Warning: failed to save token to file: %v", err) - return nil, errSaveTokenToFile + // After getting a new token, create a new token storage object with user info. + newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID) + if errCreateTokenStorage != nil { + log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage) + return nil, errCreateTokenStorage } *ts = *newTs } + + // Unmarshal the stored token into an oauth2.Token object. tsToken, _ := json.Marshal(ts.Token) if err = json.Unmarshal(tsToken, &token); err != nil { - return nil, err + return nil, fmt.Errorf("failed to unmarshal token: %w", err) } + // Return an HTTP client that automatically handles token refreshing. return conf.Client(ctx, token), nil } -// createTokenStorage creates a token storage. +// createTokenStorage creates a new TokenStorage object. It fetches the user's email +// using the provided token and populates the storage structure. func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) { httpClient := config.Client(ctx, token) req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) @@ -117,7 +120,9 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth return nil, fmt.Errorf("failed to execute request: %w", err) } defer func() { - _ = resp.Body.Close() + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } }() bodyBytes, _ := io.ReadAll(resp.Body) @@ -154,7 +159,10 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth return &ts, nil } -// getTokenFromWeb starts a local server to handle the OAuth2 flow. +// getTokenFromWeb initiates the web-based OAuth2 authorization flow. +// It starts a local HTTP server to listen for the callback from Google's auth server, +// opens the user's browser to the authorization URL, and exchanges the received +// authorization code for an access token. func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. codeChan := make(chan string) diff --git a/internal/auth/models.go b/internal/auth/models.go index c90d47c5..33c03745 100644 --- a/internal/auth/models.go +++ b/internal/auth/models.go @@ -1,9 +1,17 @@ package auth +// TokenStorage defines the structure for storing OAuth2 token information, +// along with associated user and project details. This data is typically +// serialized to a JSON file for persistence. type TokenStorage struct { - Token any `json:"token"` + // Token holds the raw OAuth2 token data, including access and refresh tokens. + Token any `json:"token"` + // ProjectID is the Google Cloud Project ID associated with this token. ProjectID string `json:"project_id"` - Email string `json:"email"` - Auto bool `json:"auto"` - Checked bool `json:"checked"` + // Email is the email address of the authenticated user. + Email string `json:"email"` + // Auto indicates if the project ID was automatically selected. + Auto bool `json:"auto"` + // Checked indicates if the associated Cloud AI API has been verified as enabled. + Checked bool `json:"checked"` } diff --git a/internal/client/client.go b/internal/client/client.go index f721c3d7..07fb6d9c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -6,12 +6,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/luispater/CLIProxyAPI/internal/auth" - "github.com/luispater/CLIProxyAPI/internal/config" - log "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "golang.org/x/oauth2" "io" "net/http" "os" @@ -20,6 +14,13 @@ import ( "strings" "sync" "time" + + "github.com/luispater/CLIProxyAPI/internal/auth" + "github.com/luispater/CLIProxyAPI/internal/config" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/oauth2" ) const ( @@ -194,7 +195,9 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo return fmt.Errorf("failed to execute request: %w", err) } defer func() { - _ = resp.Body.Close() + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { @@ -253,7 +256,9 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int if resp.StatusCode < 200 || resp.StatusCode >= 300 { defer func() { - _ = resp.Body.Close() + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } }() bodyBytes, _ := io.ReadAll(resp.Body) @@ -355,6 +360,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st return dataChan, errChan } +// CheckCloudAPIIsEnabled sends a simple test request to the API to verify +// that the Cloud AI API is enabled for the user's project. It provides +// an activation URL if the API is disabled. func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { ctx, cancel := context.WithCancel(context.Background()) defer func() { @@ -363,79 +371,78 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { }() c.RequestMutex.Lock() - requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}` - requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID) - // log.Debug(requestBody) + // A simple request to test the API endpoint. + requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID) + stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody)) if err != nil { + // If a 403 Forbidden error occurs, it likely means the API is not enabled. if err.StatusCode == 403 { errJson := err.Error.Error() - codeResult := gjson.Get(errJson, "error.code") - if codeResult.Exists() && codeResult.Type == gjson.Number { - if codeResult.Int() == 403 { - activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl") - if activationUrlResult.Exists() { - log.Warnf( - "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", - activationUrlResult.String(), - os.Args[0], - c.tokenStorage.ProjectID, - ) - } + // Check for a specific error code and extract the activation URL. + if gjson.Get(errJson, "error.code").Int() == 403 { + activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String() + if activationUrl != "" { + log.Warnf( + "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", + activationUrl, + os.Args[0], + c.tokenStorage.ProjectID, + ) } } return false, nil } return false, err.Error } + defer func() { + _ = stream.Close() + }() + // We only need to know if the request was successful, so we can drain the stream. scanner := bufio.NewScanner(stream) for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, "data: ") { - continue - } + // Do nothing, just consume the stream. } - if scannerErr := scanner.Err(); scannerErr != nil { - _ = stream.Close() - } else { - _ = stream.Close() - } - - return true, nil + return scanner.Err() == nil, scanner.Err() } +// GetProjectList fetches a list of Google Cloud projects accessible by the user. func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) { token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() + if err != nil { + return nil, fmt.Errorf("failed to get token: %w", err) + } + req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil) if err != nil { - return nil, fmt.Errorf("could not get project list: %v", err) + return nil, fmt.Errorf("could not create project list request: %v", err) } - req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) resp, err := c.httpClient.Do(req) if err != nil { - return nil, fmt.Errorf("failed to execute request: %w", err) + return nil, fmt.Errorf("failed to execute project list request: %w", err) } defer func() { _ = resp.Body.Close() }() - bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + bodyBytes, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) } var project GCPProject - err = json.Unmarshal(bodyBytes, &project) - if err != nil { + if err = json.NewDecoder(resp.Body).Decode(&project); err != nil { return nil, fmt.Errorf("failed to unmarshal project list: %w", err) } return &project, nil } +// SaveTokenToFile serializes the client's current token storage to a JSON file. +// The filename is constructed from the user's email and project ID. func (c *Client) SaveTokenToFile() error { if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil { return fmt.Errorf("failed to create directory: %v", err) @@ -457,7 +464,8 @@ func (c *Client) SaveTokenToFile() error { return nil } -// getClientMetadata returns metadata about the client environment. +// getClientMetadata returns a map of metadata about the client environment, +// such as IDE type, platform, and plugin version. func getClientMetadata() map[string]string { return map[string]string{ "ideType": "IDE_UNSPECIFIED", @@ -467,7 +475,8 @@ func getClientMetadata() map[string]string { } } -// getClientMetadataString returns the metadata as a comma-separated string. +// getClientMetadataString returns the client metadata as a single, +// comma-separated string, which is required for the 'Client-Metadata' header. func getClientMetadataString() string { md := getClientMetadata() parts := make([]string, 0, len(md)) @@ -477,11 +486,13 @@ func getClientMetadataString() string { return strings.Join(parts, ",") } +// getUserAgent constructs the User-Agent string for HTTP requests. func getUserAgent() string { - return fmt.Sprintf(fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)) + return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH) } -// getPlatform returns the OS and architecture in the format expected by the API. +// getPlatform determines the operating system and architecture and formats +// it into a string expected by the backend API. func getPlatform() string { goOS := runtime.GOOS arch := runtime.GOARCH diff --git a/internal/client/models.go b/internal/client/models.go index 23515000..a6aa9d44 100644 --- a/internal/client/models.go +++ b/internal/client/models.go @@ -2,17 +2,23 @@ package client import "time" +// ErrorMessage encapsulates an error with an associated HTTP status code. type ErrorMessage struct { StatusCode int Error error } +// GCPProject represents the response structure for a Google Cloud project list request. type GCPProject struct { Projects []GCPProjectProjects `json:"projects"` } + +// GCPProjectLabels defines the labels associated with a GCP project. type GCPProjectLabels struct { GenerativeLanguage string `json:"generative-language"` } + +// GCPProjectProjects contains details about a single Google Cloud project. type GCPProjectProjects struct { ProjectNumber string `json:"projectNumber"` ProjectID string `json:"projectId"` @@ -22,12 +28,14 @@ type GCPProjectProjects struct { CreateTime time.Time `json:"createTime"` } +// Content represents a single message in a conversation, with a role and parts. type Content struct { Role string `json:"role"` Parts []Part `json:"parts"` } -// Part represents a single part of a message's content. +// Part represents a distinct piece of content within a message, which can be +// text, inline data (like an image), a function call, or a function response. type Part struct { Text string `json:"text,omitempty"` InlineData *InlineData `json:"inlineData,omitempty"` @@ -35,46 +43,48 @@ type Part struct { FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` } +// InlineData represents base64-encoded data with its MIME type. type InlineData struct { MimeType string `json:"mime_type,omitempty"` Data string `json:"data,omitempty"` } -// FunctionCall represents a tool call requested by the model. +// FunctionCall represents a tool call requested by the model, including the +// function name and its arguments. type FunctionCall struct { Name string `json:"name"` Args map[string]interface{} `json:"args"` } -// FunctionResponse represents the result of a tool execution. +// FunctionResponse represents the result of a tool execution, sent back to the model. type FunctionResponse struct { Name string `json:"name"` Response map[string]interface{} `json:"response"` } -// GenerateContentRequest is the request payload for the streamGenerateContent endpoint. +// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint. type GenerateContentRequest struct { Contents []Content `json:"contents"` Tools []ToolDeclaration `json:"tools,omitempty"` GenerationConfig `json:"generationConfig"` } -// GenerationConfig defines model generation parameters. +// GenerationConfig defines parameters that control the model's generation behavior. type GenerationConfig struct { ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"topP,omitempty"` TopK float64 `json:"topK,omitempty"` - // Temperature, TopP, TopK, etc. can be added here. } +// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process. type GenerationConfigThinkingConfig struct { + // IncludeThoughts determines whether the model should output its reasoning process. IncludeThoughts bool `json:"include_thoughts,omitempty"` } -// ToolDeclaration is the structure for declaring tools to the API. -// For now, we'll assume a simple structure. A more complete implementation -// would mirror the OpenAPI schema definition. +// ToolDeclaration defines the structure for declaring tools (like functions) +// that the model can call. type ToolDeclaration struct { FunctionDeclarations []interface{} `json:"functionDeclarations"` } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index d71b85a9..3d6aa94a 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -9,6 +9,9 @@ import ( "os" ) +// DoLogin handles the entire user login and setup process. +// It authenticates the user, sets up the user's project, checks API enablement, +// and saves the token for future use. func DoLogin(cfg *config.Config, projectID string) { var err error var ts auth.TokenStorage @@ -16,9 +19,8 @@ func DoLogin(cfg *config.Config, projectID string) { ts.ProjectID = projectID } - // 2. Initialize authenticated HTTP Client + // Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary. clientCtx := context.Background() - log.Info("Initializing authentication...") httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) if errGetClient != nil { @@ -27,51 +29,57 @@ func DoLogin(cfg *config.Config, projectID string) { } log.Info("Authentication successful.") - // 3. Initialize CLI Client + // Initialize the API client. cliClient := client.NewClient(httpClient, &ts, cfg) + + // Perform the user setup process. err = cliClient.SetupUser(clientCtx, ts.Email, projectID) if err != nil { + // Handle the specific case where a project ID is required but not provided. if err.Error() == "failed to start user onboarding, need define a project id" { - log.Error("failed to start user onboarding") + log.Error("Failed to start user onboarding: A project ID is required.") + // Fetch and display the user's available projects to help them choose one. project, errGetProjectList := cliClient.GetProjectList(clientCtx) if errGetProjectList != nil { - log.Fatalf("failed to complete user setup: %v", err) + log.Fatalf("Failed to get project list: %v", err) } else { - log.Infof("Your account %s needs specify a project id.", ts.Email) + log.Infof("Your account %s needs to specify a project ID.", ts.Email) log.Info("========================================================================") - for i := 0; i < len(project.Projects); i++ { - log.Infof("Project ID: %s", project.Projects[i].ProjectID) - log.Infof("Project Name: %s", project.Projects[i].Name) - log.Info("========================================================================") + for _, p := range project.Projects { + log.Infof("Project ID: %s", p.ProjectID) + log.Infof("Project Name: %s", p.Name) + log.Info("------------------------------------------------------------------------") } - log.Infof("Please run this command to login again:\n\n%s --login --project_id \n", os.Args[0]) + log.Infof("Please run this command to login again with a specific project:\n\n%s --login --project_id \n", os.Args[0]) } } else { - // Log as a warning because in some cases, the CLI might still be usable - // or the user might want to retry setup later. - log.Fatalf("failed to complete user setup: %v", err) + log.Fatalf("Failed to complete user setup: %v", err) } - } else { - auto := projectID == "" - cliClient.SetIsAuto(auto) + return // Exit after handling the error. + } - if !cliClient.IsChecked() && !cliClient.IsAuto() { - isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled() - if checkErr != nil { - log.Fatalf("failed to check cloud api is enabled: %v", checkErr) - return - } - cliClient.SetIsChecked(isChecked) - } + // If setup is successful, proceed to check API status and save the token. + auto := projectID == "" + cliClient.SetIsAuto(auto) - if !cliClient.IsChecked() && !cliClient.IsAuto() { + // If the project was not automatically selected, check if the Cloud AI API is enabled. + if !cliClient.IsChecked() && !cliClient.IsAuto() { + isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled() + if checkErr != nil { + log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr) return } - - err = cliClient.SaveTokenToFile() - if err != nil { - log.Fatal(err) + cliClient.SetIsChecked(isChecked) + // If the check fails (returns false), the CheckCloudAPIIsEnabled function + // will have already printed instructions, so we can just exit. + if !isChecked { return } } + + // Save the successfully obtained and verified token to a file. + err = cliClient.SaveTokenToFile() + if err != nil { + log.Fatalf("Failed to save token to file: %v", err) + } } diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 8bf291bf..5fa6ebb1 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -18,20 +18,25 @@ import ( "time" ) +// StartService initializes and starts the main API proxy service. +// It loads all available authentication tokens, creates a pool of clients, +// starts the API server, and handles graceful shutdown signals. func StartService(cfg *config.Config) { - // Create API server configuration + // Configure the API server based on the main application config. apiConfig := &api.ServerConfig{ Port: fmt.Sprintf("%d", cfg.Port), Debug: cfg.Debug, ApiKeys: cfg.ApiKeys, } + // Create a pool of API clients, one for each token file found. cliClients := make([]*client.Client, 0) err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error { if err != nil { return err } + // Process only JSON files in the auth directory. if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") { log.Debugf("Loading token from: %s", path) f, errOpen := os.Open(path) @@ -42,58 +47,62 @@ func StartService(cfg *config.Config) { _ = f.Close() }() + // Decode the token storage file. var ts auth.TokenStorage if err = json.NewDecoder(f).Decode(&ts); err == nil { - // 2. Initialize authenticated HTTP Client + // For each valid token, create an authenticated client. clientCtx := context.Background() - - log.Info("Initializing authentication...") + log.Info("Initializing authentication for token...") httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) if errGetClient != nil { - log.Fatalf("failed to get authenticated client: %v", errGetClient) + // Log fatal will exit, but we return the error for completeness. + log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient) return errGetClient } log.Info("Authentication successful.") - // 3. Initialize CLI Client + // Add the new client to the pool. cliClient := client.NewClient(httpClient, &ts, cfg) cliClients = append(cliClients, cliClient) } } return nil }) + if err != nil { + log.Fatalf("Error walking auth directory: %v", err) + } - // Create API server + // Create and start the API server with the pool of clients. apiServer := api.NewServer(apiConfig, cliClients) log.Infof("Starting API server on port %s", apiConfig.Port) if err = apiServer.Start(); err != nil { log.Fatalf("API server failed to start: %v", err) - return } - // Set up graceful shutdown + // Set up a channel to listen for OS signals for graceful shutdown. sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + // Main loop to wait for shutdown signal. for { select { case <-sigChan: log.Debugf("Received shutdown signal. Cleaning up...") - // Create shutdown context + // Create a context with a timeout for the shutdown process. ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out + _ = cancel - // Stop API server + // Stop the API server gracefully. if err = apiServer.Stop(ctx); err != nil { log.Debugf("Error stopping API server: %v", err) } - cancel() log.Debugf("Cleanup completed. Exiting...") os.Exit(0) case <-time.After(5 * time.Second): - + // This case is currently empty and acts as a periodic check. + // It could be used for periodic tasks in the future. } } } diff --git a/internal/config/config.go b/internal/config/config.go index 01bc8bf0..38fb864d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,33 +6,35 @@ import ( "os" ) -// Config represents the application's configuration +// Config represents the application's configuration, loaded from a YAML file. type Config struct { - Port int `yaml:"port"` - AuthDir string `yaml:"auth_dir"` - Debug bool `yaml:"debug"` - ProxyUrl string `yaml:"proxy-url"` - ApiKeys []string `yaml:"api_keys"` + // Port is the network port on which the API server will listen. + Port int `yaml:"port"` + // AuthDir is the directory where authentication token files are stored. + AuthDir string `yaml:"auth_dir"` + // Debug enables or disables debug-level logging and other debug features. + Debug bool `yaml:"debug"` + // ProxyUrl is the URL of an optional proxy server to use for outbound requests. + ProxyUrl string `yaml:"proxy-url"` + // ApiKeys is a list of keys for authenticating clients to this proxy server. + ApiKeys []string `yaml:"api_keys"` } -// / LoadConfig loads the configuration from the specified file +// LoadConfig reads a YAML configuration file from the given path, +// unmarshals it into a Config struct, and returns it. func LoadConfig(configFile string) (*Config, error) { - // Read the configuration file + // Read the entire configuration file into memory. data, err := os.ReadFile(configFile) - // If reading the file fails if err != nil { - // Return an error return nil, fmt.Errorf("failed to read config file: %w", err) } - // Parse the YAML data + // Unmarshal the YAML data into the Config struct. var config Config - // If parsing the YAML data fails if err = yaml.Unmarshal(data, &config); err != nil { - // Return an error return nil, fmt.Errorf("failed to parse config file: %w", err) } - // Return the configuration + // Return the populated configuration struct. return &config, nil }