Compare commits
49 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f93dde43df | ||
|
|
84a2874c7b | ||
|
|
c10f8ae2e2 | ||
|
|
17363edf25 | ||
|
|
486cd4c343 | ||
|
|
25feceb783 | ||
|
|
d26752250d | ||
|
|
b15453c369 | ||
|
|
04ba8c8bc3 | ||
|
|
6570692291 | ||
|
|
13aa5b3375 | ||
|
|
6d8de0ade4 | ||
|
|
1587ff5e74 | ||
|
|
f033d3a6df | ||
|
|
145e0e0b5d | ||
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
10b824fcac | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
7dccc7ba2f | ||
|
|
70c90687fd | ||
|
|
8144ffd5c8 | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
6b45d311ec | ||
|
|
d54de441d3 | ||
|
|
1821bf7051 | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
c89d19b300 | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
74b862d8b8 | ||
|
|
5c817a9b42 |
343
API_USAGE.md
Normal file
343
API_USAGE.md
Normal file
@@ -0,0 +1,343 @@
|
||||
# CLIProxyAPI 호출 가이드
|
||||
|
||||
## 접속 정보
|
||||
|
||||
| 항목 | 값 |
|
||||
|------|-----|
|
||||
| 외부 URL | `https://cliproxy.gru.farm` |
|
||||
| 내부 URL | `http://192.168.0.17:8317` |
|
||||
| API 키 | `Jinie4eva!` |
|
||||
| 인증 방식 | `Authorization: Bearer <API키>` |
|
||||
|
||||
## 엔드포인트
|
||||
|
||||
| 용도 | 경로 |
|
||||
|------|------|
|
||||
| Claude 네이티브 (권장) | `/api/provider/claude/v1/messages` |
|
||||
| OpenAI 호환 | `/v1/chat/completions` |
|
||||
| 모델 목록 | `/v1/models` |
|
||||
|
||||
## 사용 가능한 모델
|
||||
|
||||
| 모델 ID | 설명 |
|
||||
|---------|------|
|
||||
| `claude-sonnet-4-6` | Claude Sonnet 4.6 (최신, 권장) |
|
||||
| `claude-opus-4-6` | Claude Opus 4.6 (최고 성능) |
|
||||
| `claude-sonnet-4-5-20250929` | Claude Sonnet 4.5 |
|
||||
| `claude-opus-4-5-20251101` | Claude Opus 4.5 |
|
||||
| `claude-haiku-4-5-20251001` | Claude Haiku 4.5 (경량/빠름) |
|
||||
| `claude-sonnet-4-20250514` | Claude Sonnet 4 |
|
||||
| `claude-opus-4-20250514` | Claude Opus 4 |
|
||||
| `claude-3-7-sonnet-20250219` | Claude 3.7 Sonnet |
|
||||
| `claude-3-5-haiku-20241022` | Claude 3.5 Haiku |
|
||||
|
||||
---
|
||||
|
||||
## 1. curl
|
||||
|
||||
### 기본 호출
|
||||
|
||||
```bash
|
||||
curl -X POST https://cliproxy.gru.farm/api/provider/claude/v1/messages \
|
||||
-H "Authorization: Bearer Jinie4eva!" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{"role": "user", "content": "안녕! 간단히 소개해줘"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 스트리밍
|
||||
|
||||
```bash
|
||||
curl -X POST https://cliproxy.gru.farm/api/provider/claude/v1/messages \
|
||||
-H "Authorization: Bearer Jinie4eva!" \
|
||||
-H "anthropic-version: 2023-06-01" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "claude-sonnet-4-6",
|
||||
"max_tokens": 1024,
|
||||
"stream": true,
|
||||
"messages": [
|
||||
{"role": "user", "content": "안녕!"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
### 모델 목록 조회
|
||||
|
||||
```bash
|
||||
curl https://cliproxy.gru.farm/v1/models \
|
||||
-H "Authorization: Bearer Jinie4eva!"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. Python — Anthropic SDK
|
||||
|
||||
### 설치
|
||||
|
||||
```bash
|
||||
pip install anthropic
|
||||
```
|
||||
|
||||
### 기본 호출
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = Anthropic(
|
||||
base_url="https://cliproxy.gru.farm/api/provider/claude",
|
||||
api_key="Jinie4eva!"
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-6",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{"role": "user", "content": "안녕! 간단히 소개해줘"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.content[0].text)
|
||||
```
|
||||
|
||||
### 스트리밍
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = Anthropic(
|
||||
base_url="https://cliproxy.gru.farm/api/provider/claude",
|
||||
api_key="Jinie4eva!"
|
||||
)
|
||||
|
||||
with client.messages.stream(
|
||||
model="claude-sonnet-4-6",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{"role": "user", "content": "안녕! 간단히 소개해줘"}
|
||||
]
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
print(text, end="", flush=True)
|
||||
```
|
||||
|
||||
### 시스템 프롬프트 + 멀티턴
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
|
||||
client = Anthropic(
|
||||
base_url="https://cliproxy.gru.farm/api/provider/claude",
|
||||
api_key="Jinie4eva!"
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-6",
|
||||
max_tokens=1024,
|
||||
system="당신은 친절한 한국어 AI 어시스턴트입니다.",
|
||||
messages=[
|
||||
{"role": "user", "content": "파이썬이 뭐야?"},
|
||||
{"role": "assistant", "content": "파이썬은 프로그래밍 언어입니다."},
|
||||
{"role": "user", "content": "그럼 자바스크립트는?"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.content[0].text)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Python — OpenAI SDK (호환 모드)
|
||||
|
||||
### 설치
|
||||
|
||||
```bash
|
||||
pip install openai
|
||||
```
|
||||
|
||||
### 기본 호출
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="https://cliproxy.gru.farm/v1",
|
||||
api_key="Jinie4eva!"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=[
|
||||
{"role": "user", "content": "안녕!"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
### 스트리밍
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="https://cliproxy.gru.farm/v1",
|
||||
api_key="Jinie4eva!"
|
||||
)
|
||||
|
||||
stream = client.chat.completions.create(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=[{"role": "user", "content": "안녕!"}],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
print(chunk.choices[0].delta.content, end="", flush=True)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Node.js — Anthropic SDK
|
||||
|
||||
### 설치
|
||||
|
||||
```bash
|
||||
npm install @anthropic-ai/sdk
|
||||
```
|
||||
|
||||
### 기본 호출
|
||||
|
||||
```javascript
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const client = new Anthropic({
|
||||
baseURL: "https://cliproxy.gru.farm/api/provider/claude",
|
||||
apiKey: "Jinie4eva!",
|
||||
});
|
||||
|
||||
const response = await client.messages.create({
|
||||
model: "claude-sonnet-4-6",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "안녕!" }],
|
||||
});
|
||||
|
||||
console.log(response.content[0].text);
|
||||
```
|
||||
|
||||
### 스트리밍
|
||||
|
||||
```javascript
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
|
||||
const client = new Anthropic({
|
||||
baseURL: "https://cliproxy.gru.farm/api/provider/claude",
|
||||
apiKey: "Jinie4eva!",
|
||||
});
|
||||
|
||||
const stream = client.messages.stream({
|
||||
model: "claude-sonnet-4-6",
|
||||
max_tokens: 1024,
|
||||
messages: [{ role: "user", content: "안녕!" }],
|
||||
});
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (
|
||||
chunk.type === "content_block_delta" &&
|
||||
chunk.delta.type === "text_delta"
|
||||
) {
|
||||
process.stdout.write(chunk.delta.text);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Node.js — OpenAI SDK (호환 모드)
|
||||
|
||||
### 설치
|
||||
|
||||
```bash
|
||||
npm install openai
|
||||
```
|
||||
|
||||
### 기본 호출
|
||||
|
||||
```javascript
|
||||
import OpenAI from "openai";
|
||||
|
||||
const client = new OpenAI({
|
||||
baseURL: "https://cliproxy.gru.farm/v1",
|
||||
apiKey: "Jinie4eva!",
|
||||
});
|
||||
|
||||
const response = await client.chat.completions.create({
|
||||
model: "claude-sonnet-4-6",
|
||||
messages: [{ role: "user", content: "안녕!" }],
|
||||
});
|
||||
|
||||
console.log(response.choices[0].message.content);
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Claude Code CLI
|
||||
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=https://cliproxy.gru.farm/api/provider/claude
|
||||
export ANTHROPIC_API_KEY=Jinie4eva!
|
||||
|
||||
claude
|
||||
```
|
||||
|
||||
영구 적용 (`~/.zshrc` 또는 `~/.bashrc`):
|
||||
|
||||
```bash
|
||||
echo 'export ANTHROPIC_BASE_URL=https://cliproxy.gru.farm/api/provider/claude' >> ~/.zshrc
|
||||
echo 'export ANTHROPIC_API_KEY=Jinie4eva!' >> ~/.zshrc
|
||||
source ~/.zshrc
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 환경변수로 관리
|
||||
|
||||
`.env` 파일:
|
||||
|
||||
```env
|
||||
ANTHROPIC_BASE_URL=https://cliproxy.gru.farm/api/provider/claude
|
||||
ANTHROPIC_API_KEY=Jinie4eva!
|
||||
```
|
||||
|
||||
Python에서 `.env` 사용:
|
||||
|
||||
```python
|
||||
from dotenv import load_dotenv
|
||||
from anthropic import Anthropic
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# base_url, api_key 자동으로 환경변수에서 읽음
|
||||
client = Anthropic()
|
||||
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-6",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": "안녕!"}]
|
||||
)
|
||||
print(response.content[0].text)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 주의사항
|
||||
|
||||
- **내부망 접근 시** URL을 `http://192.168.0.17:8317`로 변경
|
||||
- **OpenAI 호환 모드**는 `/v1/chat/completions`를 사용하지만, Claude 네이티브 기능(extended thinking 등)은 `/api/provider/claude/v1/messages` 사용 권장
|
||||
- **타임아웃** 설정: 긴 응답의 경우 클라이언트 타임아웃을 600초 이상으로 설정
|
||||
212
DOCKER_DEPLOY.md
Normal file
212
DOCKER_DEPLOY.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# CLIProxyAPI Docker 배포 가이드
|
||||
|
||||
NAS(nas.gru.farm)에 Docker로 CLIProxyAPI를 배포하는 방법을 정리합니다.
|
||||
|
||||
## 사전 조건
|
||||
|
||||
| 항목 | 내용 |
|
||||
|------|------|
|
||||
| NAS 접속 | `ssh airkjw@nas.gru.farm -p 22` |
|
||||
| Docker | `sudo /usr/local/bin/docker` (NOPASSWD) |
|
||||
| Docker Compose | `sudo /usr/local/bin/docker compose` |
|
||||
| NAS 내부 IP | 192.168.0.17 |
|
||||
|
||||
## 1. 배포 디렉토리 준비
|
||||
|
||||
```bash
|
||||
ssh airkjw@nas.gru.farm
|
||||
|
||||
# 배포 디렉토리 생성
|
||||
mkdir -p ~/docker/cli-proxy-api
|
||||
cd ~/docker/cli-proxy-api
|
||||
```
|
||||
|
||||
## 2. 필요 파일 구성
|
||||
|
||||
NAS에 아래 파일들이 필요합니다:
|
||||
|
||||
```
|
||||
~/docker/cli-proxy-api/
|
||||
├── docker-compose.yml # 컨테이너 설정
|
||||
├── config.yaml # 서비스 설정 (API 키, 포트 등)
|
||||
├── auths/ # OAuth 인증 데이터 (자동 생성)
|
||||
└── logs/ # 로그 디렉토리 (자동 생성)
|
||||
```
|
||||
|
||||
## 3. docker-compose.yml
|
||||
|
||||
로컬 빌드 방식 (소스에서 직접 빌드):
|
||||
|
||||
```yaml
|
||||
services:
|
||||
cli-proxy-api:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: cli-proxy-api
|
||||
ports:
|
||||
- "8317:8317" # 메인 API 포트
|
||||
# 필요시 추가 포트 오픈
|
||||
# - "8085:8085"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
또는 공식 이미지 사용:
|
||||
|
||||
```yaml
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: eceasy/cli-proxy-api:latest
|
||||
container_name: cli-proxy-api
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./config.yaml:/CLIProxyAPI/config.yaml
|
||||
- ./auths:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
environment:
|
||||
- TZ=Asia/Seoul
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
## 4. config.yaml 설정
|
||||
|
||||
`config.example.yaml`을 기반으로 작성합니다.
|
||||
|
||||
### 최소 설정 예시
|
||||
|
||||
```yaml
|
||||
# 서버 바인딩
|
||||
host: ""
|
||||
port: 8317
|
||||
|
||||
# API 키 (클라이언트 인증용, 원하는 값으로 설정)
|
||||
api-keys:
|
||||
- "my-secret-api-key-1"
|
||||
|
||||
# 디버그 (초기 설정 시 true 권장, 안정화 후 false)
|
||||
debug: false
|
||||
|
||||
# 로그를 파일로 기록
|
||||
logging-to-file: true
|
||||
logs-max-total-size-mb: 100
|
||||
|
||||
# 재시도 설정
|
||||
request-retry: 3
|
||||
```
|
||||
|
||||
### Claude API 키 사용 시 추가
|
||||
|
||||
```yaml
|
||||
claude-api-key:
|
||||
- api-key: "sk-ant-xxxxx"
|
||||
# base-url: "https://api.anthropic.com" # 기본값이므로 생략 가능
|
||||
```
|
||||
|
||||
### Gemini API 키 사용 시 추가
|
||||
|
||||
```yaml
|
||||
gemini-api-key:
|
||||
- api-key: "AIzaSy..."
|
||||
```
|
||||
|
||||
### Management UI 활성화 (웹 관리 패널)
|
||||
|
||||
```yaml
|
||||
remote-management:
|
||||
allow-remote: true
|
||||
secret-key: "my-management-password"
|
||||
disable-control-panel: false
|
||||
```
|
||||
|
||||
## 5. 배포 실행
|
||||
|
||||
```bash
|
||||
cd ~/docker/cli-proxy-api
|
||||
|
||||
# 공식 이미지 사용 시
|
||||
sudo /usr/local/bin/docker compose up -d
|
||||
|
||||
# 소스 빌드 시 (Gitea에서 소스 가져와서)
|
||||
git clone http://nas.gru.farm:3001/airkjw/CLIProxyAPI.git src
|
||||
sudo /usr/local/bin/docker compose -f src/docker-compose.yml up -d --build
|
||||
```
|
||||
|
||||
## 6. 확인
|
||||
|
||||
```bash
|
||||
# 컨테이너 상태 확인
|
||||
sudo /usr/local/bin/docker ps | grep cli-proxy-api
|
||||
|
||||
# 로그 확인
|
||||
sudo /usr/local/bin/docker logs cli-proxy-api
|
||||
|
||||
# API 응답 테스트
|
||||
curl http://localhost:8317/
|
||||
curl http://192.168.0.17:8317/
|
||||
|
||||
# 모델 목록 확인 (API 키 인증)
|
||||
curl -H "Authorization: Bearer my-secret-api-key-1" http://localhost:8317/v1/models
|
||||
```
|
||||
|
||||
## 7. 클라이언트 연결
|
||||
|
||||
CLIProxyAPI가 실행되면 각 AI CLI 도구에서 프록시 주소로 연결합니다.
|
||||
|
||||
### Claude Code에서 사용
|
||||
|
||||
```bash
|
||||
# 환경변수 설정
|
||||
export ANTHROPIC_BASE_URL=http://192.168.0.17:8317
|
||||
export ANTHROPIC_API_KEY=my-secret-api-key-1
|
||||
```
|
||||
|
||||
### OpenAI 호환 클라이언트에서 사용
|
||||
|
||||
```bash
|
||||
export OPENAI_BASE_URL=http://192.168.0.17:8317/v1
|
||||
export OPENAI_API_KEY=my-secret-api-key-1
|
||||
```
|
||||
|
||||
## 8. 관리 & 운영
|
||||
|
||||
```bash
|
||||
# 컨테이너 중지
|
||||
sudo /usr/local/bin/docker compose down
|
||||
|
||||
# 설정 변경 후 재시작
|
||||
sudo /usr/local/bin/docker compose restart
|
||||
|
||||
# 이미지 업데이트 (공식 이미지 사용 시)
|
||||
sudo /usr/local/bin/docker compose pull
|
||||
sudo /usr/local/bin/docker compose up -d
|
||||
|
||||
# 로그 실시간 모니터링
|
||||
sudo /usr/local/bin/docker logs -f cli-proxy-api
|
||||
```
|
||||
|
||||
## 포트 목록
|
||||
|
||||
| 포트 | 용도 | 필수 여부 |
|
||||
|------|------|-----------|
|
||||
| 8317 | 메인 API | 필수 |
|
||||
| 8085 | 추가 API | 선택 |
|
||||
| 1455 | 추가 서비스 | 선택 |
|
||||
| 54545 | 추가 서비스 | 선택 |
|
||||
| 51121 | 추가 서비스 | 선택 |
|
||||
| 11451 | 추가 서비스 | 선택 |
|
||||
|
||||
> 기본적으로 8317 포트만 열면 됩니다. 나머지는 특정 기능 사용 시 필요합니다.
|
||||
|
||||
## 주의사항
|
||||
|
||||
- `config.yaml`은 `.gitignore`에 포함되어 있어 Git에 커밋되지 않음 (API 키 보호)
|
||||
- OAuth 인증(Claude, Gemini 등)은 최초 1회 브라우저 로그인 필요
|
||||
- `auths/` 디렉토리를 볼륨으로 마운트하면 컨테이너 재생성 시에도 인증 유지
|
||||
- NAS 외부 접근 시 방화벽/포트포워딩 설정 필요
|
||||
20
README.md
20
README.md
@@ -30,6 +30,14 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through <a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus - Premium AI Accounts & Top-ups</a>, users can unlock the mind-blowing rate of <b>10% of the official GPT subscription price (90% OFF)</b>!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>Thanks to LingtrueAPI for its sponsorship of this project! LingtrueAPI is a global large - model API intermediary service platform that provides API calling services for various top - notch models such as Claude Code, Codex, and Gemini. It is committed to enabling users to connect to global AI capabilities at low cost and with high stability. LingtrueAPI offers special discounts to users of this software: register using <a href="https://www.lingtrue.com/register">this link</a>, and enter the promo code "LingtrueAPI" when making the first recharge to enjoy a 10% discount.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -74,6 +82,14 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A
|
||||
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- Security-first design with localhost-only management endpoints
|
||||
|
||||
When you need the request/response shape of a specific backend family, use the provider-specific paths instead of the merged `/v1/...` endpoints:
|
||||
|
||||
- Use `/api/provider/{provider}/v1/messages` for messages-style backends.
|
||||
- Use `/api/provider/{provider}/v1beta/models/...` for model-scoped generate endpoints.
|
||||
- Use `/api/provider/{provider}/v1/chat/completions` for chat-completions backends.
|
||||
|
||||
These routes help you select the protocol surface, but they do not by themselves guarantee a unique inference executor when the same client-visible model name is reused across multiple backends. Inference routing is still resolved from the request model/alias. For strict backend pinning, use unique aliases, prefixes, or otherwise avoid overlapping client-visible model names.
|
||||
|
||||
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK Docs
|
||||
@@ -110,10 +126,6 @@ Browser-based tool to translate SRT subtitles using your Gemini subscription via
|
||||
|
||||
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
|
||||
|
||||
20
README_CN.md
20
README_CN.md
@@ -30,6 +30,14 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">此链接</a>注册的用户,可享受首充8折,企业客户最高可享 7.5 折!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AI成品号专卖/代充</a>注册下单的用户,可享GPT <b>官网订阅一折</b> 的震撼价格!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>感谢 LingtrueAPI 对本项目的赞助!LingtrueAPI 是一家全球大模型API中转服务平台,提供Claude Code、Codex、Gemini 等多种顶级模型API调用服务,致力于让用户以低成本、高稳定性链接全球AI能力。LingtrueAPI为本软件用户提供了特别优惠:使用<a href="https://www.lingtrue.com/register">此链接</a>注册,并在首次充值时输入 "LingtrueAPI" 优惠码即可享受9折优惠。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -73,6 +81,14 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
|
||||
- 智能模型回退与自动路由
|
||||
- 以安全为先的设计,管理端点仅限 localhost
|
||||
|
||||
当你需要某一类后端的请求/响应协议形态时,优先使用 provider-specific 路径,而不是合并后的 `/v1/...` 端点:
|
||||
|
||||
- 对于 messages 风格的后端,使用 `/api/provider/{provider}/v1/messages`。
|
||||
- 对于按模型路径暴露生成接口的后端,使用 `/api/provider/{provider}/v1beta/models/...`。
|
||||
- 对于 chat-completions 风格的后端,使用 `/api/provider/{provider}/v1/chat/completions`。
|
||||
|
||||
这些路径有助于选择协议表面,但当多个后端复用同一个客户端可见模型名时,它们本身并不能保证唯一的推理执行器。实际的推理路由仍然根据请求里的 model/alias 解析。若要严格钉住某个后端,请使用唯一 alias、前缀,或避免让多个后端暴露相同的客户端模型名。
|
||||
|
||||
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK 文档
|
||||
@@ -109,10 +125,6 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
|
||||
|
||||
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型(Gemini, Codex, Antigravity),无需 API 密钥。
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
|
||||
|
||||
20
README_JA.md
20
README_JA.md
@@ -30,6 +30,14 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>LingtrueAPIのスポンサーシップに感謝します!LingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.lingtrue.com/register">こちらのリンク</a>から登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -74,6 +82,14 @@ CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
特定のバックエンド系統のリクエスト/レスポンス形状が必要な場合は、統合された `/v1/...` エンドポイントよりも provider-specific のパスを優先してください。
|
||||
|
||||
- messages 系のバックエンドには `/api/provider/{provider}/v1/messages`
|
||||
- モデル単位の generate 系エンドポイントには `/api/provider/{provider}/v1beta/models/...`
|
||||
- chat-completions 系のバックエンドには `/api/provider/{provider}/v1/chat/completions`
|
||||
|
||||
これらのパスはプロトコル面の選択には役立ちますが、同じクライアント向けモデル名が複数バックエンドで再利用されている場合、それだけで推論実行系が一意に固定されるわけではありません。実際の推論ルーティングは、引き続きリクエスト内の model/alias 解決に従います。厳密にバックエンドを固定したい場合は、一意な alias や prefix を使うか、クライアント向けモデル名の重複自体を避けてください。
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
@@ -110,10 +126,6 @@ CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を
|
||||
|
||||
CLIProxyAPI OAuthを使用して複数のClaudeアカウントや代替モデル(Gemini、Codex、Antigravity)を即座に切り替えるCLIラッパー - APIキー不要
|
||||
|
||||
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
|
||||
|
||||
CLIProxyAPI管理用のmacOSネイティブGUI:OAuth経由でプロバイダー、モデルマッピング、エンドポイントを設定 - APIキー不要
|
||||
|
||||
### [Quotio](https://github.com/nguyenphutrong/quotio)
|
||||
|
||||
Claude、Gemini、OpenAI、Qwen、Antigravityのサブスクリプションを統合し、リアルタイムのクォータ追跡とスマート自動フェイルオーバーを備えたmacOSネイティブのメニューバーアプリ。Claude Code、OpenCode、Droidなどのコーディングツール向け - APIキー不要
|
||||
|
||||
104
REVERSE_PROXY_SETUP.md
Normal file
104
REVERSE_PROXY_SETUP.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# CLIProxyAPI 역방향 프록시 & HTTPS 설정 가이드
|
||||
|
||||
외부에서 `https://cliproxy.gru.farm`으로 CLIProxyAPI에 접근하기 위한 설정입니다.
|
||||
|
||||
## 1단계: DNS 레코드 추가
|
||||
|
||||
hostcocoa.com DNS 관리에서 A 레코드를 추가합니다.
|
||||
|
||||
| 타입 | 호스트 | 값 |
|
||||
|------|--------|-----|
|
||||
| A | cliproxy | 125.188.185.74 |
|
||||
|
||||
> 기존 `nas.gru.farm`, `haesol.gru.farm` 등과 같은 IP입니다.
|
||||
|
||||
## 2단계: Synology DSM 역방향 프록시 설정
|
||||
|
||||
1. DSM 웹 UI 접속 (보통 `https://nas.gru.farm:5001`)
|
||||
2. **제어판** → **로그인 포털** → **고급** 탭 → **역방향 프록시** 클릭
|
||||
3. **생성** 버튼 클릭
|
||||
4. 아래와 같이 입력:
|
||||
|
||||
### 일반 설정
|
||||
|
||||
| 항목 | 값 |
|
||||
|------|-----|
|
||||
| 설명 | `CLIProxyAPI` |
|
||||
| **소스 (프론트엔드)** | |
|
||||
| 프로토콜 | `HTTPS` |
|
||||
| 호스트 이름 | `cliproxy.gru.farm` |
|
||||
| 포트 | `443` |
|
||||
| HSTS | 비활성화 |
|
||||
| **대상 (백엔드)** | |
|
||||
| 프로토콜 | `HTTP` |
|
||||
| 호스트 이름 | `localhost` |
|
||||
| 포트 | `8317` |
|
||||
|
||||
### 사용자 지정 헤더 (선택)
|
||||
|
||||
필요 시 WebSocket 지원을 위해 사용자 지정 헤더 추가:
|
||||
- `Upgrade` → `$http_upgrade`
|
||||
- `Connection` → `$connection_upgrade`
|
||||
|
||||
### 타임아웃 설정
|
||||
|
||||
AI 요청은 응답이 오래 걸릴 수 있으므로 타임아웃을 늘려주세요:
|
||||
- 연결 타임아웃: `600`
|
||||
- 전송 타임아웃: `600`
|
||||
- 수신 타임아웃: `600`
|
||||
|
||||
5. **저장** 클릭
|
||||
|
||||
## 3단계: SSL 인증서 설정
|
||||
|
||||
Synology DSM에서 `cliproxy.gru.farm` 용 SSL 인증서를 설정합니다.
|
||||
|
||||
### Let's Encrypt 인증서 발급 (권장)
|
||||
|
||||
1. **제어판** → **보안** → **인증서** 탭
|
||||
2. **추가** → **새 인증서 추가** → **Let's Encrypt에서 인증서 가져오기**
|
||||
3. 도메인: `cliproxy.gru.farm`
|
||||
4. 이메일: 본인 이메일
|
||||
5. 발급 완료 후, **설정** 버튼 클릭
|
||||
6. `cliproxy.gru.farm` 역방향 프록시 항목에 방금 발급한 인증서 선택
|
||||
|
||||
### 기존 와일드카드 인증서가 있는 경우
|
||||
|
||||
`*.gru.farm` 와일드카드 인증서가 있다면 별도 발급 없이 해당 인증서를 선택하면 됩니다.
|
||||
|
||||
## 4단계: 공유기 포트 포워딩
|
||||
|
||||
공유기에서 443 포트가 NAS(192.168.0.17)로 포워딩되어 있는지 확인합니다.
|
||||
|
||||
> 기존 `haesol.gru.farm` 등이 HTTPS로 동작 중이라면 이미 설정되어 있을 가능성이 높습니다.
|
||||
|
||||
| 외부 포트 | 내부 IP | 내부 포트 | 프로토콜 |
|
||||
|-----------|---------|-----------|----------|
|
||||
| 443 | 192.168.0.17 | 443 | TCP |
|
||||
|
||||
## 5단계: 확인
|
||||
|
||||
```bash
|
||||
# DNS 전파 확인
|
||||
dig +short cliproxy.gru.farm
|
||||
# 125.188.185.74 가 나오면 성공
|
||||
|
||||
# HTTPS 접속 테스트
|
||||
curl https://cliproxy.gru.farm/
|
||||
# {"endpoints":[...],"message":"CLI Proxy API Server"}
|
||||
|
||||
# 모델 목록 확인
|
||||
curl -H "Authorization: Bearer Jinie4eva!" https://cliproxy.gru.farm/v1/models
|
||||
```
|
||||
|
||||
## 클라이언트 연결 (외부)
|
||||
|
||||
```bash
|
||||
# Claude Code
|
||||
export ANTHROPIC_BASE_URL=https://cliproxy.gru.farm
|
||||
export ANTHROPIC_API_KEY=Jinie4eva!
|
||||
|
||||
# OpenAI 호환
|
||||
export OPENAI_BASE_URL=https://cliproxy.gru.farm/v1
|
||||
export OPENAI_API_KEY=Jinie4eva!
|
||||
```
|
||||
BIN
assets/bmoplus.png
Normal file
BIN
assets/bmoplus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
@@ -281,6 +281,10 @@ nonstream-keepalive-interval: 0
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-alias:
|
||||
# gemini-cli:
|
||||
|
||||
@@ -541,10 +541,23 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
|
||||
}
|
||||
|
||||
func isUnsafeAuthFileName(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return true
|
||||
}
|
||||
if strings.ContainsAny(name, "/\\") {
|
||||
return true
|
||||
}
|
||||
if filepath.VolumeName(name) != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Download single auth file by name
|
||||
func (h *Handler) DownloadAuthFile(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
name := strings.TrimSpace(c.Query("name"))
|
||||
if isUnsafeAuthFileName(name) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
@@ -626,8 +639,8 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
|
||||
return
|
||||
}
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
name := strings.TrimSpace(c.Query("name"))
|
||||
if isUnsafeAuthFileName(name) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
@@ -860,7 +873,7 @@ func uniqueAuthFileNames(names []string) []string {
|
||||
|
||||
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
if isUnsafeAuthFileName(name) {
|
||||
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
|
||||
}
|
||||
|
||||
|
||||
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
fileName := "download-user.json"
|
||||
expected := []byte(`{"type":"codex"}`)
|
||||
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := rec.Body.Bytes(); string(got) != string(expected) {
|
||||
t.Fatalf("unexpected download content: %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
|
||||
|
||||
for _, name := range []string{
|
||||
"../external/secret.json",
|
||||
`..\\external\\secret.json`,
|
||||
"nested/secret.json",
|
||||
`nested\\secret.json`,
|
||||
} {
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
//go:build windows
|
||||
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
authDir := filepath.Join(tempDir, "auth")
|
||||
externalDir := filepath.Join(tempDir, "external")
|
||||
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(externalDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create external dir: %v", err)
|
||||
}
|
||||
|
||||
secretName := "secret.json"
|
||||
secretPath := filepath.Join(externalDir, secretName)
|
||||
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
|
||||
t.Fatalf("failed to write external file: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
|
||||
nil,
|
||||
)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize request body: remove thinking blocks with invalid signatures
|
||||
// to prevent upstream API 400 errors
|
||||
bodyBytes = SanitizeAmpRequestBody(bodyBytes)
|
||||
|
||||
// Restore the body for the handler to read
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
@@ -259,10 +263,16 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
// Wrap with ResponseRewriter for local providers too, because upstream
|
||||
// proxies (e.g. NewAPI) may return a different model name and lack
|
||||
// Amp-required fields like thinking.signature.
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
} else {
|
||||
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
@@ -2,6 +2,7 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -12,32 +13,83 @@ import (
|
||||
)
|
||||
|
||||
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
|
||||
// It's used to rewrite model names in responses when model mapping is used
|
||||
// It is used to rewrite model names in responses when model mapping is used
|
||||
// and to keep Amp-compatible response shapes.
|
||||
type ResponseRewriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
originalModel string
|
||||
isStreaming bool
|
||||
body *bytes.Buffer
|
||||
originalModel string
|
||||
isStreaming bool
|
||||
suppressedContentBlock map[int]struct{}
|
||||
}
|
||||
|
||||
// NewResponseRewriter creates a new response rewriter for model name substitution
|
||||
// NewResponseRewriter creates a new response rewriter for model name substitution.
|
||||
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
|
||||
return &ResponseRewriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
originalModel: originalModel,
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
originalModel: originalModel,
|
||||
suppressedContentBlock: make(map[int]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement
|
||||
const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap
|
||||
|
||||
func looksLikeSSEChunk(data []byte) bool {
|
||||
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) ||
|
||||
bytes.HasPrefix(trimmed, []byte("event:")) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) enableStreaming(reason string) error {
|
||||
if rw.isStreaming {
|
||||
return nil
|
||||
}
|
||||
rw.isStreaming = true
|
||||
|
||||
if rw.body != nil && rw.body.Len() > 0 {
|
||||
buf := rw.body.Bytes()
|
||||
toFlush := make([]byte, len(buf))
|
||||
copy(toFlush, buf)
|
||||
rw.body.Reset()
|
||||
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil {
|
||||
return err
|
||||
}
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("amp response rewriter: switched to streaming (%s)", reason)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
// Detect streaming on first write
|
||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||
if !rw.isStreaming && rw.body.Len() == 0 {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if !rw.isStreaming {
|
||||
if looksLikeSSEChunk(data) {
|
||||
if err := rw.enableStreaming("sse heuristic"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else if rw.body.Len()+len(data) > maxBufferedResponseBytes {
|
||||
log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes)
|
||||
if err := rw.enableStreaming("buffer limit"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
@@ -50,7 +102,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
// Flush writes the buffered response with model names rewritten
|
||||
func (rw *ResponseRewriter) Flush() {
|
||||
if rw.isStreaming {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
@@ -59,26 +110,68 @@ func (rw *ResponseRewriter) Flush() {
|
||||
return
|
||||
}
|
||||
if rw.body.Len() > 0 {
|
||||
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
|
||||
rewritten := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||
// Update Content-Length to match the rewritten body size, since
|
||||
// signature injection and model name changes alter the payload length.
|
||||
rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten)))
|
||||
if _, err := rw.ResponseWriter.Write(rewritten); err != nil {
|
||||
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
|
||||
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
|
||||
// The Amp client struggles when both thinking and tool_use blocks are present
|
||||
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||
func ensureAmpSignature(data []byte) []byte {
|
||||
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType != "tool_use" && blockType != "thinking" {
|
||||
continue
|
||||
}
|
||||
signaturePath := fmt.Sprintf("content.%d.signature", index)
|
||||
if gjson.GetBytes(data, signaturePath).Exists() {
|
||||
continue
|
||||
}
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, signaturePath, "")
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
contentBlockType := gjson.GetBytes(data, "content_block.type").String()
|
||||
if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content_block.signature", "")
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err)
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) markSuppressedContentBlock(index int) {
|
||||
if rw.suppressedContentBlock == nil {
|
||||
rw.suppressedContentBlock = make(map[int]struct{})
|
||||
}
|
||||
rw.suppressedContentBlock[index] = struct{}{}
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) isSuppressedContentBlock(index int) bool {
|
||||
_, ok := rw.suppressedContentBlock[index]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
|
||||
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
|
||||
if filtered.Exists() {
|
||||
originalCount := gjson.GetBytes(data, "content.#").Int()
|
||||
filteredCount := filtered.Get("#").Int()
|
||||
|
||||
if originalCount > filteredCount {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content", filtered.Value())
|
||||
@@ -86,13 +179,41 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
|
||||
} else {
|
||||
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
|
||||
// Log the result for verification
|
||||
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
eventType := gjson.GetBytes(data, "type").String()
|
||||
indexResult := gjson.GetBytes(data, "index")
|
||||
if eventType == "content_block_start" && gjson.GetBytes(data, "content_block.type").String() == "thinking" && indexResult.Exists() {
|
||||
rw.markSuppressedContentBlock(int(indexResult.Int()))
|
||||
return nil
|
||||
}
|
||||
if gjson.GetBytes(data, "delta.type").String() == "thinking_delta" {
|
||||
if indexResult.Exists() {
|
||||
rw.markSuppressedContentBlock(int(indexResult.Int()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if eventType == "content_block_stop" && indexResult.Exists() {
|
||||
index := int(indexResult.Int())
|
||||
if rw.isSuppressedContentBlock(index) {
|
||||
delete(rw.suppressedContentBlock, index)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
data = ensureAmpSignature(data)
|
||||
data = rw.suppressAmpThinking(data)
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
if rw.originalModel == "" {
|
||||
return data
|
||||
}
|
||||
@@ -104,24 +225,158 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
// rewriteStreamChunk rewrites model names in SSE stream chunks
|
||||
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||
if rw.originalModel == "" {
|
||||
return chunk
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
var out [][]byte
|
||||
|
||||
i := 0
|
||||
for i < len(lines) {
|
||||
line := lines[i]
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
|
||||
// Case 1: "event:" line - look ahead for its "data:" line
|
||||
if bytes.HasPrefix(trimmed, []byte("event: ")) {
|
||||
// Scan forward past blank lines to find the data: line
|
||||
dataIdx := -1
|
||||
for j := i + 1; j < len(lines); j++ {
|
||||
t := bytes.TrimSpace(lines[j])
|
||||
if len(t) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(t, []byte("data: ")) {
|
||||
dataIdx = j
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
if dataIdx >= 0 {
|
||||
// Found event+data pair - process through rewriter
|
||||
jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||
if rewritten == nil {
|
||||
// Event suppressed (e.g. thinking block), skip event+data pair
|
||||
i = dataIdx + 1
|
||||
continue
|
||||
}
|
||||
// Emit event line
|
||||
out = append(out, line)
|
||||
// Emit blank lines between event and data
|
||||
for k := i + 1; k < dataIdx; k++ {
|
||||
out = append(out, lines[k])
|
||||
}
|
||||
// Emit rewritten data
|
||||
out = append(out, append([]byte("data: "), rewritten...))
|
||||
i = dataIdx + 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// No data line found (orphan event from cross-chunk split)
|
||||
// Pass it through as-is - the data will arrive in the next chunk
|
||||
out = append(out, line)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
// Case 2: standalone "data:" line (no preceding event: in this chunk)
|
||||
if bytes.HasPrefix(trimmed, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(trimmed, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
rewritten := rw.rewriteStreamEvent(jsonData)
|
||||
if rewritten != nil {
|
||||
out = append(out, append([]byte("data: "), rewritten...))
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Case 3: everything else
|
||||
out = append(out, line)
|
||||
i++
|
||||
}
|
||||
|
||||
// SSE format: "data: {json}\n\n"
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
// Rewrite JSON in the data line
|
||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||
lines[i] = append([]byte("data: "), rewritten...)
|
||||
return bytes.Join(out, []byte("\n"))
|
||||
}
|
||||
|
||||
// rewriteStreamEvent processes a single JSON event in the SSE stream.
|
||||
// It rewrites model names and ensures signature fields exist.
|
||||
func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||
// Suppress thinking blocks before any other processing.
|
||||
data = rw.suppressAmpThinking(data)
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Inject empty signature where needed
|
||||
data = ensureAmpSignature(data)
|
||||
|
||||
// Rewrite model name
|
||||
if rw.originalModel != "" {
|
||||
for _, path := range modelFieldPaths {
|
||||
if gjson.GetBytes(data, path).Exists() {
|
||||
data, _ = sjson.SetBytes(data, path, rw.originalModel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
return data
|
||||
}
|
||||
|
||||
// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures
|
||||
// from the messages array in a request body before forwarding to the upstream API.
|
||||
// This prevents 400 errors from the API which requires valid signatures on thinking blocks.
|
||||
func SanitizeAmpRequestBody(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
for msgIdx, msg := range messages.Array() {
|
||||
if msg.Get("role").String() != "assistant" {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
var keepBlocks []interface{}
|
||||
removedCount := 0
|
||||
|
||||
for _, block := range content.Array() {
|
||||
blockType := block.Get("type").String()
|
||||
if blockType == "thinking" {
|
||||
sig := block.Get("signature")
|
||||
if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" {
|
||||
removedCount++
|
||||
continue
|
||||
}
|
||||
}
|
||||
keepBlocks = append(keepBlocks, block.Value())
|
||||
}
|
||||
|
||||
if removedCount > 0 {
|
||||
contentPath := fmt.Sprintf("messages.%d.content", msgIdx)
|
||||
var err error
|
||||
if len(keepBlocks) == 0 {
|
||||
body, err = sjson.SetBytes(body, contentPath, []interface{}{})
|
||||
} else {
|
||||
body, err = sjson.SetBytes(body, contentPath, keepBlocks)
|
||||
}
|
||||
if err != nil {
|
||||
log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err)
|
||||
continue
|
||||
}
|
||||
modified = true
|
||||
log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx)
|
||||
}
|
||||
}
|
||||
|
||||
if modified {
|
||||
log.Debugf("Amp RequestSanitizer: sanitized request body")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -100,6 +100,44 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteStreamChunk_SuppressesThinkingContentBlockFrames(t *testing.T) {
|
||||
rw := &ResponseRewriter{suppressedContentBlock: make(map[int]struct{})}
|
||||
|
||||
chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n")
|
||||
result := rw.rewriteStreamChunk(chunk)
|
||||
|
||||
if contains(result, []byte("\"thinking\"")) || contains(result, []byte("\"thinking_delta\"")) {
|
||||
t.Fatalf("expected thinking content_block frames to be suppressed, got %s", string(result))
|
||||
}
|
||||
if contains(result, []byte("content_block_stop")) {
|
||||
t.Fatalf("expected suppressed thinking content_block_stop to be removed, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte("\"tool_use\"")) {
|
||||
t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte("\"signature\":\"\"")) {
|
||||
t.Fatalf("expected tool_use content_block signature injection, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) {
|
||||
input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`)
|
||||
result := SanitizeAmpRequestBody(input)
|
||||
|
||||
if contains(result, []byte("drop-whitespace")) {
|
||||
t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result))
|
||||
}
|
||||
if contains(result, []byte("drop-number")) {
|
||||
t.Fatalf("expected non-string signature block to be removed, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte("keep-valid")) {
|
||||
t.Fatalf("expected valid thinking block to remain, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte("keep-text")) {
|
||||
t.Fatalf("expected non-thinking content to remain, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -44,6 +45,10 @@ type ClaudeExecutor struct {
|
||||
// Previously "proxy_" was used but this is a detectable fingerprint difference.
|
||||
const claudeToolPrefix = ""
|
||||
|
||||
// Anthropic-compatible upstreams may reject or even crash when Claude models
|
||||
// omit max_tokens. Prefer registered model metadata before using a fallback.
|
||||
const defaultModelMaxTokens = 1024
|
||||
|
||||
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
|
||||
|
||||
func (e *ClaudeExecutor) Identifier() string { return "claude" }
|
||||
@@ -127,6 +132,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body = ensureModelMaxTokens(body, baseModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -293,6 +299,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body = ensureModelMaxTokens(body, baseModel)
|
||||
|
||||
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
|
||||
body = disableThinkingIfToolChoiceForced(body)
|
||||
@@ -1880,3 +1887,26 @@ func injectSystemCacheControl(payload []byte) []byte {
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
func ensureModelMaxTokens(body []byte, modelID string) []byte {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
return body
|
||||
}
|
||||
|
||||
if maxTokens := gjson.GetBytes(body, "max_tokens"); maxTokens.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(strings.TrimSpace(modelID)) {
|
||||
if strings.EqualFold(provider, "claude") {
|
||||
maxTokens := defaultModelMaxTokens
|
||||
if info := registry.GetGlobalRegistry().GetModelInfo(strings.TrimSpace(modelID), "claude"); info != nil && info.MaxCompletionTokens > 0 {
|
||||
maxTokens = info.MaxCompletionTokens
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", maxTokens)
|
||||
return body
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
@@ -1183,6 +1184,83 @@ func testClaudeExecutorInvalidCompressedErrorBody(
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_UsesRegisteredMaxCompletionTokens(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-claude-max-completion-tokens-client"
|
||||
modelID := "test-claude-max-completion-tokens-model"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||
ID: modelID,
|
||||
Type: "claude",
|
||||
OwnedBy: "anthropic",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
MaxCompletionTokens: 4096,
|
||||
UserDefined: true,
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
input := []byte(`{"model":"test-claude-max-completion-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, modelID)
|
||||
|
||||
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 4096 {
|
||||
t.Fatalf("max_tokens = %d, want %d", got, 4096)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_DefaultsMissingValue(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-claude-default-max-tokens-client"
|
||||
modelID := "test-claude-default-max-tokens-model"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||
ID: modelID,
|
||||
Type: "claude",
|
||||
OwnedBy: "anthropic",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
UserDefined: true,
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
input := []byte(`{"model":"test-claude-default-max-tokens-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, modelID)
|
||||
|
||||
if got := gjson.GetBytes(out, "max_tokens").Int(); got != defaultModelMaxTokens {
|
||||
t.Fatalf("max_tokens = %d, want %d", got, defaultModelMaxTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_PreservesExplicitValue(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
clientID := "test-claude-preserve-max-tokens-client"
|
||||
modelID := "test-claude-preserve-max-tokens-model"
|
||||
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{
|
||||
ID: modelID,
|
||||
Type: "claude",
|
||||
OwnedBy: "anthropic",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
MaxCompletionTokens: 4096,
|
||||
UserDefined: true,
|
||||
}})
|
||||
defer reg.UnregisterClient(clientID)
|
||||
|
||||
input := []byte(`{"model":"test-claude-preserve-max-tokens-model","max_tokens":2048,"messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, modelID)
|
||||
|
||||
if got := gjson.GetBytes(out, "max_tokens").Int(); got != 2048 {
|
||||
t.Fatalf("max_tokens = %d, want %d", got, 2048)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureModelMaxTokens_SkipsUnregisteredModel(t *testing.T) {
|
||||
input := []byte(`{"model":"test-claude-unregistered-model","messages":[{"role":"user","content":"hi"}]}`)
|
||||
out := ensureModelMaxTokens(input, "test-claude-unregistered-model")
|
||||
|
||||
if gjson.GetBytes(out, "max_tokens").Exists() {
|
||||
t.Fatalf("max_tokens should remain unset, got %s", gjson.GetBytes(out, "max_tokens").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
|
||||
// requests use Accept-Encoding: identity so the upstream cannot respond with a
|
||||
// compressed SSE body that would silently break the line scanner.
|
||||
|
||||
@@ -113,6 +113,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
@@ -311,6 +312,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -415,6 +417,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -685,13 +688,39 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
errCode := statusCode
|
||||
if isCodexModelCapacityError(body) {
|
||||
errCode = http.StatusTooManyRequests
|
||||
}
|
||||
err := statusErr{code: errCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isCodexModelCapacityError(errorBody []byte) bool {
|
||||
if len(errorBody) == 0 {
|
||||
return false
|
||||
}
|
||||
candidates := []string{
|
||||
gjson.GetBytes(errorBody, "error.message").String(),
|
||||
gjson.GetBytes(errorBody, "message").String(),
|
||||
string(errorBody),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
lower := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(lower, "selected model is at capacity") ||
|
||||
strings.Contains(lower, "model is at capacity. please try a different model") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
|
||||
|
||||
err := newCodexStatusErr(http.StatusBadRequest, body)
|
||||
|
||||
if got := err.StatusCode(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||
}
|
||||
if err.RetryAfter() != nil {
|
||||
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
@@ -330,32 +330,45 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder parts for 'model' role to ensure thinking block is first
|
||||
// Reorder parts for 'model' role:
|
||||
// 1. Thinking parts first (Antigravity API requirement)
|
||||
// 2. Regular parts (text, inlineData, etc.)
|
||||
// 3. FunctionCall parts last
|
||||
//
|
||||
// Moving functionCall parts to the end prevents tool_use↔tool_result
|
||||
// pairing breakage: the Antigravity API internally splits model messages
|
||||
// at functionCall boundaries. If a text part follows a functionCall, the
|
||||
// split creates an extra assistant turn between tool_use and tool_result,
|
||||
// which Claude rejects with "tool_use ids were found without tool_result
|
||||
// blocks immediately after".
|
||||
if role == "model" {
|
||||
partsResult := gjson.GetBytes(clientContentJSON, "parts")
|
||||
if partsResult.IsArray() {
|
||||
parts := partsResult.Array()
|
||||
var thinkingParts []gjson.Result
|
||||
var otherParts []gjson.Result
|
||||
for _, part := range parts {
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingParts = append(thinkingParts, part)
|
||||
} else {
|
||||
otherParts = append(otherParts, part)
|
||||
}
|
||||
}
|
||||
if len(thinkingParts) > 0 {
|
||||
firstPartIsThinking := parts[0].Get("thought").Bool()
|
||||
if !firstPartIsThinking || len(thinkingParts) > 1 {
|
||||
var newParts []interface{}
|
||||
for _, p := range thinkingParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
if len(parts) > 1 {
|
||||
var thinkingParts []gjson.Result
|
||||
var regularParts []gjson.Result
|
||||
var functionCallParts []gjson.Result
|
||||
for _, part := range parts {
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingParts = append(thinkingParts, part)
|
||||
} else if part.Get("functionCall").Exists() {
|
||||
functionCallParts = append(functionCallParts, part)
|
||||
} else {
|
||||
regularParts = append(regularParts, part)
|
||||
}
|
||||
for _, p := range otherParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
var newParts []interface{}
|
||||
for _, p := range thinkingParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
for _, p := range regularParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
for _, p := range functionCallParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetBytes(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -361,6 +361,167 @@ func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderTextAfterFunctionCall(t *testing.T) {
|
||||
// Bug: text part after tool_use in an assistant message causes Antigravity
|
||||
// to split at functionCall boundary, creating an extra assistant turn that
|
||||
// breaks tool_use↔tool_result adjacency (upstream issue #989).
|
||||
// Fix: reorder parts so functionCall comes last.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me check..."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_abc",
|
||||
"name": "Read",
|
||||
"input": {"file": "test.go"}
|
||||
},
|
||||
{"type": "text", "text": "Reading the file now"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "call_abc",
|
||||
"content": "file content"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("Expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Text parts should come before functionCall
|
||||
if parts[0].Get("text").String() != "Let me check..." {
|
||||
t.Errorf("Expected first text part first, got %s", parts[0].Raw)
|
||||
}
|
||||
if parts[1].Get("text").String() != "Reading the file now" {
|
||||
t.Errorf("Expected second text part second, got %s", parts[1].Raw)
|
||||
}
|
||||
if !parts[2].Get("functionCall").Exists() {
|
||||
t.Errorf("Expected functionCall last, got %s", parts[2].Raw)
|
||||
}
|
||||
if parts[2].Get("functionCall.name").String() != "Read" {
|
||||
t.Errorf("Expected functionCall name 'Read', got '%s'", parts[2].Get("functionCall.name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderParallelFunctionCalls(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Reading both files."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_1",
|
||||
"name": "Read",
|
||||
"input": {"file": "a.go"}
|
||||
},
|
||||
{"type": "text", "text": "And this one too."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_2",
|
||||
"name": "Read",
|
||||
"input": {"file": "b.go"}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 4 {
|
||||
t.Fatalf("Expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
if parts[0].Get("text").String() != "Reading both files." {
|
||||
t.Errorf("Expected first text, got %s", parts[0].Raw)
|
||||
}
|
||||
if parts[1].Get("text").String() != "And this one too." {
|
||||
t.Errorf("Expected second text, got %s", parts[1].Raw)
|
||||
}
|
||||
if parts[2].Get("functionCall.name").String() != "Read" || parts[2].Get("functionCall.id").String() != "call_1" {
|
||||
t.Errorf("Expected fc1 third, got %s", parts[2].Raw)
|
||||
}
|
||||
if parts[3].Get("functionCall.name").String() != "Read" || parts[3].Get("functionCall.id").String() != "call_2" {
|
||||
t.Errorf("Expected fc2 fourth, got %s", parts[3].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinkingAndTextBeforeFunctionCall(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
thinkingText := "Let me think about this..."
|
||||
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Before thinking"},
|
||||
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_xyz",
|
||||
"name": "Bash",
|
||||
"input": {"command": "ls"}
|
||||
},
|
||||
{"type": "text", "text": "After tool call"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// contents.1 = assistant message (contents.0 = user)
|
||||
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
|
||||
if len(parts) != 4 {
|
||||
t.Fatalf("Expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
// Order: thinking → text → text → functionCall
|
||||
if !parts[0].Get("thought").Bool() {
|
||||
t.Error("First part should be thinking")
|
||||
}
|
||||
if parts[1].Get("functionCall").Exists() || parts[1].Get("thought").Bool() {
|
||||
t.Errorf("Second part should be text, got %s", parts[1].Raw)
|
||||
}
|
||||
if parts[2].Get("functionCall").Exists() || parts[2].Get("thought").Bool() {
|
||||
t.Errorf("Third part should be text, got %s", parts[2].Raw)
|
||||
}
|
||||
if !parts[3].Get("functionCall").Exists() {
|
||||
t.Errorf("Last part should be functionCall, got %s", parts[3].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
type oaiToResponsesStateReasoning struct {
|
||||
ReasoningID string
|
||||
ReasoningData string
|
||||
OutputIndex int
|
||||
}
|
||||
type oaiToResponsesState struct {
|
||||
Seq int
|
||||
@@ -29,16 +31,19 @@ type oaiToResponsesState struct {
|
||||
MsgTextBuf map[int]*strings.Builder
|
||||
ReasoningBuf strings.Builder
|
||||
Reasonings []oaiToResponsesStateReasoning
|
||||
FuncArgsBuf map[int]*strings.Builder // index -> args
|
||||
FuncNames map[int]string // index -> name
|
||||
FuncCallIDs map[int]string // index -> call_id
|
||||
FuncArgsBuf map[string]*strings.Builder
|
||||
FuncNames map[string]string
|
||||
FuncCallIDs map[string]string
|
||||
FuncOutputIx map[string]int
|
||||
MsgOutputIx map[int]int
|
||||
NextOutputIx int
|
||||
// message item state per output index
|
||||
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
|
||||
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
|
||||
MsgItemDone map[int]bool // whether message done events were emitted
|
||||
// function item done state
|
||||
FuncArgsDone map[int]bool
|
||||
FuncItemDone map[int]bool
|
||||
FuncArgsDone map[string]bool
|
||||
FuncItemDone map[string]bool
|
||||
// usage aggregation
|
||||
PromptTokens int64
|
||||
CachedTokens int64
|
||||
@@ -60,15 +65,17 @@ func emitRespEvent(event string, payload []byte) []byte {
|
||||
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &oaiToResponsesState{
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncArgsBuf: make(map[string]*strings.Builder),
|
||||
FuncNames: make(map[string]string),
|
||||
FuncCallIDs: make(map[string]string),
|
||||
FuncOutputIx: make(map[string]int),
|
||||
MsgOutputIx: make(map[int]int),
|
||||
MsgTextBuf: make(map[int]*strings.Builder),
|
||||
MsgItemAdded: make(map[int]bool),
|
||||
MsgContentAdded: make(map[int]bool),
|
||||
MsgItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[int]bool),
|
||||
FuncItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[string]bool),
|
||||
FuncItemDone: make(map[string]bool),
|
||||
Reasonings: make([]oaiToResponsesStateReasoning, 0),
|
||||
}
|
||||
}
|
||||
@@ -125,6 +132,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
allocOutputIndex := func() int {
|
||||
ix := st.NextOutputIx
|
||||
st.NextOutputIx++
|
||||
return ix
|
||||
}
|
||||
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
|
||||
var out [][]byte
|
||||
|
||||
if !st.Started {
|
||||
@@ -135,14 +148,17 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningID = ""
|
||||
st.ReasoningIndex = 0
|
||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||
st.FuncNames = make(map[int]string)
|
||||
st.FuncCallIDs = make(map[int]string)
|
||||
st.FuncArgsBuf = make(map[string]*strings.Builder)
|
||||
st.FuncNames = make(map[string]string)
|
||||
st.FuncCallIDs = make(map[string]string)
|
||||
st.FuncOutputIx = make(map[string]int)
|
||||
st.MsgOutputIx = make(map[int]int)
|
||||
st.NextOutputIx = 0
|
||||
st.MsgItemAdded = make(map[int]bool)
|
||||
st.MsgContentAdded = make(map[int]bool)
|
||||
st.MsgItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[int]bool)
|
||||
st.FuncItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[string]bool)
|
||||
st.FuncItemDone = make(map[string]bool)
|
||||
st.PromptTokens = 0
|
||||
st.CachedTokens = 0
|
||||
st.CompletionTokens = 0
|
||||
@@ -185,7 +201,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
|
||||
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
|
||||
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
|
||||
st.ReasoningID = ""
|
||||
}
|
||||
|
||||
@@ -201,10 +217,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
stopReasoning(st.ReasoningBuf.String())
|
||||
st.ReasoningBuf.Reset()
|
||||
}
|
||||
if _, exists := st.MsgOutputIx[idx]; !exists {
|
||||
st.MsgOutputIx[idx] = allocOutputIndex()
|
||||
}
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
if !st.MsgItemAdded[idx] {
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
st.MsgItemAdded[idx] = true
|
||||
@@ -213,7 +233,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
||||
part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
|
||||
part, _ = sjson.SetBytes(part, "content_index", 0)
|
||||
out = append(out, emitRespEvent("response.content_part.added", part))
|
||||
st.MsgContentAdded[idx] = true
|
||||
@@ -222,7 +242,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
|
||||
msg, _ = sjson.SetBytes(msg, "content_index", 0)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", c.String())
|
||||
out = append(out, emitRespEvent("response.output_text.delta", msg))
|
||||
@@ -238,10 +258,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// On first appearance, add reasoning item and part
|
||||
if st.ReasoningID == "" {
|
||||
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||
st.ReasoningIndex = idx
|
||||
st.ReasoningIndex = allocOutputIndex()
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
@@ -269,6 +289,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// Before emitting any function events, if a message is open for this index,
|
||||
// close its text/content to match Codex expected ordering.
|
||||
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[idx]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -276,7 +297,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
done, _ = sjson.SetBytes(done, "output_index", idx)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -284,69 +305,72 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", idx)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.MsgItemDone[idx] = true
|
||||
}
|
||||
|
||||
// Only emit item.added once per tool call and preserve call_id across chunks.
|
||||
newCallID := tcs.Get("0.id").String()
|
||||
nameChunk := tcs.Get("0.function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[idx] = nameChunk
|
||||
}
|
||||
existingCallID := st.FuncCallIDs[idx]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
// First time seeing a valid call_id for this index
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[idx] = newCallID
|
||||
shouldEmitItem = true
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", idx)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
name := st.FuncNames[idx]
|
||||
o, _ = sjson.SetBytes(o, "item.name", name)
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
// Ensure args buffer exists for this index
|
||||
if st.FuncArgsBuf[idx] == nil {
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
}
|
||||
|
||||
// Append arguments delta if available and we have a valid call_id to reference
|
||||
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
|
||||
// Prefer an already known call_id; fall back to newCallID if first time
|
||||
refCallID := st.FuncCallIDs[idx]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
tcs.ForEach(func(_, tc gjson.Result) bool {
|
||||
toolIndex := int(tc.Get("index").Int())
|
||||
key := toolStateKey(idx, toolIndex)
|
||||
newCallID := tc.Get("id").String()
|
||||
nameChunk := tc.Get("function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[key] = nameChunk
|
||||
}
|
||||
if refCallID != "" {
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", idx)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
|
||||
existingCallID := st.FuncCallIDs[key]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[key] = newCallID
|
||||
st.FuncOutputIx[key] = allocOutputIndex()
|
||||
shouldEmitItem = true
|
||||
}
|
||||
st.FuncArgsBuf[idx].WriteString(args.String())
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
if st.FuncArgsBuf[key] == nil {
|
||||
st.FuncArgsBuf[key] = &strings.Builder{}
|
||||
}
|
||||
|
||||
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
|
||||
refCallID := st.FuncCallIDs[key]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
}
|
||||
if refCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
}
|
||||
st.FuncArgsBuf[key].WriteString(args.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,15 +384,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
for i := range st.MsgItemAdded {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
|
||||
for _, i := range idxs {
|
||||
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
|
||||
msgOutputIndex := st.MsgOutputIx[i]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -376,7 +395,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
done, _ = sjson.SetBytes(done, "output_index", i)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -384,14 +403,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", i)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
@@ -407,43 +426,42 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
|
||||
// Emit function call done events for any active function calls
|
||||
if len(st.FuncCallIDs) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncCallIDs))
|
||||
for i := range st.FuncCallIDs {
|
||||
idxs = append(idxs, i)
|
||||
keys := make([]string, 0, len(st.FuncCallIDs))
|
||||
for key := range st.FuncCallIDs {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
callID := st.FuncCallIDs[i]
|
||||
if callID == "" || st.FuncItemDone[i] {
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
left := st.FuncOutputIx[keys[i]]
|
||||
right := st.FuncOutputIx[keys[j]]
|
||||
return left < right || (left == right && keys[i] < keys[j])
|
||||
})
|
||||
for _, key := range keys {
|
||||
callID := st.FuncCallIDs[key]
|
||||
if callID == "" || st.FuncItemDone[key] {
|
||||
continue
|
||||
}
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
|
||||
if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
}
|
||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i])
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.FuncItemDone[i] = true
|
||||
st.FuncArgsDone[i] = true
|
||||
st.FuncItemDone[key] = true
|
||||
st.FuncArgsDone[key] = true
|
||||
}
|
||||
}
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
@@ -516,28 +534,21 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
// Build response.output using aggregated buffers
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
type completedOutputItem struct {
|
||||
index int
|
||||
raw []byte
|
||||
}
|
||||
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||
if len(st.Reasonings) > 0 {
|
||||
for _, r := range st.Reasonings {
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||
}
|
||||
}
|
||||
// Append message items in ascending index order
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
midxs := make([]int, 0, len(st.MsgItemAdded))
|
||||
for i := range st.MsgItemAdded {
|
||||
midxs = append(midxs, i)
|
||||
}
|
||||
for i := 0; i < len(midxs); i++ {
|
||||
for j := i + 1; j < len(midxs); j++ {
|
||||
if midxs[j] < midxs[i] {
|
||||
midxs[i], midxs[j] = midxs[j], midxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range midxs {
|
||||
txt := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
txt = b.String()
|
||||
@@ -545,37 +556,29 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
||||
for i := range st.FuncArgsBuf {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
// small-N sort without extra imports
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
for key := range st.FuncArgsBuf {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[i]; b != nil {
|
||||
if b := st.FuncArgsBuf[key]; b != nil {
|
||||
args = b.String()
|
||||
}
|
||||
callID := st.FuncCallIDs[i]
|
||||
name := st.FuncNames[i]
|
||||
callID := st.FuncCallIDs[key]
|
||||
name := st.FuncNames[key]
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||
}
|
||||
}
|
||||
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||
for _, item := range outputItems {
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||
}
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||
}
|
||||
|
||||
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
if !gjson.Valid(dataLine) {
|
||||
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||
}
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
addedNames := map[string]string{}
|
||||
doneArgs := map[string]string{}
|
||||
doneNames := map[string]string{}
|
||||
outputItems := map[string]gjson.Result{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
doneArgs[callID] = data.Get("item.arguments").String()
|
||||
doneNames[callID] = data.Get("item.name").String()
|
||||
case "response.completed":
|
||||
output := data.Get("response.output")
|
||||
for _, item := range output.Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
outputItems[item.Get("call_id").String()] = item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(addedNames) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
|
||||
}
|
||||
if len(doneArgs) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
|
||||
}
|
||||
|
||||
if addedNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
|
||||
}
|
||||
if addedNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
|
||||
}
|
||||
|
||||
if !gjson.Valid(doneArgs["call_read"]) {
|
||||
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
|
||||
}
|
||||
if !gjson.Valid(doneArgs["call_glob"]) {
|
||||
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_read"], "}{") {
|
||||
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_glob"], "}{") {
|
||||
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
|
||||
}
|
||||
|
||||
if doneNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
|
||||
}
|
||||
if doneNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
|
||||
}
|
||||
|
||||
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected filePath for call_read: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected path for call_glob: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
|
||||
t.Fatalf("unexpected pattern for call_glob: %q", got)
|
||||
}
|
||||
|
||||
if len(outputItems) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
|
||||
}
|
||||
if outputItems["call_read"].Get("name").String() != "read" {
|
||||
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
|
||||
}
|
||||
if outputItems["call_glob"].Get("name").String() != "glob" {
|
||||
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
type fcEvent struct {
|
||||
outputIndex int64
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
added := map[string]fcEvent{}
|
||||
done := map[string]fcEvent{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
added[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
done[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
arguments: data.Get("item.arguments").String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(added) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(added))
|
||||
}
|
||||
if len(done) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(done))
|
||||
}
|
||||
|
||||
if added["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
|
||||
}
|
||||
if added["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
|
||||
}
|
||||
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
|
||||
}
|
||||
|
||||
if !gjson.Valid(done["call_choice0"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
|
||||
}
|
||||
if !gjson.Valid(done["call_choice1"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
|
||||
}
|
||||
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
|
||||
}
|
||||
if done["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
|
||||
}
|
||||
if done["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var messageOutputIndex int64 = -1
|
||||
var toolOutputIndex int64 = -1
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
if ev != "response.output_item.added" {
|
||||
continue
|
||||
}
|
||||
switch data.Get("item.type").String() {
|
||||
case "message":
|
||||
if data.Get("item.id").String() == "msg_resp_mixed_0" {
|
||||
messageOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
case "function_call":
|
||||
if data.Get("item.call_id").String() == "call_choice1" {
|
||||
toolOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messageOutputIndex < 0 {
|
||||
t.Fatal("did not find message output index")
|
||||
}
|
||||
if toolOutputIndex < 0 {
|
||||
t.Fatal("did not find tool output index")
|
||||
}
|
||||
if messageOutputIndex == toolOutputIndex {
|
||||
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var doneIndexes []int64
|
||||
var completedOrder []string
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "function_call" {
|
||||
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
|
||||
}
|
||||
case "response.completed":
|
||||
for _, item := range data.Get("response.output").Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
completedOrder = append(completedOrder, item.Get("call_id").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(doneIndexes) != 2 {
|
||||
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
|
||||
}
|
||||
if doneIndexes[0] >= doneIndexes[1] {
|
||||
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
|
||||
}
|
||||
if len(completedOrder) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
|
||||
}
|
||||
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
|
||||
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
|
||||
}
|
||||
}
|
||||
@@ -923,8 +923,10 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
auth.Index = existing.Index
|
||||
auth.indexAssigned = existing.indexAssigned
|
||||
}
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
if !existing.Disabled && existing.Status != StatusDisabled && !auth.Disabled && auth.Status != StatusDisabled {
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
}
|
||||
}
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
@@ -1732,77 +1734,79 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
}
|
||||
} else {
|
||||
if result.Model != "" {
|
||||
state := ensureModelState(auth, result.Model)
|
||||
state.Unavailable = true
|
||||
state.Status = StatusError
|
||||
state.UpdatedAt = now
|
||||
if result.Error != nil {
|
||||
state.LastError = cloneError(result.Error)
|
||||
state.StatusMessage = result.Error.Message
|
||||
auth.LastError = cloneError(result.Error)
|
||||
auth.StatusMessage = result.Error.Message
|
||||
}
|
||||
if !isRequestScopedNotFoundResultError(result.Error) {
|
||||
state := ensureModelState(auth, result.Model)
|
||||
state.Unavailable = true
|
||||
state.Status = StatusError
|
||||
state.UpdatedAt = now
|
||||
if result.Error != nil {
|
||||
state.LastError = cloneError(result.Error)
|
||||
state.StatusMessage = result.Error.Message
|
||||
auth.LastError = cloneError(result.Error)
|
||||
auth.StatusMessage = result.Error.Message
|
||||
}
|
||||
|
||||
statusCode := statusCodeFromResult(result.Error)
|
||||
if isModelSupportResultError(result.Error) {
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "model_not_supported"
|
||||
shouldSuspendModel = true
|
||||
} else {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
case 404:
|
||||
statusCode := statusCodeFromResult(result.Error)
|
||||
if isModelSupportResultError(result.Error) {
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
suspendReason = "model_not_supported"
|
||||
shouldSuspendModel = true
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
Exceeded: true,
|
||||
Reason: "quota",
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
} else {
|
||||
switch statusCode {
|
||||
case 401:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "unauthorized"
|
||||
shouldSuspendModel = true
|
||||
case 402, 403:
|
||||
next := now.Add(30 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "payment_required"
|
||||
shouldSuspendModel = true
|
||||
case 404:
|
||||
next := now.Add(12 * time.Hour)
|
||||
state.NextRetryAfter = next
|
||||
suspendReason = "not_found"
|
||||
shouldSuspendModel = true
|
||||
case 429:
|
||||
var next time.Time
|
||||
backoffLevel := state.Quota.BackoffLevel
|
||||
if result.RetryAfter != nil {
|
||||
next = now.Add(*result.RetryAfter)
|
||||
} else {
|
||||
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
|
||||
if cooldown > 0 {
|
||||
next = now.Add(cooldown)
|
||||
}
|
||||
backoffLevel = nextLevel
|
||||
}
|
||||
state.NextRetryAfter = next
|
||||
state.Quota = QuotaState{
|
||||
Exceeded: true,
|
||||
Reason: "quota",
|
||||
NextRecoverAt: next,
|
||||
BackoffLevel: backoffLevel,
|
||||
}
|
||||
suspendReason = "quota"
|
||||
shouldSuspendModel = true
|
||||
setModelQuota = true
|
||||
case 408, 500, 502, 503, 504:
|
||||
if quotaCooldownDisabledForAuth(auth) {
|
||||
state.NextRetryAfter = time.Time{}
|
||||
} else {
|
||||
next := now.Add(1 * time.Minute)
|
||||
state.NextRetryAfter = next
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
default:
|
||||
state.NextRetryAfter = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
auth.Status = StatusError
|
||||
auth.UpdatedAt = now
|
||||
updateAggregatedAvailability(auth, now)
|
||||
auth.Status = StatusError
|
||||
auth.UpdatedAt = now
|
||||
updateAggregatedAvailability(auth, now)
|
||||
}
|
||||
} else {
|
||||
applyAuthFailureState(auth, result.Error, result.RetryAfter, now)
|
||||
}
|
||||
@@ -2054,11 +2058,29 @@ func isModelSupportResultError(err *Error) bool {
|
||||
return isModelSupportErrorMessage(err.Message)
|
||||
}
|
||||
|
||||
func isRequestScopedNotFoundMessage(message string) bool {
|
||||
if message == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(message)
|
||||
return strings.Contains(lower, "item with id") &&
|
||||
strings.Contains(lower, "not found") &&
|
||||
strings.Contains(lower, "items are not persisted when `store` is set to false")
|
||||
}
|
||||
|
||||
func isRequestScopedNotFoundResultError(err *Error) bool {
|
||||
if err == nil || statusCodeFromResult(err) != http.StatusNotFound {
|
||||
return false
|
||||
}
|
||||
return isRequestScopedNotFoundMessage(err.Message)
|
||||
}
|
||||
|
||||
// isRequestInvalidError returns true if the error represents a client request
|
||||
// error that should not be retried. Specifically, it treats 400 responses with
|
||||
// "invalid_request_error" and all 422 responses as request-shape failures,
|
||||
// where switching auths or pooled upstream models will not help. Model-support
|
||||
// errors are excluded so routing can fall through to another auth or upstream.
|
||||
// "invalid_request_error", request-scoped 404 item misses caused by `store=false`,
|
||||
// and all 422 responses as request-shape failures, where switching auths or
|
||||
// pooled upstream models will not help. Model-support errors are excluded so
|
||||
// routing can fall through to another auth or upstream.
|
||||
func isRequestInvalidError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -2070,6 +2092,8 @@ func isRequestInvalidError(err error) bool {
|
||||
switch status {
|
||||
case http.StatusBadRequest:
|
||||
return strings.Contains(err.Error(), "invalid_request_error")
|
||||
case http.StatusNotFound:
|
||||
return isRequestScopedNotFoundMessage(err.Error())
|
||||
case http.StatusUnprocessableEntity:
|
||||
return true
|
||||
default:
|
||||
@@ -2081,6 +2105,9 @@ func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Durati
|
||||
if auth == nil {
|
||||
return
|
||||
}
|
||||
if isRequestScopedNotFoundResultError(resultErr) {
|
||||
return
|
||||
}
|
||||
auth.Unavailable = true
|
||||
auth.Status = StatusError
|
||||
auth.UpdatedAt = now
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
const requestScopedNotFoundMessage = "Item with id 'rs_0b5f3eb6f51f175c0169ca74e4a85881998539920821603a74' not found. Items are not persisted when `store` is set to false. Try again with `store` set to true, or remove this item from your input."
|
||||
|
||||
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
m.SetRetryConfig(3, 30*time.Second, 0)
|
||||
@@ -447,3 +449,114 @@ func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
|
||||
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_MarkResult_RequestScopedNotFoundDoesNotCooldownAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "auth-1",
|
||||
Provider: "openai",
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
|
||||
model := "gpt-4.1"
|
||||
m.MarkResult(context.Background(), Result{
|
||||
AuthID: auth.ID,
|
||||
Provider: auth.Provider,
|
||||
Model: model,
|
||||
Success: false,
|
||||
Error: &Error{
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
Message: requestScopedNotFoundMessage,
|
||||
},
|
||||
})
|
||||
|
||||
updated, ok := m.GetByID(auth.ID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if updated.Unavailable {
|
||||
t.Fatalf("expected request-scoped 404 to keep auth available")
|
||||
}
|
||||
if !updated.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected request-scoped 404 to keep auth cooldown unset, got %v", updated.NextRetryAfter)
|
||||
}
|
||||
if state := updated.ModelStates[model]; state != nil {
|
||||
t.Fatalf("expected request-scoped 404 to avoid model cooldown state, got %#v", state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_RequestScopedNotFoundStopsRetryWithoutSuspendingAuth(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
executor := &authFallbackExecutor{
|
||||
id: "openai",
|
||||
executeErrors: map[string]error{
|
||||
"aa-bad-auth": &Error{
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
Message: requestScopedNotFoundMessage,
|
||||
},
|
||||
},
|
||||
}
|
||||
m.RegisterExecutor(executor)
|
||||
|
||||
model := "gpt-4.1"
|
||||
badAuth := &Auth{ID: "aa-bad-auth", Provider: "openai"}
|
||||
goodAuth := &Auth{ID: "bb-good-auth", Provider: "openai"}
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient(badAuth.ID, "openai", []*registry.ModelInfo{{ID: model}})
|
||||
reg.RegisterClient(goodAuth.ID, "openai", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient(badAuth.ID)
|
||||
reg.UnregisterClient(goodAuth.ID)
|
||||
})
|
||||
|
||||
if _, errRegister := m.Register(context.Background(), badAuth); errRegister != nil {
|
||||
t.Fatalf("register bad auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := m.Register(context.Background(), goodAuth); errRegister != nil {
|
||||
t.Fatalf("register good auth: %v", errRegister)
|
||||
}
|
||||
|
||||
_, errExecute := m.Execute(context.Background(), []string{"openai"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
|
||||
if errExecute == nil {
|
||||
t.Fatal("expected request-scoped not-found error")
|
||||
}
|
||||
errResult, ok := errExecute.(*Error)
|
||||
if !ok {
|
||||
t.Fatalf("expected *Error, got %T", errExecute)
|
||||
}
|
||||
if errResult.HTTPStatus != http.StatusNotFound {
|
||||
t.Fatalf("status = %d, want %d", errResult.HTTPStatus, http.StatusNotFound)
|
||||
}
|
||||
if errResult.Message != requestScopedNotFoundMessage {
|
||||
t.Fatalf("message = %q, want %q", errResult.Message, requestScopedNotFoundMessage)
|
||||
}
|
||||
|
||||
got := executor.ExecuteCalls()
|
||||
want := []string{badAuth.ID}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("execute calls = %v, want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("execute call %d auth = %q, want %q", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
|
||||
updatedBad, ok := m.GetByID(badAuth.ID)
|
||||
if !ok || updatedBad == nil {
|
||||
t.Fatalf("expected bad auth to remain registered")
|
||||
}
|
||||
if updatedBad.Unavailable {
|
||||
t.Fatalf("expected request-scoped 404 to keep bad auth available")
|
||||
}
|
||||
if !updatedBad.NextRetryAfter.IsZero() {
|
||||
t.Fatalf("expected request-scoped 404 to keep bad auth cooldown unset, got %v", updatedBad.NextRetryAfter)
|
||||
}
|
||||
if state := updatedBad.ModelStates[model]; state != nil {
|
||||
t.Fatalf("expected request-scoped 404 to avoid bad auth model cooldown state, got %#v", state)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,3 +47,158 @@ func TestManager_Update_PreservesModelStates(t *testing.T) {
|
||||
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_DisabledExistingDoesNotInheritModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register a disabled auth with existing ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-disabled",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 5},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Update with empty ModelStates — should NOT inherit stale states.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-disabled",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-disabled")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected disabled auth NOT to inherit ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_ActiveToDisabledDoesNotInheritModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register an active auth with ModelStates (simulates existing live auth).
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-a2d",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 9},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// File watcher deletes config → synthesizes Disabled=true auth → Update.
|
||||
// Even though existing is active, incoming auth is disabled → skip inheritance.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-a2d",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-a2d")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected active→disabled transition NOT to inherit ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_DisabledToActiveDoesNotInheritStaleModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register a disabled auth with stale ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-d2a",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 4},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Re-enable: incoming auth is active, existing is disabled → skip inheritance.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-d2a",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-d2a")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected disabled→active transition NOT to inherit stale ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_ActiveInheritsModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
model := "active-model"
|
||||
backoffLevel := 3
|
||||
|
||||
// Register an active auth with ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-active",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Quota: QuotaState{BackoffLevel: backoffLevel},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Update with empty ModelStates — both sides active → SHOULD inherit.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-active",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-active")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) == 0 {
|
||||
t.Fatalf("expected active auth to inherit ModelStates")
|
||||
}
|
||||
state := updated.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state to be present")
|
||||
}
|
||||
if state.Quota.BackoffLevel != backoffLevel {
|
||||
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -293,12 +293,46 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
||||
}
|
||||
|
||||
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
|
||||
start := 0
|
||||
if len(normalized) > 0 {
|
||||
start = s.mixedCursors[cursorKey] % len(normalized)
|
||||
weights := make([]int, len(normalized))
|
||||
segmentStarts := make([]int, len(normalized))
|
||||
segmentEnds := make([]int, len(normalized))
|
||||
totalWeight := 0
|
||||
for providerIndex, shard := range candidateShards {
|
||||
segmentStarts[providerIndex] = totalWeight
|
||||
if shard != nil {
|
||||
weights[providerIndex] = shard.readyCountAtPriorityLocked(false, bestPriority)
|
||||
}
|
||||
totalWeight += weights[providerIndex]
|
||||
segmentEnds[providerIndex] = totalWeight
|
||||
}
|
||||
if totalWeight == 0 {
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
startSlot := s.mixedCursors[cursorKey] % totalWeight
|
||||
startProviderIndex := -1
|
||||
for providerIndex := range normalized {
|
||||
if weights[providerIndex] == 0 {
|
||||
continue
|
||||
}
|
||||
if startSlot < segmentEnds[providerIndex] {
|
||||
startProviderIndex = providerIndex
|
||||
break
|
||||
}
|
||||
}
|
||||
if startProviderIndex < 0 {
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
}
|
||||
|
||||
slot := startSlot
|
||||
for offset := 0; offset < len(normalized); offset++ {
|
||||
providerIndex := (start + offset) % len(normalized)
|
||||
providerIndex := (startProviderIndex + offset) % len(normalized)
|
||||
if weights[providerIndex] == 0 {
|
||||
continue
|
||||
}
|
||||
if providerIndex != startProviderIndex {
|
||||
slot = segmentStarts[providerIndex]
|
||||
}
|
||||
providerKey := normalized[providerIndex]
|
||||
shard := candidateShards[providerIndex]
|
||||
if shard == nil {
|
||||
@@ -308,7 +342,7 @@ func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model
|
||||
if picked == nil {
|
||||
continue
|
||||
}
|
||||
s.mixedCursors[cursorKey] = providerIndex + 1
|
||||
s.mixedCursors[cursorKey] = slot + 1
|
||||
return picked, providerKey, nil
|
||||
}
|
||||
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
|
||||
@@ -704,6 +738,20 @@ func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priorit
|
||||
return picked.auth
|
||||
}
|
||||
|
||||
func (m *modelScheduler) readyCountAtPriorityLocked(preferWebsocket bool, priority int) int {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
bucket := m.readyByPriority[priority]
|
||||
if bucket == nil {
|
||||
return 0
|
||||
}
|
||||
if preferWebsocket && len(bucket.ws.flat) > 0 {
|
||||
return len(bucket.ws.flat)
|
||||
}
|
||||
return len(bucket.all.flat)
|
||||
}
|
||||
|
||||
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
|
||||
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
|
||||
now := time.Now()
|
||||
|
||||
@@ -208,7 +208,7 @@ func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
|
||||
func TestSchedulerPick_MixedProvidersUsesWeightedProviderRotationOverReadyCandidates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scheduler := newSchedulerForTest(
|
||||
@@ -218,8 +218,8 @@ func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *
|
||||
&Auth{ID: "claude-a", Provider: "claude"},
|
||||
)
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||
for index := range wantProviders {
|
||||
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
@@ -272,7 +272,7 @@ func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
func TestManager_PickNextMixed_UsesWeightedProviderRotationBeforeCredentialRotation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, &RoundRobinSelector{}, nil)
|
||||
@@ -288,8 +288,8 @@ func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *t
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
|
||||
if errPick != nil {
|
||||
@@ -399,8 +399,8 @@ func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
|
||||
t.Fatalf("Register(claude-a) error = %v", errRegister)
|
||||
}
|
||||
|
||||
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
|
||||
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
|
||||
wantProviders := []string{"gemini", "gemini", "claude", "gemini"}
|
||||
wantIDs := []string{"gemini-a", "gemini-b", "claude-a", "gemini-a"}
|
||||
for index := range wantProviders {
|
||||
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
|
||||
if errPick != nil {
|
||||
|
||||
@@ -286,10 +286,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
var err error
|
||||
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
if !existing.Disabled && existing.Status != coreauth.StatusDisabled && !auth.Disabled && auth.Status != coreauth.StatusDisabled {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
}
|
||||
}
|
||||
op = "update"
|
||||
_, err = s.coreManager.Update(ctx, auth)
|
||||
|
||||
85
sdk/cliproxy/service_stale_state_test.go
Normal file
85
sdk/cliproxy/service_stale_state_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestServiceApplyCoreAuthAddOrUpdate_DeleteReAddDoesNotInheritStaleRuntimeState(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
|
||||
authID := "service-stale-state-auth"
|
||||
modelID := "stale-model"
|
||||
lastRefreshedAt := time.Date(2026, time.March, 1, 8, 0, 0, 0, time.UTC)
|
||||
nextRefreshAfter := lastRefreshedAt.Add(30 * time.Minute)
|
||||
|
||||
t.Cleanup(func() {
|
||||
GlobalModelRegistry().UnregisterClient(authID)
|
||||
})
|
||||
|
||||
service.applyCoreAuthAddOrUpdate(context.Background(), &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: "claude",
|
||||
Status: coreauth.StatusActive,
|
||||
LastRefreshedAt: lastRefreshedAt,
|
||||
NextRefreshAfter: nextRefreshAfter,
|
||||
ModelStates: map[string]*coreauth.ModelState{
|
||||
modelID: {
|
||||
Quota: coreauth.QuotaState{BackoffLevel: 7},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
service.applyCoreAuthRemoval(context.Background(), authID)
|
||||
|
||||
disabled, ok := service.coreManager.GetByID(authID)
|
||||
if !ok || disabled == nil {
|
||||
t.Fatalf("expected disabled auth after removal")
|
||||
}
|
||||
if !disabled.Disabled || disabled.Status != coreauth.StatusDisabled {
|
||||
t.Fatalf("expected disabled auth after removal, got disabled=%v status=%v", disabled.Disabled, disabled.Status)
|
||||
}
|
||||
if disabled.LastRefreshedAt.IsZero() {
|
||||
t.Fatalf("expected disabled auth to still carry prior LastRefreshedAt for regression setup")
|
||||
}
|
||||
if disabled.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("expected disabled auth to still carry prior NextRefreshAfter for regression setup")
|
||||
}
|
||||
if len(disabled.ModelStates) == 0 {
|
||||
t.Fatalf("expected disabled auth to still carry prior ModelStates for regression setup")
|
||||
}
|
||||
|
||||
service.applyCoreAuthAddOrUpdate(context.Background(), &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: "claude",
|
||||
Status: coreauth.StatusActive,
|
||||
})
|
||||
|
||||
updated, ok := service.coreManager.GetByID(authID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected re-added auth to be present")
|
||||
}
|
||||
if updated.Disabled {
|
||||
t.Fatalf("expected re-added auth to be active")
|
||||
}
|
||||
if !updated.LastRefreshedAt.IsZero() {
|
||||
t.Fatalf("expected LastRefreshedAt to reset on delete -> re-add, got %v", updated.LastRefreshedAt)
|
||||
}
|
||||
if !updated.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("expected NextRefreshAfter to reset on delete -> re-add, got %v", updated.NextRefreshAfter)
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected ModelStates to reset on delete -> re-add, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
if models := registry.GetGlobalRegistry().GetModelsForClient(authID); len(models) == 0 {
|
||||
t.Fatalf("expected re-added auth to re-register models in global registry")
|
||||
}
|
||||
}
|
||||
@@ -68,14 +68,18 @@ func Parse(raw string) (Setting, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func cloneDefaultTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
return transport.Clone()
|
||||
}
|
||||
return &http.Transport{}
|
||||
}
|
||||
|
||||
// NewDirectTransport returns a transport that bypasses environment proxies.
|
||||
func NewDirectTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
return &http.Transport{Proxy: nil}
|
||||
clone := cloneDefaultTransport()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
|
||||
@@ -102,14 +106,16 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
|
||||
if errSOCKS5 != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}, setting.Mode, nil
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Proxy = nil
|
||||
transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
return transport, setting.Mode, nil
|
||||
}
|
||||
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Proxy = http.ProxyURL(setting.URL)
|
||||
return transport, setting.Mode, nil
|
||||
default:
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,16 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustDefaultTransport(t *testing.T) *http.Transport {
|
||||
t.Helper()
|
||||
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatal("http.DefaultTransport is not an *http.Transport")
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -86,4 +96,44 @@ func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
|
||||
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
|
||||
}
|
||||
|
||||
defaultTransport := mustDefaultTransport(t)
|
||||
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
|
||||
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
|
||||
}
|
||||
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
|
||||
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport, mode, errBuild := BuildHTTPTransport("socks5://proxy.example.com:1080")
|
||||
if errBuild != nil {
|
||||
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
|
||||
}
|
||||
if mode != ModeProxy {
|
||||
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||
}
|
||||
if transport == nil {
|
||||
t.Fatal("expected transport, got nil")
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected SOCKS5 transport to bypass http proxy function")
|
||||
}
|
||||
|
||||
defaultTransport := mustDefaultTransport(t)
|
||||
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
|
||||
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
|
||||
}
|
||||
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
|
||||
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user