Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beff9282f6 | ||
|
|
31a9e2d11f | ||
|
|
423faae3da | ||
|
|
ead71fb7ef | ||
|
|
58b7afdf1e | ||
|
|
c86545d7e1 | ||
|
|
f49a530c1a | ||
|
|
368796349e | ||
|
|
c601542f6f | ||
|
|
3c0c61aaf1 | ||
|
|
edeadfc389 | ||
|
|
aa9fd057fe | ||
|
|
b3607d3981 | ||
|
|
fa8d94971f | ||
|
|
ef68a97526 |
42
.github/workflows/docker-image.yml
vendored
Normal file
42
.github/workflows/docker-image.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
name: docker-image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*
|
||||||
|
|
||||||
|
env:
|
||||||
|
APP_NAME: CLIProxyAPI
|
||||||
|
DOCKERHUB_REPO: eceasy/cli-proxy-api
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Login to DockerHub
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
|
- name: Generate App Version
|
||||||
|
run: echo APP_VERSION=`git describe --tags --always` >> $GITHUB_ENV
|
||||||
|
- name: Build and push
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
platforms: |
|
||||||
|
linux/amd64
|
||||||
|
linux/arm64
|
||||||
|
push: true
|
||||||
|
build-args: |
|
||||||
|
APP_NAME=${{ env.APP_NAME }}
|
||||||
|
APP_VERSION=${{ env.APP_VERSION }}
|
||||||
|
tags: |
|
||||||
|
${{ env.DOCKERHUB_REPO }}:latest
|
||||||
|
${{ env.DOCKERHUB_REPO }}:${{ env.APP_VERSION }}
|
||||||
@@ -14,4 +14,5 @@ archives:
|
|||||||
files:
|
files:
|
||||||
- LICENSE
|
- LICENSE
|
||||||
- README.md
|
- README.md
|
||||||
|
- README_CN.md
|
||||||
- config.yaml
|
- config.yaml
|
||||||
23
Dockerfile
Normal file
23
Dockerfile
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
FROM golang:1.24-alpine AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
RUN CGO_ENABLED=0 GOOS=linux go build -o ./CLIProxyAPI ./cmd/server/
|
||||||
|
|
||||||
|
FROM alpine:3.22.0
|
||||||
|
|
||||||
|
RUN mkdir /CLIProxyAPI
|
||||||
|
|
||||||
|
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
|
||||||
|
|
||||||
|
WORKDIR /CLIProxyAPI
|
||||||
|
|
||||||
|
EXPOSE 8317
|
||||||
|
|
||||||
|
CMD ["./CLIProxyAPI"]
|
||||||
39
README.md
39
README.md
@@ -1,10 +1,12 @@
|
|||||||
# CLI Proxy API
|
# CLI Proxy API
|
||||||
|
|
||||||
A proxy server that provides an OpenAI-compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI API.
|
English | [中文](README_CN.md)
|
||||||
|
|
||||||
|
A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- OpenAI-compatible API endpoints for CLI models
|
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||||
- Support for both streaming and non-streaming responses
|
- Support for both streaming and non-streaming responses
|
||||||
- Function calling/tools support
|
- Function calling/tools support
|
||||||
- Multimodal input support (text and images)
|
- Multimodal input support (text and images)
|
||||||
@@ -136,7 +138,7 @@ console.log(response.choices[0].message.content);
|
|||||||
|
|
||||||
- gemini-2.5-pro
|
- gemini-2.5-pro
|
||||||
- gemini-2.5-flash
|
- gemini-2.5-flash
|
||||||
- And various preview versions
|
- And it automates switching to various preview versions
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
@@ -149,10 +151,13 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
|
|||||||
### Configuration Options
|
### Configuration Options
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
| Parameter | Type | Default | Description |
|
||||||
|-------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
|
|---------------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
|
||||||
| `port` | integer | 8317 | The port number on which the server will listen |
|
| `port` | integer | 8317 | The port number on which the server will listen |
|
||||||
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
|
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
|
||||||
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
|
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
|
||||||
|
| `quota-exceeded` | object | {} | Configuration for handling quota exceeded |
|
||||||
|
| `quota-exceeded.switch-project` | boolean | true | Whether to automatically switch to another project when a quota is exceeded |
|
||||||
|
| `quota-exceeded.switch-preview-model` | boolean | true | Whether to automatically switch to a preview model when a quota is exceeded |
|
||||||
| `debug` | boolean | false | Enable debug mode for verbose logging |
|
| `debug` | boolean | false | Enable debug mode for verbose logging |
|
||||||
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
|
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
|
||||||
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
|
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
|
||||||
@@ -169,6 +174,14 @@ auth-dir: "~/.cli-proxy-api"
|
|||||||
# Enable debug logging
|
# Enable debug logging
|
||||||
debug: false
|
debug: false
|
||||||
|
|
||||||
|
# Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/
|
||||||
|
proxy-url: ""
|
||||||
|
|
||||||
|
# Quota exceeded behavior
|
||||||
|
quota-exceeded:
|
||||||
|
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||||
|
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||||
|
|
||||||
# API keys for authentication
|
# API keys for authentication
|
||||||
api-keys:
|
api-keys:
|
||||||
- "your-api-key-1"
|
- "your-api-key-1"
|
||||||
@@ -208,6 +221,24 @@ export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
|||||||
|
|
||||||
The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` requests. And automatically load balance the text generation requests between the multiple accounts.
|
The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` requests. And automatically load balance the text generation requests between the multiple accounts.
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> This feature only allows local access because I couldn't find a way to authenticate the requests.
|
||||||
|
> I hardcoded `127.0.0.1` into the load balancing.
|
||||||
|
|
||||||
|
## Run with Docker
|
||||||
|
|
||||||
|
Run the following command to login:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||||
|
```
|
||||||
|
|
||||||
|
Run the following command to start the server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
|
||||||
|
```
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||||
|
|||||||
254
README_CN.md
Normal file
254
README_CN.md
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
# CLI 代理 API
|
||||||
|
|
||||||
|
[English](README.md) | 中文
|
||||||
|
|
||||||
|
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
|
||||||
|
- 支持流式和非流式响应
|
||||||
|
- 函数调用/工具支持
|
||||||
|
- 多模态输入支持(文本和图像)
|
||||||
|
- 多账户支持与负载均衡
|
||||||
|
- 简单的 CLI 身份验证流程
|
||||||
|
- 支持 Gemini AIStudio API 密钥
|
||||||
|
- 支持 Gemini CLI 多账户轮询
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
### 前置要求
|
||||||
|
|
||||||
|
- Go 1.24 或更高版本
|
||||||
|
- 有权访问 CLI 模型的 Google 账户
|
||||||
|
|
||||||
|
### 从源码构建
|
||||||
|
|
||||||
|
1. 克隆仓库:
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/luispater/CLIProxyAPI.git
|
||||||
|
cd CLIProxyAPI
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 构建应用程序:
|
||||||
|
```bash
|
||||||
|
go build -o cli-proxy-api ./cmd/server
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 身份验证
|
||||||
|
|
||||||
|
在使用 API 之前,您需要使用 Google 账户进行身份验证:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./cli-proxy-api --login
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您是旧版 gemini code 用户,可能需要指定项目 ID:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./cli-proxy-api --login --project_id <your_project_id>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 启动服务器
|
||||||
|
|
||||||
|
身份验证完成后,启动服务器:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./cli-proxy-api
|
||||||
|
```
|
||||||
|
|
||||||
|
默认情况下,服务器在端口 8317 上运行。
|
||||||
|
|
||||||
|
### API 端点
|
||||||
|
|
||||||
|
#### 列出模型
|
||||||
|
|
||||||
|
```
|
||||||
|
GET http://localhost:8317/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 聊天补全
|
||||||
|
|
||||||
|
```
|
||||||
|
POST http://localhost:8317/v1/chat/completions
|
||||||
|
```
|
||||||
|
|
||||||
|
请求体示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "gemini-2.5-pro",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你好,你好吗?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 与 OpenAI 库一起使用
|
||||||
|
|
||||||
|
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
|
||||||
|
|
||||||
|
#### Python(使用 OpenAI 库)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key="dummy", # 不使用但必需
|
||||||
|
base_url="http://localhost:8317/v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="gemini-2.5-pro",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "你好,你好吗?"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### JavaScript/TypeScript
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
import OpenAI from 'openai';
|
||||||
|
|
||||||
|
const openai = new OpenAI({
|
||||||
|
apiKey: 'dummy', // 不使用但必需
|
||||||
|
baseURL: 'http://localhost:8317/v1',
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await openai.chat.completions.create({
|
||||||
|
model: 'gemini-2.5-pro',
|
||||||
|
messages: [
|
||||||
|
{ role: 'user', content: '你好,你好吗?' }
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(response.choices[0].message.content);
|
||||||
|
```
|
||||||
|
|
||||||
|
## 支持的模型
|
||||||
|
|
||||||
|
- gemini-2.5-pro
|
||||||
|
- gemini-2.5-flash
|
||||||
|
- 并且自动切换到之前的预览版本
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./cli-proxy --config /path/to/your/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置选项
|
||||||
|
|
||||||
|
| 参数 | 类型 | 默认值 | 描述 |
|
||||||
|
|---------------------------------------|----------|--------------------|------------------------------------------------------------------------|
|
||||||
|
| `port` | integer | 8317 | 服务器监听的端口号 |
|
||||||
|
| `auth-dir` | string | "~/.cli-proxy-api" | 存储身份验证令牌的目录。支持使用 `~` 表示主目录 |
|
||||||
|
| `proxy-url` | string | "" | 代理 URL,支持 socks5/http/https 协议,示例:socks5://user:pass@192.168.1.1:1080/ |
|
||||||
|
| `quota-exceeded` | object | {} | 处理配额超限的配置 |
|
||||||
|
| `quota-exceeded.switch-project` | boolean | true | 当配额超限时是否自动切换到另一个项目 |
|
||||||
|
| `quota-exceeded.switch-preview-model` | boolean | true | 当配额超限时是否自动切换到预览模型 |
|
||||||
|
| `debug` | boolean | false | 启用调试模式以进行详细日志记录 |
|
||||||
|
| `api-keys` | string[] | [] | 可用于验证请求的 API 密钥列表 |
|
||||||
|
| `generative-language-api-key` | string[] | [] | 生成式语言 API 密钥列表 |
|
||||||
|
|
||||||
|
### 配置文件示例
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# 服务器端口
|
||||||
|
port: 8317
|
||||||
|
|
||||||
|
# 身份验证目录(支持 ~ 表示主目录)
|
||||||
|
auth-dir: "~/.cli-proxy-api"
|
||||||
|
|
||||||
|
# 启用调试日志
|
||||||
|
debug: false
|
||||||
|
|
||||||
|
# 代理 URL,支持 socks5/http/https 协议,示例:socks5://user:pass@192.168.1.1:1080/
|
||||||
|
proxy-url: ""
|
||||||
|
|
||||||
|
# 配额超限行为
|
||||||
|
quota-exceeded:
|
||||||
|
switch-project: true # 当配额超限时是否自动切换到另一个项目
|
||||||
|
switch-preview-model: true # 当配额超限时是否自动切换到预览模型
|
||||||
|
|
||||||
|
# 用于本地身份验证的 API 密钥
|
||||||
|
api-keys:
|
||||||
|
- "your-api-key-1"
|
||||||
|
- "your-api-key-2"
|
||||||
|
|
||||||
|
# AIStduio Gemini API 的 API 密钥
|
||||||
|
generative-language-api-key:
|
||||||
|
- "AIzaSy...01"
|
||||||
|
- "AIzaSy...02"
|
||||||
|
- "AIzaSy...03"
|
||||||
|
- "AIzaSy...04"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 身份验证目录
|
||||||
|
|
||||||
|
`auth-dir` 参数指定身份验证令牌的存储位置。当您运行登录命令时,应用程序将在此目录中创建包含 Google 账户身份验证令牌的 JSON 文件。多个账户可用于轮询。
|
||||||
|
|
||||||
|
### API 密钥
|
||||||
|
|
||||||
|
`api-keys` 参数允许您定义可用于验证对代理服务器请求的 API 密钥列表。在向 API 发出请求时,您可以在 `Authorization` 标头中包含其中一个密钥:
|
||||||
|
|
||||||
|
```
|
||||||
|
Authorization: Bearer your-api-key-1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 官方生成式语言 API
|
||||||
|
|
||||||
|
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
|
||||||
|
|
||||||
|
## Gemini CLI 多账户负载均衡
|
||||||
|
|
||||||
|
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
||||||
|
```
|
||||||
|
|
||||||
|
服务器将中继 `loadCodeAssist`、`onboardUser` 和 `countTokens` 请求。并自动在多个账户之间轮询文本生成请求。
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 此功能仅允许本地访问,因为找不到一个可以验证请求的方法。
|
||||||
|
> 所以只能强制只有 `127.0.0.1` 可以访问。
|
||||||
|
|
||||||
|
## 使用 Docker 运行
|
||||||
|
|
||||||
|
运行以下命令进行登录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||||
|
```
|
||||||
|
|
||||||
|
运行以下命令启动服务器:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
## 贡献
|
||||||
|
|
||||||
|
欢迎贡献!请随时提交 Pull Request。
|
||||||
|
|
||||||
|
1. Fork 仓库
|
||||||
|
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`)
|
||||||
|
3. 提交您的更改(`git commit -m 'Add some amazing feature'`)
|
||||||
|
4. 推送到分支(`git push origin feature/amazing-feature`)
|
||||||
|
5. 打开 Pull Request
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||||
188
internal/api/claude-code-handlers.go
Normal file
188
internal/api/claude-code-handlers.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
||||||
|
// This function implements a sophisticated client rotation and quota management system
|
||||||
|
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
||||||
|
func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
|
||||||
|
// Extract raw JSON data from the incoming request
|
||||||
|
rawJson, err := c.GetRawData()
|
||||||
|
// If data retrieval fails, return a 400 Bad Request error.
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||||
|
// These headers are essential for maintaining a persistent connection
|
||||||
|
// and enabling real-time streaming of chat completions
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||||
|
// conversation contents, and available tools from the raw JSON
|
||||||
|
modelName, systemInstruction, contents, tools := translator.PrepareClaudeRequest(rawJson)
|
||||||
|
|
||||||
|
// Map Claude model names to corresponding Gemini models
|
||||||
|
// This allows the proxy to handle Claude API calls using Gemini backend
|
||||||
|
if modelName == "claude-sonnet-4-20250514" {
|
||||||
|
modelName = "gemini-2.5-pro"
|
||||||
|
} else if modelName == "claude-3-5-haiku-20241022" {
|
||||||
|
modelName = "gemini-2.5-flash"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a cancellable context for the backend client request
|
||||||
|
// This allows proper cleanup and cancellation of ongoing requests
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient *client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
// This prevents deadlocks and ensures proper resource cleanup
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.RequestMutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Main client rotation loop with quota management
|
||||||
|
// This loop implements a sophisticated load balancing and failover mechanism
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the authentication method being used by the selected client
|
||||||
|
// This affects how responses are formatted and logged
|
||||||
|
isGlAPIKey := false
|
||||||
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
isGlAPIKey = true
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
}
|
||||||
|
// Initiate streaming communication with the backend client
|
||||||
|
// This returns two channels: one for response chunks and one for errors
|
||||||
|
|
||||||
|
includeThoughts := false
|
||||||
|
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
|
||||||
|
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
||||||
|
}
|
||||||
|
|
||||||
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools, includeThoughts)
|
||||||
|
|
||||||
|
// Track response state for proper Claude format conversion
|
||||||
|
hasFirstResponse := false
|
||||||
|
responseType := 0
|
||||||
|
responseIndex := 0
|
||||||
|
|
||||||
|
// Main streaming loop - handles multiple concurrent events using Go channels
|
||||||
|
// This select statement manages four different types of events simultaneously
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Case 1: Handle client disconnection
|
||||||
|
// Detects when the HTTP client has disconnected and cleans up resources
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: Process incoming response chunks from the backend
|
||||||
|
// This handles the actual streaming data from the AI model
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
// Stream has ended - send the final message_stop event
|
||||||
|
// This follows the Claude API specification for stream termination
|
||||||
|
_, _ = c.Writer.Write([]byte(`event: message_stop`))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n\n"))
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
// Convert the backend response to Claude-compatible format
|
||||||
|
// This translation layer ensures API compatibility
|
||||||
|
claudeFormat := translator.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
||||||
|
if claudeFormat != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(claudeFormat))
|
||||||
|
flusher.Flush() // Immediately send the chunk to the client
|
||||||
|
}
|
||||||
|
hasFirstResponse = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 3: Handle errors from the backend
|
||||||
|
// This manages various error conditions and implements retry logic
|
||||||
|
case errInfo, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
// Special handling for quota exceeded errors
|
||||||
|
// If configured, attempt to switch to a different project/client
|
||||||
|
if errInfo.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop // Restart the client selection process
|
||||||
|
} else {
|
||||||
|
// Forward other errors directly to the client
|
||||||
|
c.Status(errInfo.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 4: Send periodic keep-alive signals
|
||||||
|
// Prevents connection timeouts during long-running requests
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
if hasFirstResponse {
|
||||||
|
// Send a ping event to maintain the connection
|
||||||
|
// This is especially important for slow AI model responses
|
||||||
|
output := "event: ping\n"
|
||||||
|
output = output + `data: {"type": "ping"}`
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
_, _ = c.Writer.Write([]byte(output))
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
239
internal/api/cli-handlers.go
Normal file
239
internal/api/cli-handlers.go
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
||||||
|
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
||||||
|
c.JSON(http.StatusForbidden, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: "CLI reply only allow local access",
|
||||||
|
Type: "forbidden",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJson, _ := c.GetRawData()
|
||||||
|
requestRawURI := c.Request.URL.Path
|
||||||
|
if requestRawURI == "/v1internal:generateContent" {
|
||||||
|
h.internalGenerateContent(c, rawJson)
|
||||||
|
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
||||||
|
h.internalStreamGenerateContent(c, rawJson)
|
||||||
|
} else {
|
||||||
|
reqBody := bytes.NewBuffer(rawJson)
|
||||||
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, value := range c.Request.Header {
|
||||||
|
req.Header[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient, err := util.SetProxy(h.cfg, &http.Client{})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("set proxy failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
defer func() {
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: string(bodyBytes),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for key, value := range resp.Header {
|
||||||
|
c.Header(key, value[0])
|
||||||
|
}
|
||||||
|
output, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to read response body: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient *client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.RequestMutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
}
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "")
|
||||||
|
hasFirstResponse := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
hasFirstResponse = true
|
||||||
|
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
||||||
|
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
if hasFirstResponse {
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient *client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.RequestMutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "")
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
_, _ = c.Writer.Write(resp)
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
394
internal/api/gemini-handlers.go
Normal file
394
internal/api/gemini-handlers.go
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *APIHandlers) GeminiModels(c *gin.Context) {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||||
|
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
|
||||||
|
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
|
||||||
|
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
|
||||||
|
var request struct {
|
||||||
|
Action string `uri:"action" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindUri(&request); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if request.Action == "gemini-2.5-pro" {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||||
|
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
||||||
|
} else if request.Action == "gemini-2.5-flash" {
|
||||||
|
c.Status(http.StatusOK)
|
||||||
|
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||||
|
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
|
||||||
|
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
|
||||||
|
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
||||||
|
} else {
|
||||||
|
c.Status(http.StatusNotFound)
|
||||||
|
_, _ = c.Writer.Write([]byte(
|
||||||
|
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
|
||||||
|
var request struct {
|
||||||
|
Action string `uri:"action" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindUri(&request); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
action := strings.Split(request.Action, ":")
|
||||||
|
if len(action) != 2 {
|
||||||
|
c.JSON(http.StatusNotFound, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := action[0]
|
||||||
|
method := action[1]
|
||||||
|
rawJson, _ := c.GetRawData()
|
||||||
|
rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName))
|
||||||
|
|
||||||
|
if method == "generateContent" {
|
||||||
|
h.geminiGenerateContent(c, rawJson)
|
||||||
|
} else if method == "streamGenerateContent" {
|
||||||
|
h.geminiStreamGenerateContent(c, rawJson)
|
||||||
|
} else if method == "countTokens" {
|
||||||
|
h.geminiCountTokens(c, rawJson)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
|
||||||
|
alt := h.getAlt(c)
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient *client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.RequestMutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
template := `{"project":"","request":{},"model":""}`
|
||||||
|
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
||||||
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
|
||||||
|
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
||||||
|
if errFixCLIToolResponse != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: errFixCLIToolResponse.Error(),
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
|
if systemInstructionResult.Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
|
}
|
||||||
|
rawJson = []byte(template)
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||||
|
if alt == "" {
|
||||||
|
responseResult := gjson.GetBytes(chunk, "response")
|
||||||
|
if responseResult.Exists() {
|
||||||
|
chunk = []byte(responseResult.Raw)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chunkTemplate := "[]"
|
||||||
|
responseResult := gjson.ParseBytes(chunk)
|
||||||
|
if responseResult.IsArray() {
|
||||||
|
responseResultItems := responseResult.Array()
|
||||||
|
for i := 0; i < len(responseResultItems); i++ {
|
||||||
|
responseResultItem := responseResultItems[i]
|
||||||
|
if responseResultItem.Get("response").Exists() {
|
||||||
|
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunk = []byte(chunkTemplate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if alt == "" {
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
} else {
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
alt := h.getAlt(c)
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient *client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.RequestMutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
|
||||||
|
template := `{"request":{}}`
|
||||||
|
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
|
||||||
|
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||||
|
rawJson = []byte(template)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||||
|
responseResult := gjson.GetBytes(resp, "response")
|
||||||
|
if responseResult.Exists() {
|
||||||
|
resp = []byte(responseResult.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(resp)
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
alt := h.getAlt(c)
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient *client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.RequestMutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
template := `{"project":"","request":{},"model":""}`
|
||||||
|
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
||||||
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
|
||||||
|
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
||||||
|
if errFixCLIToolResponse != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: errFixCLIToolResponse.Error(),
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
|
if systemInstructionResult.Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
|
}
|
||||||
|
rawJson = []byte(template)
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
}
|
||||||
|
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||||
|
responseResult := gjson.GetBytes(resp, "response")
|
||||||
|
if responseResult.Exists() {
|
||||||
|
resp = []byte(responseResult.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(resp)
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) getAlt(c *gin.Context) string {
|
||||||
|
var alt string
|
||||||
|
var hasAlt bool
|
||||||
|
alt, hasAlt = c.GetQuery("alt")
|
||||||
|
if !hasAlt {
|
||||||
|
alt, _ = c.GetQuery("$alt")
|
||||||
|
}
|
||||||
|
if alt == "sse" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return alt
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
@@ -9,12 +8,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -47,46 +41,6 @@ func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandler
|
|||||||
func (h *APIHandlers) Models(c *gin.Context) {
|
func (h *APIHandlers) Models(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": []map[string]any{
|
"data": []map[string]any{
|
||||||
{
|
|
||||||
"id": "gemini-2.5-pro-preview-05-06",
|
|
||||||
"object": "model",
|
|
||||||
"version": "2.5-preview-05-06",
|
|
||||||
"name": "Gemini 2.5 Pro Preview 05-06",
|
|
||||||
"description": "Preview release (May 6th, 2025) of Gemini 2.5 Pro",
|
|
||||||
"context_length": 1048576,
|
|
||||||
"max_completion_tokens": 65536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gemini-2.5-pro-preview-06-05",
|
|
||||||
"object": "model",
|
|
||||||
"version": "2.5-preview-06-05",
|
|
||||||
"name": "Gemini 2.5 Pro Preview 06-05",
|
|
||||||
"description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro",
|
|
||||||
"context_length": 1048576,
|
|
||||||
"max_completion_tokens": 65536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": "gemini-2.5-pro",
|
"id": "gemini-2.5-pro",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
@@ -107,46 +61,6 @@ func (h *APIHandlers) Models(c *gin.Context) {
|
|||||||
"maxTemperature": 2,
|
"maxTemperature": 2,
|
||||||
"thinking": true,
|
"thinking": true,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"id": "gemini-2.5-flash-preview-04-17",
|
|
||||||
"object": "model",
|
|
||||||
"version": "2.5-preview-04-17",
|
|
||||||
"name": "Gemini 2.5 Flash Preview 04-17",
|
|
||||||
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
|
|
||||||
"context_length": 1048576,
|
|
||||||
"max_completion_tokens": 65536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gemini-2.5-flash-preview-05-20",
|
|
||||||
"object": "model",
|
|
||||||
"version": "2.5-preview-05-20",
|
|
||||||
"name": "Gemini 2.5 Flash Preview 05-20",
|
|
||||||
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
|
|
||||||
"context_length": 1048576,
|
|
||||||
"max_completion_tokens": 65536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"id": "gemini-2.5-flash",
|
"id": "gemini-2.5-flash",
|
||||||
"object": "model",
|
"object": "model",
|
||||||
@@ -171,6 +85,52 @@ func (h *APIHandlers) Models(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.ErrorMessage) {
|
||||||
|
if len(h.cliClients) == 0 {
|
||||||
|
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cliClient *client.Client
|
||||||
|
|
||||||
|
// Lock the mutex to update the last used client index
|
||||||
|
mutex.Lock()
|
||||||
|
startIndex := lastUsedClientIndex
|
||||||
|
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||||
|
lastUsedClientIndex = currentIndex
|
||||||
|
mutex.Unlock()
|
||||||
|
|
||||||
|
// Reorder the client to start from the last used index
|
||||||
|
reorderedClients := make([]*client.Client, 0)
|
||||||
|
for i := 0; i < len(h.cliClients); i++ {
|
||||||
|
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||||
|
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||||
|
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
cliClient = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reorderedClients = append(reorderedClients, cliClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(reorderedClients) == 0 {
|
||||||
|
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
|
||||||
|
}
|
||||||
|
|
||||||
|
locked := false
|
||||||
|
for i := 0; i < len(reorderedClients); i++ {
|
||||||
|
cliClient = reorderedClients[i]
|
||||||
|
if cliClient.RequestMutex.TryLock() {
|
||||||
|
locked = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !locked {
|
||||||
|
cliClient = h.cliClients[0]
|
||||||
|
cliClient.RequestMutex.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return cliClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||||
// It determines whether the request is for a streaming or non-streaming response
|
// It determines whether the request is for a streaming or non-streaming response
|
||||||
// and calls the appropriate handler.
|
// and calls the appropriate handler.
|
||||||
@@ -212,45 +172,15 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Lock the mutex to update the last used client index
|
var errorResponse *client.ErrorMessage
|
||||||
mutex.Lock()
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
startIndex := lastUsedClientIndex
|
if errorResponse != nil {
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
c.Status(errorResponse.StatusCode)
|
||||||
lastUsedClientIndex = currentIndex
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
mutex.Unlock()
|
|
||||||
|
|
||||||
// Reorder the client to start from the last used index
|
|
||||||
reorderedClients := make([]*client.Client, 0)
|
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
|
||||||
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
|
||||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
|
||||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
cliClient = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reorderedClients = append(reorderedClients, cliClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reorderedClients) == 0 {
|
|
||||||
c.Status(429)
|
|
||||||
_, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)))
|
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
locked := false
|
|
||||||
for i := 0; i < len(reorderedClients); i++ {
|
|
||||||
cliClient = reorderedClients[i]
|
|
||||||
if cliClient.RequestMutex.TryLock() {
|
|
||||||
locked = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !locked {
|
|
||||||
cliClient = h.cliClients[0]
|
|
||||||
cliClient.RequestMutex.Lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
isGlAPIKey := false
|
isGlAPIKey := false
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
@@ -312,46 +242,16 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
|
|
||||||
outLoop:
|
outLoop:
|
||||||
for {
|
for {
|
||||||
// Lock the mutex to update the last used client index
|
var errorResponse *client.ErrorMessage
|
||||||
mutex.Lock()
|
cliClient, errorResponse = h.getClient(modelName)
|
||||||
startIndex := lastUsedClientIndex
|
if errorResponse != nil {
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
c.Status(errorResponse.StatusCode)
|
||||||
lastUsedClientIndex = currentIndex
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
mutex.Unlock()
|
|
||||||
|
|
||||||
// Reorder the client to start from the last used index
|
|
||||||
reorderedClients := make([]*client.Client, 0)
|
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
|
||||||
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
|
||||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
|
||||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
cliClient = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reorderedClients = append(reorderedClients, cliClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reorderedClients) == 0 {
|
|
||||||
c.Status(429)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
|
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
locked := false
|
|
||||||
for i := 0; i < len(reorderedClients); i++ {
|
|
||||||
cliClient = reorderedClients[i]
|
|
||||||
if cliClient.RequestMutex.TryLock() {
|
|
||||||
locked = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !locked {
|
|
||||||
cliClient = h.cliClients[0]
|
|
||||||
cliClient.RequestMutex.Lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
isGlAPIKey := false
|
isGlAPIKey := false
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
@@ -411,295 +311,3 @@ outLoop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) Internal(c *gin.Context) {
|
|
||||||
rawJson, _ := c.GetRawData()
|
|
||||||
requestRawURI := c.Request.URL.Path
|
|
||||||
if requestRawURI == "/v1internal:generateContent" {
|
|
||||||
h.internalGenerateContent(c, rawJson)
|
|
||||||
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
|
||||||
h.internalStreamGenerateContent(c, rawJson)
|
|
||||||
} else {
|
|
||||||
reqBody := bytes.NewBuffer(rawJson)
|
|
||||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
|
||||||
Error: ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for key, value := range c.Request.Header {
|
|
||||||
req.Header[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
var transport *http.Transport
|
|
||||||
proxyURL, errParse := url.Parse(h.cfg.ProxyUrl)
|
|
||||||
if errParse == nil {
|
|
||||||
if proxyURL.Scheme == "socks5" {
|
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
}
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
httpClient := &http.Client{}
|
|
||||||
if transport != nil {
|
|
||||||
httpClient.Transport = transport
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
|
||||||
Error: ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
defer func() {
|
|
||||||
if err = resp.Body.Close(); err != nil {
|
|
||||||
log.Printf("warn: failed to close response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
|
||||||
Error: ErrorDetail{
|
|
||||||
Message: string(bodyBytes),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for key, value := range resp.Header {
|
|
||||||
c.Header(key, value[0])
|
|
||||||
}
|
|
||||||
output, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to read response body: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write(output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
|
||||||
Error: ErrorDetail{
|
|
||||||
Message: "Streaming not supported",
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
// Ensure the client's mutex is unlocked on function exit.
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
outLoop:
|
|
||||||
for {
|
|
||||||
// Lock the mutex to update the last used client index
|
|
||||||
mutex.Lock()
|
|
||||||
startIndex := lastUsedClientIndex
|
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
|
||||||
lastUsedClientIndex = currentIndex
|
|
||||||
mutex.Unlock()
|
|
||||||
|
|
||||||
// Reorder the client to start from the last used index
|
|
||||||
reorderedClients := make([]*client.Client, 0)
|
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
|
||||||
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
|
||||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
|
||||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
cliClient = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reorderedClients = append(reorderedClients, cliClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reorderedClients) == 0 {
|
|
||||||
c.Status(429)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
locked := false
|
|
||||||
for i := 0; i < len(reorderedClients); i++ {
|
|
||||||
cliClient = reorderedClients[i]
|
|
||||||
if cliClient.RequestMutex.TryLock() {
|
|
||||||
locked = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !locked {
|
|
||||||
cliClient = h.cliClients[0]
|
|
||||||
cliClient.RequestMutex.Lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
// Send the message and receive response chunks and errors via channels.
|
|
||||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
|
|
||||||
hasFirstResponse := false
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
// Handle client disconnection.
|
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
if c.Request.Context().Err().Error() == "context canceled" {
|
|
||||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
|
||||||
cliCancel() // Cancel the backend request.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Process incoming response chunks.
|
|
||||||
case chunk, okStream := <-respChan:
|
|
||||||
if !okStream {
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
hasFirstResponse = true
|
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
|
||||||
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write([]byte("data: "))
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
// Handle errors from the backend.
|
|
||||||
case err, okError := <-errChan:
|
|
||||||
if okError {
|
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue outLoop
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send a keep-alive signal to the client.
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
if hasFirstResponse {
|
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
|
|
||||||
c.Header("Content-Type", "application/json")
|
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Lock the mutex to update the last used client index
|
|
||||||
mutex.Lock()
|
|
||||||
startIndex := lastUsedClientIndex
|
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
|
||||||
lastUsedClientIndex = currentIndex
|
|
||||||
mutex.Unlock()
|
|
||||||
|
|
||||||
// Reorder the client to start from the last used index
|
|
||||||
reorderedClients := make([]*client.Client, 0)
|
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
|
||||||
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
|
||||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
|
||||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
cliClient = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reorderedClients = append(reorderedClients, cliClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reorderedClients) == 0 {
|
|
||||||
c.Status(429)
|
|
||||||
_, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)))
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
locked := false
|
|
||||||
for i := 0; i < len(reorderedClients); i++ {
|
|
||||||
cliClient = reorderedClients[i]
|
|
||||||
if cliClient.RequestMutex.TryLock() {
|
|
||||||
locked = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !locked {
|
|
||||||
cliClient = h.cliClients[0]
|
|
||||||
cliClient.RequestMutex.Lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
|
|
||||||
if err != nil {
|
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
_, _ = c.Writer.Write(resp)
|
|
||||||
cliCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -68,6 +68,16 @@ func (s *Server) setupRoutes() {
|
|||||||
{
|
{
|
||||||
v1.GET("/models", s.handlers.Models)
|
v1.GET("/models", s.handlers.Models)
|
||||||
v1.POST("/chat/completions", s.handlers.ChatCompletions)
|
v1.POST("/chat/completions", s.handlers.ChatCompletions)
|
||||||
|
v1.POST("/messages", s.handlers.ClaudeMessages)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gemini compatible API routes
|
||||||
|
v1beta := s.engine.Group("/v1beta")
|
||||||
|
v1beta.Use(AuthMiddleware(s.cfg))
|
||||||
|
{
|
||||||
|
v1beta.GET("/models", s.handlers.GeminiModels)
|
||||||
|
v1beta.POST("/models/:action", s.handlers.GeminiHandler)
|
||||||
|
v1beta.GET("/models/:action", s.handlers.GeminiGetHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Root endpoint
|
// Root endpoint
|
||||||
@@ -81,7 +91,7 @@ func (s *Server) setupRoutes() {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
s.engine.POST("/v1internal:method", s.handlers.Internal)
|
s.engine.POST("/v1internal:method", s.handlers.CLIHandler)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,7 +150,13 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
|||||||
|
|
||||||
// Get the Authorization header
|
// Get the Authorization header
|
||||||
authHeader := c.GetHeader("Authorization")
|
authHeader := c.GetHeader("Authorization")
|
||||||
if authHeader == "" {
|
authHeaderGoogle := c.GetHeader("X-Goog-Api-Key")
|
||||||
|
authHeaderAnthropic := c.GetHeader("X-Api-Key")
|
||||||
|
|
||||||
|
// Get the API key from the query parameter
|
||||||
|
apiKeyQuery, _ := c.GetQuery("key")
|
||||||
|
|
||||||
|
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && apiKeyQuery == "" {
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||||
"error": "Missing API key",
|
"error": "Missing API key",
|
||||||
})
|
})
|
||||||
@@ -159,7 +175,7 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
|||||||
// Find the API key in the in-memory list
|
// Find the API key in the in-memory list
|
||||||
var foundKey string
|
var foundKey string
|
||||||
for i := range cfg.ApiKeys {
|
for i := range cfg.ApiKeys {
|
||||||
if cfg.ApiKeys[i] == apiKey {
|
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic || cfg.ApiKeys[i] == apiKeyQuery {
|
||||||
foundKey = cfg.ApiKeys[i]
|
foundKey = cfg.ApiKeys[i]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package translator
|
package translator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
@@ -20,10 +23,52 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
|
|||||||
modelName = modelResult.String()
|
modelName = modelResult.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the array of messages.
|
// Initialize data structures for processing conversation messages
|
||||||
|
// contents: stores the processed conversation history
|
||||||
|
// systemInstruction: stores system-level instructions separate from conversation
|
||||||
contents := make([]client.Content, 0)
|
contents := make([]client.Content, 0)
|
||||||
var systemInstruction *client.Content
|
var systemInstruction *client.Content
|
||||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||||
|
|
||||||
|
// Pre-process tool responses to create a lookup map
|
||||||
|
// This first pass collects all tool responses so they can be matched with their corresponding calls
|
||||||
|
toolItems := make(map[string]*client.FunctionResponse)
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messagesResults := messagesResult.Array()
|
||||||
|
for i := 0; i < len(messagesResults); i++ {
|
||||||
|
messageResult := messagesResults[i]
|
||||||
|
roleResult := messageResult.Get("role")
|
||||||
|
if roleResult.Type != gjson.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentResult := messageResult.Get("content")
|
||||||
|
|
||||||
|
// Extract tool responses for later matching with function calls
|
||||||
|
if roleResult.String() == "tool" {
|
||||||
|
toolCallID := messageResult.Get("tool_call_id").String()
|
||||||
|
if toolCallID != "" {
|
||||||
|
var responseData string
|
||||||
|
// Handle both string and object-based tool response formats
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
responseData = contentResult.String()
|
||||||
|
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||||
|
responseData = contentResult.Get("text").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up tool call ID by removing timestamp suffix
|
||||||
|
// This normalizes IDs for consistent matching between calls and responses
|
||||||
|
toolCallIDs := strings.Split(toolCallID, "-")
|
||||||
|
strings.Join(toolCallIDs, "-")
|
||||||
|
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
|
||||||
|
|
||||||
|
// Create function response object with normalized ID and response data
|
||||||
|
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||||
|
toolItems[toolCallID] = &functionResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if messagesResult.IsArray() {
|
if messagesResult.IsArray() {
|
||||||
messagesResults := messagesResult.Array()
|
messagesResults := messagesResult.Array()
|
||||||
for i := 0; i < len(messagesResults); i++ {
|
for i := 0; i < len(messagesResults); i++ {
|
||||||
@@ -91,45 +136,62 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
|
|||||||
}
|
}
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||||
}
|
}
|
||||||
// Assistant messages can contain text or tool calls.
|
// Assistant messages can contain text responses or tool calls
|
||||||
|
// In the internal format, assistant messages are converted to "model" role
|
||||||
case "assistant":
|
case "assistant":
|
||||||
if contentResult.Type == gjson.String {
|
if contentResult.Type == gjson.String {
|
||||||
|
// Simple text response from the assistant
|
||||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||||
// Handle tool calls made by the assistant.
|
// Handle complex tool calls made by the assistant
|
||||||
|
// This processes function calls and matches them with their responses
|
||||||
|
functionIDs := make([]string, 0)
|
||||||
toolCallsResult := messageResult.Get("tool_calls")
|
toolCallsResult := messageResult.Get("tool_calls")
|
||||||
if toolCallsResult.IsArray() {
|
if toolCallsResult.IsArray() {
|
||||||
|
parts := make([]client.Part, 0)
|
||||||
tcsResult := toolCallsResult.Array()
|
tcsResult := toolCallsResult.Array()
|
||||||
|
|
||||||
|
// Process each tool call in the assistant's message
|
||||||
for j := 0; j < len(tcsResult); j++ {
|
for j := 0; j < len(tcsResult); j++ {
|
||||||
tcResult := tcsResult[j]
|
tcResult := tcsResult[j]
|
||||||
|
|
||||||
|
// Extract function call details
|
||||||
|
functionID := tcResult.Get("id").String()
|
||||||
|
functionIDs = append(functionIDs, functionID)
|
||||||
|
|
||||||
functionName := tcResult.Get("function.name").String()
|
functionName := tcResult.Get("function.name").String()
|
||||||
functionArgs := tcResult.Get("function.arguments").String()
|
functionArgs := tcResult.Get("function.arguments").String()
|
||||||
|
|
||||||
|
// Parse function arguments from JSON string to map
|
||||||
var args map[string]any
|
var args map[string]any
|
||||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||||
contents = append(contents, client.Content{
|
parts = append(parts, client.Part{
|
||||||
Role: "model", Parts: []client.Part{{
|
|
||||||
FunctionCall: &client.FunctionCall{
|
FunctionCall: &client.FunctionCall{
|
||||||
Name: functionName,
|
Name: functionName,
|
||||||
Args: args,
|
Args: args,
|
||||||
},
|
},
|
||||||
}},
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the model's function calls to the conversation
|
||||||
|
if len(parts) > 0 {
|
||||||
|
contents = append(contents, client.Content{
|
||||||
|
Role: "model", Parts: parts,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a separate tool response message with the collected responses
|
||||||
|
// This matches function calls with their corresponding responses
|
||||||
|
toolParts := make([]client.Part, 0)
|
||||||
|
for _, functionID := range functionIDs {
|
||||||
|
if functionResponse, ok := toolItems[functionID]; ok {
|
||||||
|
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Tool messages contain the output of a tool call.
|
// Add the tool responses as a separate message in the conversation
|
||||||
case "tool":
|
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
|
||||||
toolCallID := messageResult.Get("tool_call_id").String()
|
}
|
||||||
if toolCallID != "" {
|
|
||||||
var responseData string
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
responseData = contentResult.String()
|
|
||||||
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
|
||||||
responseData = contentResult.Get("text").String()
|
|
||||||
}
|
}
|
||||||
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
|
|
||||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,3 +222,323 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
|
|||||||
|
|
||||||
return modelName, systemInstruction, contents, tools
|
return modelName, systemInstruction, contents, tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FunctionCallGroup represents a group of function calls and their responses
|
||||||
|
type FunctionCallGroup struct {
|
||||||
|
ModelContent map[string]interface{}
|
||||||
|
FunctionCalls []gjson.Result
|
||||||
|
ResponsesNeeded int
|
||||||
|
}
|
||||||
|
|
||||||
|
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||||
|
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||||
|
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||||
|
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
|
||||||
|
// and their responses are properly associated and structured.
|
||||||
|
func FixCLIToolResponse(input string) (string, error) {
|
||||||
|
// Parse the input JSON to extract the conversation structure
|
||||||
|
parsed := gjson.Parse(input)
|
||||||
|
|
||||||
|
// Extract the contents array which contains the conversation messages
|
||||||
|
contents := parsed.Get("request.contents")
|
||||||
|
if !contents.Exists() {
|
||||||
|
return input, fmt.Errorf("contents not found in input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize data structures for processing and grouping
|
||||||
|
var newContents []interface{} // Final processed contents array
|
||||||
|
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||||
|
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||||
|
|
||||||
|
// Process each content object in the conversation
|
||||||
|
// This iterates through messages and groups function calls with their responses
|
||||||
|
contents.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
role := value.Get("role").String()
|
||||||
|
parts := value.Get("parts")
|
||||||
|
|
||||||
|
// Check if this content has function responses
|
||||||
|
var responsePartsInThisContent []gjson.Result
|
||||||
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionResponse").Exists() {
|
||||||
|
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// If this content has function responses, collect them
|
||||||
|
if len(responsePartsInThisContent) > 0 {
|
||||||
|
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||||
|
|
||||||
|
// Check if any pending groups can be satisfied
|
||||||
|
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||||
|
group := pendingGroups[i]
|
||||||
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||||
|
// Take the needed responses for this group
|
||||||
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
|
// Create merged function response content
|
||||||
|
var responseParts []interface{}
|
||||||
|
for _, response := range groupResponses {
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseParts = append(responseParts, responseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
functionResponseContent := map[string]interface{}{
|
||||||
|
"parts": responseParts,
|
||||||
|
"role": "function",
|
||||||
|
}
|
||||||
|
newContents = append(newContents, functionResponseContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove this group as it's been satisfied
|
||||||
|
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true // Skip adding this content, responses are merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a model with function calls, create a new group
|
||||||
|
if role == "model" {
|
||||||
|
var functionCallsInThisModel []gjson.Result
|
||||||
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionCall").Exists() {
|
||||||
|
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(functionCallsInThisModel) > 0 {
|
||||||
|
// Add the model content
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
|
||||||
|
// Create a new group for tracking responses
|
||||||
|
group := &FunctionCallGroup{
|
||||||
|
ModelContent: contentMap,
|
||||||
|
FunctionCalls: functionCallsInThisModel,
|
||||||
|
ResponsesNeeded: len(functionCallsInThisModel),
|
||||||
|
}
|
||||||
|
pendingGroups = append(pendingGroups, group)
|
||||||
|
} else {
|
||||||
|
// Regular model content without function calls
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-model content (user, etc.)
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle any remaining pending groups with remaining responses
|
||||||
|
for _, group := range pendingGroups {
|
||||||
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||||
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
|
var responseParts []interface{}
|
||||||
|
for _, response := range groupResponses {
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseParts = append(responseParts, responseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
functionResponseContent := map[string]interface{}{
|
||||||
|
"parts": responseParts,
|
||||||
|
"role": "function",
|
||||||
|
}
|
||||||
|
newContents = append(newContents, functionResponseContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the original JSON with the new contents
|
||||||
|
result := input
|
||||||
|
newContentsJSON, _ := json.Marshal(newContents)
|
||||||
|
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||||
|
var pathsToDelete []string
|
||||||
|
root := gjson.ParseBytes(rawJson)
|
||||||
|
walk(root, "", "additionalProperties", &pathsToDelete)
|
||||||
|
walk(root, "", "$schema", &pathsToDelete)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, p := range pathsToDelete {
|
||||||
|
rawJson, err = sjson.DeleteBytes(rawJson, p)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||||
|
|
||||||
|
// log.Debug(string(rawJson))
|
||||||
|
modelName := "gemini-2.5-pro"
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
if modelResult.Type == gjson.String {
|
||||||
|
modelName = modelResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := make([]client.Content, 0)
|
||||||
|
|
||||||
|
var systemInstruction *client.Content
|
||||||
|
|
||||||
|
systemResult := gjson.GetBytes(rawJson, "system")
|
||||||
|
if systemResult.IsArray() {
|
||||||
|
systemResults := systemResult.Array()
|
||||||
|
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||||
|
for i := 0; i < len(systemResults); i++ {
|
||||||
|
systemPromptResult := systemResults[i]
|
||||||
|
systemTypePromptResult := systemPromptResult.Get("type")
|
||||||
|
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||||
|
systemPrompt := systemPromptResult.Get("text").String()
|
||||||
|
systemPart := client.Part{Text: systemPrompt}
|
||||||
|
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(systemInstruction.Parts) == 0 {
|
||||||
|
systemInstruction = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messageResults := messagesResult.Array()
|
||||||
|
for i := 0; i < len(messageResults); i++ {
|
||||||
|
messageResult := messageResults[i]
|
||||||
|
roleResult := messageResult.Get("role")
|
||||||
|
if roleResult.Type != gjson.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role := roleResult.String()
|
||||||
|
if role == "assistant" {
|
||||||
|
role = "model"
|
||||||
|
}
|
||||||
|
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||||
|
|
||||||
|
contentsResult := messageResult.Get("content")
|
||||||
|
if contentsResult.IsArray() {
|
||||||
|
contentResults := contentsResult.Array()
|
||||||
|
for j := 0; j < len(contentResults); j++ {
|
||||||
|
contentResult := contentResults[j]
|
||||||
|
contentTypeResult := contentResult.Get("type")
|
||||||
|
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||||
|
prompt := contentResult.Get("text").String()
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||||
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||||
|
functionName := contentResult.Get("name").String()
|
||||||
|
functionArgs := contentResult.Get("input").String()
|
||||||
|
var args map[string]any
|
||||||
|
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||||
|
FunctionCall: &client.FunctionCall{
|
||||||
|
Name: functionName,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||||
|
toolCallID := contentResult.Get("tool_use_id").String()
|
||||||
|
if toolCallID != "" {
|
||||||
|
funcName := toolCallID
|
||||||
|
toolCallIDs := strings.Split(toolCallID, "-")
|
||||||
|
if len(toolCallIDs) > 1 {
|
||||||
|
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||||
|
}
|
||||||
|
responseData := contentResult.Get("content").String()
|
||||||
|
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contents = append(contents, clientContent)
|
||||||
|
} else if contentsResult.Type == gjson.String {
|
||||||
|
prompt := contentsResult.String()
|
||||||
|
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tools []client.ToolDeclaration
|
||||||
|
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||||
|
if toolsResult.IsArray() {
|
||||||
|
tools = make([]client.ToolDeclaration, 1)
|
||||||
|
tools[0].FunctionDeclarations = make([]any, 0)
|
||||||
|
toolsResults := toolsResult.Array()
|
||||||
|
for i := 0; i < len(toolsResults); i++ {
|
||||||
|
toolResult := toolsResults[i]
|
||||||
|
inputSchemaResult := toolResult.Get("input_schema")
|
||||||
|
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||||
|
inputSchema := inputSchemaResult.Raw
|
||||||
|
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
|
||||||
|
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
|
||||||
|
|
||||||
|
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||||
|
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
|
||||||
|
var toolDeclaration any
|
||||||
|
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||||
|
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tools = make([]client.ToolDeclaration, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelName, systemInstruction, contents, tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
|
||||||
|
switch value.Type {
|
||||||
|
case gjson.JSON:
|
||||||
|
value.ForEach(func(key, val gjson.Result) bool {
|
||||||
|
var childPath string
|
||||||
|
if path == "" {
|
||||||
|
childPath = key.String()
|
||||||
|
} else {
|
||||||
|
childPath = path + "." + key.String()
|
||||||
|
}
|
||||||
|
if key.String() == field {
|
||||||
|
*pathsToDelete = append(*pathsToDelete, childPath)
|
||||||
|
}
|
||||||
|
walk(val, childPath, field, pathsToDelete)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package translator
|
package translator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -62,7 +64,11 @@ func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
||||||
|
if partsResult.IsArray() {
|
||||||
|
partResults := partsResult.Array()
|
||||||
|
for i := 0; i < len(partResults); i++ {
|
||||||
|
partResult := partResults[i]
|
||||||
partTextResult := partResult.Get("text")
|
partTextResult := partResult.Get("text")
|
||||||
functionCallResult := partResult.Get("functionCall")
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
@@ -76,18 +82,22 @@ func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) st
|
|||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
} else if functionCallResult.Exists() {
|
} else if functionCallResult.Exists() {
|
||||||
// Handle function call content.
|
// Handle function call content.
|
||||||
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
|
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||||
|
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
}
|
}
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
|
||||||
} else {
|
}
|
||||||
// If no usable content is found, return an empty string.
|
}
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return template
|
return template
|
||||||
@@ -163,7 +173,7 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey
|
|||||||
}
|
}
|
||||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
fcName := functionCallResult.Get("name").String()
|
fcName := functionCallResult.Get("name").String()
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
@@ -179,3 +189,194 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey
|
|||||||
|
|
||||||
return template
|
return template
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
||||||
|
// This function implements a complex state machine that translates backend client responses
|
||||||
|
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
|
// and handles state transitions between content blocks, thinking processes, and function calls.
|
||||||
|
//
|
||||||
|
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
||||||
|
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
||||||
|
func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
||||||
|
// Normalize the response format for different API key types
|
||||||
|
// Generative Language API keys have a different response structure
|
||||||
|
if isGlAPIKey {
|
||||||
|
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track whether tools are being used in this response chunk
|
||||||
|
usedTool := false
|
||||||
|
output := ""
|
||||||
|
|
||||||
|
// Initialize the streaming session with a message_start event
|
||||||
|
// This is only sent for the very first response chunk
|
||||||
|
if !hasFirstResponse {
|
||||||
|
output = "event: message_start\n"
|
||||||
|
|
||||||
|
// Create the initial message structure with default values
|
||||||
|
// This follows the Claude API specification for streaming message initialization
|
||||||
|
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||||
|
|
||||||
|
// Override default values with actual response metadata if available
|
||||||
|
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||||
|
}
|
||||||
|
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||||
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String())
|
||||||
|
}
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the response parts array from the backend client
|
||||||
|
// Each part can contain text content, thinking content, or function calls
|
||||||
|
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
||||||
|
if partsResult.IsArray() {
|
||||||
|
partResults := partsResult.Array()
|
||||||
|
for i := 0; i < len(partResults); i++ {
|
||||||
|
partResult := partResults[i]
|
||||||
|
|
||||||
|
// Extract the different types of content from each part
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
|
// Handle text content (both regular content and thinking)
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
// Process thinking content (internal reasoning)
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
// Continue existing thinking block
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
} else {
|
||||||
|
// Transition from another state to thinking
|
||||||
|
// First, close any existing content block
|
||||||
|
if *responseType != 0 {
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new thinking content block
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
*responseType = 2 // Set state to thinking
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Process regular text content (user-visible output)
|
||||||
|
// Continue existing text block
|
||||||
|
if *responseType == 1 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
} else {
|
||||||
|
// Transition from another state to text content
|
||||||
|
// First, close any existing content block
|
||||||
|
if *responseType != 0 {
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new text content block
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
*responseType = 1 // Set state to content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Handle function/tool calls from the AI model
|
||||||
|
// This processes tool usage requests and formats them for Claude API compatibility
|
||||||
|
usedTool = true
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
|
||||||
|
// Handle state transitions when switching to function calls
|
||||||
|
// Close any existing function call block first
|
||||||
|
if *responseType == 3 {
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
*responseType = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for thinking state transition
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close any other existing content block
|
||||||
|
if *responseType != 0 {
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new tool use content block
|
||||||
|
// This creates the structure for a function call in Claude format
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
|
||||||
|
// Create the tool use block with unique ID and function details
|
||||||
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
|
||||||
|
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||||
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
}
|
||||||
|
*responseType = 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
||||||
|
if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) {
|
||||||
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
|
||||||
|
output = output + "event: message_delta\n"
|
||||||
|
output = output + `data: `
|
||||||
|
|
||||||
|
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
|
if usedTool {
|
||||||
|
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
|
}
|
||||||
|
|
||||||
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||||
|
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||||
|
|
||||||
|
output = output + template + "\n\n\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|||||||
@@ -168,11 +168,12 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
|
|||||||
codeChan := make(chan string)
|
codeChan := make(chan string)
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
|
|
||||||
// Create a new HTTP server.
|
// Create a new HTTP server with its own multiplexer.
|
||||||
server := &http.Server{Addr: "localhost:8085"}
|
mux := http.NewServeMux()
|
||||||
|
server := &http.Server{Addr: ":8085", Handler: mux}
|
||||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
||||||
|
|
||||||
http.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.URL.Query().Get("error"); err != "" {
|
if err := r.URL.Query().Get("error"); err != "" {
|
||||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ const (
|
|||||||
apiVersion = "v1internal"
|
apiVersion = "v1internal"
|
||||||
pluginVersion = "0.1.9"
|
pluginVersion = "0.1.9"
|
||||||
|
|
||||||
glEndPoint = "https://generativelanguage.googleapis.com/"
|
glEndPoint = "https://generativelanguage.googleapis.com"
|
||||||
glApiVersion = "v1beta"
|
glApiVersion = "v1beta"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -241,7 +241,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
// APIRequest handles making requests to the CLI API endpoints.
|
// APIRequest handles making requests to the CLI API endpoints.
|
||||||
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
|
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||||
var jsonBody []byte
|
var jsonBody []byte
|
||||||
var err error
|
var err error
|
||||||
if byteBody, ok := body.([]byte); ok {
|
if byteBody, ok := body.([]byte); ok {
|
||||||
@@ -257,14 +257,26 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
|||||||
if c.glAPIKey == "" {
|
if c.glAPIKey == "" {
|
||||||
// Add alt=sse for streaming
|
// Add alt=sse for streaming
|
||||||
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||||
if stream {
|
if alt == "" && stream {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
if alt != "" {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if endpoint == "countTokens" {
|
||||||
|
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||||
|
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
||||||
} else {
|
} else {
|
||||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
||||||
if stream {
|
if alt == "" && stream {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
if alt != "" {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
|
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
|
||||||
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
|
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
|
||||||
@@ -274,8 +286,10 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
|||||||
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
|
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// log.Debug(string(jsonBody))
|
// log.Debug(string(jsonBody))
|
||||||
|
// log.Debug(url)
|
||||||
reqBody := bytes.NewBuffer(jsonBody)
|
reqBody := bytes.NewBuffer(jsonBody)
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||||
@@ -311,6 +325,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
// log.Debug(string(jsonBody))
|
||||||
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,7 +406,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
|
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -411,90 +426,209 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
// SendMessageStream handles streaming conversational turns with comprehensive parameter management.
|
||||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
|
||||||
|
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
|
||||||
|
// one for streaming response data and another for error handling.
|
||||||
|
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
|
// Define the data prefix used in Server-Sent Events streaming format
|
||||||
dataTag := []byte("data: ")
|
dataTag := []byte("data: ")
|
||||||
|
|
||||||
|
// Create channels for asynchronous communication
|
||||||
|
// errChan: delivers error messages during streaming
|
||||||
|
// dataChan: delivers response data chunks
|
||||||
errChan := make(chan *ErrorMessage)
|
errChan := make(chan *ErrorMessage)
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
|
|
||||||
|
// Launch a goroutine to handle the streaming process asynchronously
|
||||||
|
// This allows the function to return immediately while processing continues in the background
|
||||||
go func() {
|
go func() {
|
||||||
|
// Ensure channels are properly closed when the goroutine exits
|
||||||
defer close(errChan)
|
defer close(errChan)
|
||||||
defer close(dataChan)
|
defer close(dataChan)
|
||||||
|
|
||||||
|
// Configure thinking/reasoning capabilities
|
||||||
|
// Default to including thoughts unless explicitly disabled
|
||||||
|
includeThoughtsFlag := true
|
||||||
|
if len(includeThoughts) > 0 {
|
||||||
|
includeThoughtsFlag = includeThoughts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the base request structure for the Gemini API
|
||||||
|
// This includes conversation contents and generation configuration
|
||||||
request := GenerateContentRequest{
|
request := GenerateContentRequest{
|
||||||
Contents: contents,
|
Contents: contents,
|
||||||
GenerationConfig: GenerationConfig{
|
GenerationConfig: GenerationConfig{
|
||||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||||
IncludeThoughts: true,
|
IncludeThoughts: includeThoughtsFlag,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add system instructions if provided
|
||||||
|
// System instructions guide the AI's behavior and response style
|
||||||
request.SystemInstruction = systemInstruction
|
request.SystemInstruction = systemInstruction
|
||||||
|
|
||||||
|
// Add available tools for function calling capabilities
|
||||||
|
// Tools allow the AI to perform actions beyond text generation
|
||||||
request.Tools = tools
|
request.Tools = tools
|
||||||
|
|
||||||
|
// Construct the complete request body with project context
|
||||||
|
// The project ID is essential for proper API routing and billing
|
||||||
requestBody := map[string]interface{}{
|
requestBody := map[string]interface{}{
|
||||||
"project": c.GetProjectID(), // Assuming ProjectID is available
|
"project": c.GetProjectID(), // Project ID for API routing and quota management
|
||||||
"request": request,
|
"request": request,
|
||||||
"model": model,
|
"model": model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serialize the request body to JSON for API transmission
|
||||||
byteRequestBody, _ := json.Marshal(requestBody)
|
byteRequestBody, _ := json.Marshal(requestBody)
|
||||||
|
|
||||||
// log.Debug(string(byteRequestBody))
|
// Parse and configure reasoning effort levels from the original request
|
||||||
|
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
|
||||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||||
if reasoningEffortResult.String() == "none" {
|
if reasoningEffortResult.String() == "none" {
|
||||||
|
// Disable thinking entirely for fastest responses
|
||||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||||
} else if reasoningEffortResult.String() == "auto" {
|
} else if reasoningEffortResult.String() == "auto" {
|
||||||
|
// Let the model decide the appropriate thinking budget automatically
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
} else if reasoningEffortResult.String() == "low" {
|
} else if reasoningEffortResult.String() == "low" {
|
||||||
|
// Minimal thinking for simple tasks (1KB thinking budget)
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
} else if reasoningEffortResult.String() == "medium" {
|
} else if reasoningEffortResult.String() == "medium" {
|
||||||
|
// Moderate thinking for complex tasks (8KB thinking budget)
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
} else if reasoningEffortResult.String() == "high" {
|
} else if reasoningEffortResult.String() == "high" {
|
||||||
|
// Maximum thinking for very complex tasks (24KB thinking budget)
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||||
} else {
|
} else {
|
||||||
|
// Default to automatic thinking budget if no specific level is provided
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure temperature parameter for response randomness control
|
||||||
|
// Temperature affects the creativity vs consistency trade-off in responses
|
||||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure top-p parameter for nucleus sampling
|
||||||
|
// Controls the cumulative probability threshold for token selection
|
||||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure top-k parameter for limiting token candidates
|
||||||
|
// Restricts the model to consider only the top K most likely tokens
|
||||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
// log.Debug(string(byteRequestBody))
|
// Initialize model name for quota management and potential fallback
|
||||||
modelName := model
|
modelName := model
|
||||||
var stream io.ReadCloser
|
var stream io.ReadCloser
|
||||||
|
|
||||||
|
// Quota management and model fallback loop
|
||||||
|
// This loop handles quota exceeded scenarios and automatic model switching
|
||||||
for {
|
for {
|
||||||
|
// Check if the current model has exceeded its quota
|
||||||
if c.isModelQuotaExceeded(modelName) {
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
// Attempt to switch to a preview model if configured and using account auth
|
||||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
modelName = c.getPreviewModel(model)
|
modelName = c.getPreviewModel(model)
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
// Update the request body with the new model name
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||||
continue
|
continue // Retry with the preview model
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// If no fallback is available, return a quota exceeded error
|
||||||
errChan <- &ErrorMessage{
|
errChan <- &ErrorMessage{
|
||||||
StatusCode: 429,
|
StatusCode: 429,
|
||||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to establish a streaming connection with the API
|
||||||
var err *ErrorMessage
|
var err *ErrorMessage
|
||||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
|
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
|
||||||
|
if err != nil {
|
||||||
|
// Handle quota exceeded errors by marking the model and potentially retrying
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded
|
||||||
|
// If preview model switching is enabled, retry the loop
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Forward other errors to the error channel
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Clear any previous quota exceeded status for this model
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
break // Successfully established connection, exit the retry loop
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the streaming response using a scanner
|
||||||
|
// This handles the Server-Sent Events format from the API
|
||||||
|
scanner := bufio.NewScanner(stream)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
// Filter and forward only data lines (those prefixed with "data: ")
|
||||||
|
// This extracts the actual JSON content from the SSE format
|
||||||
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
|
dataChan <- line[6:] // Remove "data: " prefix and send the JSON content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle any scanning errors that occurred during stream processing
|
||||||
|
if errScanner := scanner.Err(); errScanner != nil {
|
||||||
|
// Send a 500 Internal Server Error for scanning failures
|
||||||
|
errChan <- &ErrorMessage{500, errScanner}
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the stream is properly closed to prevent resource leaks
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return the channels immediately for asynchronous communication
|
||||||
|
// The caller can read from these channels while the goroutine processes the request
|
||||||
|
return dataChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawTokenCount handles a token count.
|
||||||
|
func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
model := modelResult.String()
|
||||||
|
modelName := model
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -503,38 +637,22 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
errChan <- err
|
return nil, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
break
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
|
if errReadAll != nil {
|
||||||
|
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
}
|
}
|
||||||
|
return bodyBytes, nil
|
||||||
scanner := bufio.NewScanner(stream)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Bytes()
|
|
||||||
// log.Printf("Received stream chunk: %s", line)
|
|
||||||
if bytes.HasPrefix(line, dataTag) {
|
|
||||||
dataChan <- line[6:]
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if errScanner := scanner.Err(); errScanner != nil {
|
|
||||||
// log.Println(err)
|
|
||||||
errChan <- &ErrorMessage{500, errScanner}
|
|
||||||
_ = stream.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = stream.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return dataChan, errChan
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendRawMessage handles a single conversational turn, including tool calls.
|
// SendRawMessage handles a single conversational turn, including tool calls.
|
||||||
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) {
|
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
|
if c.glAPIKey == "" {
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
model := modelResult.String()
|
model := modelResult.String()
|
||||||
@@ -555,7 +673,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false)
|
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -576,7 +694,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
||||||
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) {
|
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
dataTag := []byte("data: ")
|
dataTag := []byte("data: ")
|
||||||
errChan := make(chan *ErrorMessage)
|
errChan := make(chan *ErrorMessage)
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
@@ -584,7 +702,9 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
|
|||||||
defer close(errChan)
|
defer close(errChan)
|
||||||
defer close(dataChan)
|
defer close(dataChan)
|
||||||
|
|
||||||
|
if c.glAPIKey == "" {
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
model := modelResult.String()
|
model := modelResult.String()
|
||||||
@@ -607,7 +727,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err *ErrorMessage
|
var err *ErrorMessage
|
||||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true)
|
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -623,6 +743,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if alt == "" {
|
||||||
scanner := bufio.NewScanner(stream)
|
scanner := bufio.NewScanner(stream)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
@@ -637,7 +758,17 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
data, err := io.ReadAll(stream)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- &ErrorMessage{500, err}
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
_ = stream.Close()
|
||||||
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return dataChan, errChan
|
return dataChan, errChan
|
||||||
@@ -689,7 +820,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
|||||||
// A simple request to test the API endpoint.
|
// A simple request to test the API endpoint.
|
||||||
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
|
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
|
||||||
|
|
||||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
|
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||||
if err.StatusCode == 403 {
|
if err.StatusCode == 403 {
|
||||||
@@ -706,6 +837,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJson)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, err.Error
|
return false, err.Error
|
||||||
|
|||||||
@@ -73,6 +73,7 @@ func DoLogin(cfg *config.Config, projectID string) {
|
|||||||
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
|
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
|
||||||
// will have already printed instructions, so we can just exit.
|
// will have already printed instructions, so we can just exit.
|
||||||
if !isChecked {
|
if !isChecked {
|
||||||
|
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,10 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -69,33 +67,12 @@ func StartService(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(cfg.GlAPIKey) > 0 {
|
if len(cfg.GlAPIKey) > 0 {
|
||||||
var transport *http.Transport
|
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
||||||
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
|
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
|
||||||
if errParse == nil {
|
if errSetProxy != nil {
|
||||||
if proxyURL.Scheme == "socks5" {
|
log.Fatalf("set proxy failed: %v", errSetProxy)
|
||||||
username := proxyURL.User.Username()
|
|
||||||
password, _ := proxyURL.User.Password()
|
|
||||||
proxyAuth := &proxy.Auth{User: username, Password: password}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
|
||||||
if errSOCKS5 != nil {
|
|
||||||
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
|
||||||
}
|
|
||||||
transport = &http.Transport{
|
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
return dialer.Dial(network, addr)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
|
||||||
// Handle HTTP/HTTPS proxy.
|
|
||||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
|
||||||
httpClient := &http.Client{}
|
|
||||||
if transport != nil {
|
|
||||||
httpClient.Transport = transport
|
|
||||||
}
|
|
||||||
log.Debug("Initializing with Generative Language API key...")
|
log.Debug("Initializing with Generative Language API key...")
|
||||||
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
||||||
cliClients = append(cliClients, cliClient)
|
cliClients = append(cliClients, cliClient)
|
||||||
|
|||||||
37
internal/util/proxy.go
Normal file
37
internal/util/proxy.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
|
||||||
|
var transport *http.Transport
|
||||||
|
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
|
||||||
|
if errParse == nil {
|
||||||
|
if proxyURL.Scheme == "socks5" {
|
||||||
|
username := proxyURL.User.Username()
|
||||||
|
password, _ := proxyURL.User.Password()
|
||||||
|
proxyAuth := &proxy.Auth{User: username, Password: password}
|
||||||
|
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||||
|
if errSOCKS5 != nil {
|
||||||
|
return nil, errSOCKS5
|
||||||
|
}
|
||||||
|
transport = &http.Transport{
|
||||||
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
return dialer.Dial(network, addr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||||
|
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if transport != nil {
|
||||||
|
httpClient.Transport = transport
|
||||||
|
}
|
||||||
|
return httpClient, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user