Compare commits

..

50 Commits

Author SHA1 Message Date
airkjw a9472dfdee docs: add v7.1.20 update history 2026-05-24 14:34:12 +09:00
airkjw 087045a5f1 docs: restore deployment guides after v7.1.20 reset 2026-05-24 14:33:42 +09:00
Luis Pater aaec9194d5 feat(models): add Grok Build 0.1 to registry
- Registered `grok-build-0.1` model with enhanced context length and agentic engineering support.
- Supports dynamic thinking levels for improved software workflows.
2026-05-23 22:49:36 +08:00
Luis Pater 33f4904b25 fix(translator): handle system role as developer in Claude request conversion
- Updated `ConvertClaudeRequestToGemini` logic to treat `system` role as `developer`.
- Added unit test case to validate the behavior.

Closes: #3510
2026-05-22 12:04:27 +08:00
Luis Pater cecd39317d Merge pull request #3498 from router-for-me/test
fix(auth): update import paths to v7 for registry and executor
2026-05-21 10:50:58 +08:00
hkfires 3c62a9a9b0 fix(auth): update import paths to v7 for registry and executor 2026-05-21 10:00:22 +08:00
Luis Pater 21fad9dbb4 Merge pull request #3477 from router-for-me/cluster
Add cluster-specific docker-compose configuration for CLIProxyAPI
2026-05-21 03:00:50 +08:00
Luis Pater 48a1c88115 Merge pull request #3476 from sususu98/fix/codex-context-length-stream-errors-dev
fix codex context length stream errors
2026-05-21 02:53:54 +08:00
Luis Pater 8b9ecffc2f Merge pull request #3382 from sususu98/dev
fix: scope antigravity credits fallback gate
2026-05-21 02:52:49 +08:00
Luis Pater 42e9605871 Merge pull request #3254 from sususu98/fix/antigravity-project-id-onboard
fix: require antigravity project id
2026-05-21 02:52:32 +08:00
Luis Pater a726e37394 feat(redis): enhance Redis protocol handling with subscription and queue operations
- Added support for advanced RESP commands (`AUTH`, `SUBSCRIBE`, `RPOP`, `LPOP`) with extended functionality.
- Implemented queue operations for usage events via `RPOP` and `LPOP` commands.
- Introduced subscription handling with new Pub/Sub message features and error handling improvements.
- Updated Redis connection logic to enforce authentication requirements and validate inputs.
- Expanded related unit tests to cover new scenarios and edge cases.
2026-05-20 17:20:03 +08:00
Luis Pater f1ee883cd3 Merge pull request #3484 from yavon007/main
Add reasoning_effort to usage event payloads
2026-05-20 12:34:40 +08:00
Luis Pater 1c632d151d fix(translator): skip empty text parts in Claude request conversion
- Updated `ConvertClaudeRequestToGemini` to ignore empty `text` entries during processing.
- Added unit tests to ensure empty `text` parts are skipped correctly.

Closes: #3485
2026-05-20 11:59:31 +08:00
Luis Pater 0ec07e57dd feat(models): add Gemini 3.5 Flash to registry with enhanced thinking capabilities
- Registered `gemini-3.5-flash` model with dynamic thinking levels and extended token limits.
- Supports multiple generation methods, including cached and batch content creation.
2026-05-20 10:53:31 +08:00
Luis Pater fdffe49974 feat(models): register Gemini 3.5 Flash with dynamic thinking levels
- Added new model `gemini-3.5-flash` to the registry with enhanced intelligence and speed capabilities.
- Supports extended thinking levels (`minimal`, `low`, `medium`, `high`) and dynamic adjustments.
- Expanded generation methods, including content creation and token counting.
2026-05-20 10:50:02 +08:00
Luis Pater de0394917a feat(models): expand supported reasoning levels for Codex
- Added new reasoning levels: `none`, `minimal`, and `unsupported` to Codex model configurations.
- Introduced metadata sanitization and normalization for reasoning levels in API response.
- Extended unit tests to cover reasoning levels validation and metadata sanitation logic.
2026-05-20 03:21:46 +08:00
Luis Pater ea25949479 feat(models): add Gemini 3.5 Flash models to registry
- Registered new models: `gemini-3-flash-agent` and `gemini-3.5-flash-low` with detailed specifications.
- Includes support for dynamic thinking levels and extended context capabilities.
2026-05-20 02:17:49 +08:00
Luis Pater 99fa530967 test: remove unused Redis protocol tests and helpers
- Removed obsolete Redis protocol test cases and helper functions that were no longer relevant due to recent architecture changes.
- Streamlined remaining test files to align with updated Redis handling and connection management logic.
2026-05-19 23:12:57 +08:00
Luis Pater b9589e8ed6 Merge pull request #3482 from 9ycrooked/patch-1
Add Codex Switch tool to README
2026-05-19 22:59:28 +08:00
yavon007 0de0ad0d36 Add reasoning effort to usage events 2026-05-19 22:10:48 +08:00
Xinyao Xu 5ef7693933 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-05-19 22:05:52 +08:00
Xinyao Xu 7f68fa2414 Add Codex Switch tool to README
Added a new section for Codex Switch tool with details.
2026-05-19 18:00:28 +08:00
Luis Pater bb5ac40a67 feat(client): add timeout handling for Redis operations and subscription failover
- Introduced `homeRedisOperationTimeout` and `homeSubscriptionReceiveTimeout` constants for configurable timeouts.
- Enhanced Redis connection options with operation timeout settings and failover mechanisms.
- Implemented subscription failover logic on heartbeat timeouts to improve resilience.
- Updated message handling to support additional Redis event types, including Pong and Subscription.
2026-05-19 16:44:42 +08:00
hkfires 7efc1629ba feat(docker): add cluster-specific docker-compose configuration for CLIProxyAPI 2026-05-19 16:24:34 +08:00
Luis Pater 67f22514ed style(docs): improve sponsor section clarity in README files
- Updated text formatting with bold emphasis for consistent branding.
- Refined wording for VisionCoder's promotion details in Chinese, Japanese, and English README.
2026-05-19 16:11:48 +08:00
sususu98 ad868308c0 fix codex context length stream errors 2026-05-19 16:05:40 +08:00
Luis Pater bbe30f53b5 feat(server): enhance Home certificate handling with CA fingerprint verification
- Added support for `ClusterID`, `CAFingerprint`, and `EnrollmentSecret` in Home JWT claims.
- Implemented CA fingerprint normalization and verification for PEM and file-based certificates.
- Improved certificate request validation and error handling.
- Updated server-side logic to include `EnrollmentSecret` in certificate requests.
2026-05-19 10:25:57 +08:00
Luis Pater feebe6c7f2 feat(api): add OpenAI compatibility for image models
- Introduced OpenAI-compatible image model support in the API, enabling integration through image generation and editing endpoints.
- Added registry type for OpenAIImageModelType to classify and validate compatibility.
- Implemented request handling for OpenAI-compatible image models, including JSON and multipart formats.
- Enhanced executor methods to support OpenAI-compatible image streaming and non-streaming requests.
- Included tests to validate model registration, streaming behavior, and multipart payload formatting.
2026-05-19 10:13:26 +08:00
sususu98 b67eb6f25d Merge pull request #3470 from sususu98/fix/antigravity-gemini-thought-signatures
Fix Antigravity Gemini thought signatures
2026-05-19 09:48:25 +08:00
sususu98 644823529f Merge pull request #3469 from sususu98/fix/gemini-max-output-token-cap
Cap Gemini max output tokens
2026-05-19 09:48:08 +08:00
Luis Pater bac006e72b feat(thinking): add xAI provider support with reasoning.effort implementation
- Implemented `xAI` provider for thinking configurations with support for reasoning.effort levels.
- Registered `xAI` in available providers and updated relevant APIs for compatibility.
- Added unit tests for `xAI` provider functionality, including fallback logic for unsupported levels.
- Integrated `xAI` with executor handling and ensured conformance with OpenAI-compatible standards.
2026-05-19 03:09:53 +08:00
Luis Pater ad98c9549a feat(runtime): track upstream response headers in logging and usage reporting
- Added APIs to store, retrieve, and clone upstream response headers in context for detailed logging.
- Updated `RecordAPIResponseMetadata`, `RecordAPIWebsocketHandshake`, and related methods to capture response headers.
- Extended `UsageReporter` to include response headers in published usage records.
- Enhanced payload tests to validate response headers' integrity and persistence.
- Refactored `usage.Record` to support optional `ResponseHeaders` field.
2026-05-19 01:29:23 +08:00
Luis Pater 77ba15f71b feat(server): add mTLS certificate bootstrap via JWT for Home connections
- Introduced `-home-jwt` flag and `HOME_JWT` environment variable to provide JWT for mTLS certificate generation.
- Added new APIs to handle certificate requests, validate JWT claims, and manage local certificate files.
- Updated Home TLS configuration to support client certificates, keys, and dynamic server name resolution.
2026-05-19 00:53:40 +08:00
sususu98 32a0d69b17 Fix Antigravity Gemini thought signatures 2026-05-18 19:01:51 +08:00
sususu98 1583cb4ef0 Cap Gemini max output tokens 2026-05-18 18:41:45 +08:00
Luis Pater cc0cb057b3 Merge pull request #3468 from sususu98/fix/claude-codex-call-id-length
Fix Claude-Codex long tool call IDs
2026-05-18 18:04:55 +08:00
Luis Pater 2710f56ae1 Merge pull request #3450 from sususu98/fix/http-connect-proxy-dialer
fix(proxy): support HTTP CONNECT dialer
2026-05-18 18:03:41 +08:00
sususu98 8bc2eff58a fix: shorten claude codex tool call ids 2026-05-18 17:49:42 +08:00
sususu98 ec79951e7f fix(proxy): support HTTP CONNECT dialer 2026-05-18 12:20:41 +08:00
Luis Pater 24602055a8 Merge pull request #2926 from slicenferqin/fix-tool-use-name-loss-and-duplicates
fix(openai→claude): suppress empty/duplicate tool_use content_block_start
2026-05-18 12:11:41 +08:00
Luis Pater 4ad6ffefb7 Merge pull request #3438 from madwiki/fix/strip-claude-code-attribution
fix: strip Claude Code attribution from non-Anthropic translations
2026-05-18 11:25:38 +08:00
slicenfer 1c2153a2cb fix(openai-claude): stabilize streaming tool_use blocks 2026-05-18 11:25:33 +08:00
Luis Pater 64d233fe93 Merge pull request #3448 from LongDinhh/feat/home-env-vars
feat(server): add HOME_ADDR and HOME_PASSWORD env var fallback
2026-05-18 11:20:18 +08:00
Luis Pater 66c5d60b3d refactor(api): remove newTestServerWithOptions and spoofed IP rejection test
- Simplified test server initialization by removing `newTestServerWithOptions`.
- Deleted `TestManagementLocalPasswordRejectsSpoofedForwardedFor` as spoofed IP handling is no longer applicable.
- Removed trusted proxy configuration from Gin engine setup.
2026-05-18 11:01:10 +08:00
Long Dinh 5f039654f0 refactor: move home env vars after godotenv and use lookupEnv helper
Address review feedback: move HOME_ADDR/HOME_PASSWORD lookup after
godotenv.Load() so .env files work, and use the lookupEnv helper
for case-insensitive key support consistent with PGSTORE_* etc.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-18 08:52:57 +07:00
Long Dinh ed0ac68324 feat(server): add HOME_ADDR and HOME_PASSWORD env var fallback for home flags
Allow configuring the home control plane connection via environment
variables HOME_ADDR and HOME_PASSWORD as an alternative to the --home
and --home-password command-line flags. This enables Docker Swarm stack
deployments without needing docker service update --args.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-18 03:11:19 +07:00
Mad Wiki d606faa99c fix: strip Claude Code attribution from non-Anthropic translations 2026-05-17 04:21:53 +08:00
sususu98 bfdc0b3989 fix: scope antigravity credits fallback gate 2026-05-13 18:17:22 +08:00
sususu98 809feb1e86 fix(antigravity): mask project_id in logs 2026-05-07 16:28:53 +08:00
sususu98 33130f18d2 fix: require antigravity project id 2026-05-07 12:55:31 +08:00
96 changed files with 5988 additions and 1335 deletions
+5
View File
@@ -0,0 +1,5 @@
# Cluster JWT example.
# After deploying https://github.com/router-for-me/CLIProxyAPIHome, get the JWT value with:
# curl -sS -X POST "http://<home-host>:8327/v0/management/certificates/clients" -H "X-MANAGEMENT-KEY: <management-key>" | jq -r '.home_jwt'
# Then paste it into HOME_JWT here or export it before starting Compose.
HOME_JWT=your-home-jwt-here
+1
View File
@@ -215,6 +215,7 @@ sudo /usr/local/bin/docker logs -f cli-proxy-api
| 날짜 | 버전 | 비고 |
|------|------|------|
| 2026-05-24 | v7.1.20 | v7.1.10 → v7.1.20 패치 — Claude tool-use 이름 손실/중복 수정(v7.1.12), Claude 요청 변환 system→developer 처리(v7.1.20), Gemini 3.5 Flash 모델 추가(v7.1.18), Grok Build 0.1, Redis timeout/failover, xAI reasoning.effort. 무중단, 재인증 불필요 |
| 2026-05-18 | v7.1.10 | 메이저 v6→v7 — Home Control Plane(Redis) 신설, ClaudeCodeSessionAffinity 제거, Usage tracking 제거(v6.10.0), xAI Grok 이미지/비디오, Codex client models, Local mgmt password validation + spoofed IP rejection. Auth 파일 호환(재인증 불필요), config 신규 필드 모두 옵션 |
| 2026-05-04 | v6.10.4 | 69개 커밋 변경 — WebSocket compact 처리 개선, X-Amp-Thread-Id 기반 session affinity, Codex reasoning/이미지 처리 강화, GPT-5.5 모델 추가, OpenAI 호환 provider 비활성화 옵션. 무중단 업데이트, 재인증 불필요 |
| 2026-04-26 | v6.9.38 | Protocol multiplexer + Redis queue 도입, 관리키/Redis AUTH 반복 실패 시 IP 차단 추가. 무중단 업데이트, 재인증 불필요 |
+6 -2
View File
@@ -32,9 +32,9 @@ PackyCode provides special discounts for our software users: register using <a h
</tr>
<tr>
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
<td>Thanks to VisionCoder for supporting this project. <a href="https://coder.visioncoder.cn" target="_blank">VisionCoder Developer Platform</a> is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity.
<td>Thanks to <b>VisionCoder</b> for supporting this project. <a href="https://coder.visioncoder.cn" target="_blank">VisionCoder Developer Platform</a> is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity.
<p></p>
VisionCoder is also offering our users a limited-time <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> promotion: buy 1 month and get 1 month free.</td>
VisionCoder is also offering our users a limited-time <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> promotion: <b>buy 1 month and get 1 month free</b>.</td>
</tr>
</tbody>
</table>
@@ -222,6 +222,10 @@ OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoin
A public CLIProxyAPI-compatible fork and bundled management panel. It keeps upstream-style usage while restoring built-in usage statistics, adding cache hit rate, first-byte latency, TPS tracking, and Docker-oriented self-hosted installation docs.
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
This is a tool built with Tauri 2 + Vue 3 for managing multiple OpenAI Codex desktop accounts. Switch between saved ChatGPT/Codex certification profiles, check 5-hour and weekly quota usage in real time, verify token health, view active account details, and import or save auth.json files without manual copying.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
+6 -2
View File
@@ -32,9 +32,9 @@ PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.p
</tr>
<tr>
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
<td>感谢 VisionCoder 对本项目的支持。<a href="https://coder.visioncoder.cn" target="_blank">VisionCoder 开发平台</a> 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。
<td>感谢 <b>VisionCoder</b> 对本项目的支持。<a href="https://coder.visioncoder.cn" target="_blank">VisionCoder 开发平台</a> 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。
<p></p>
VisionCoder 还为我们的用户提供 <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> 限时活动:购买 1 个月,赠送 1 个月。</td>
VisionCoder 还为我们的用户提供 <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> 限时活动:<b>购买 1 个月,赠送 1 个月</b>。</td>
</tr>
</tbody>
</table>
@@ -218,6 +218,10 @@ OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼
一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
这是一个使用 Tauri 2 + Vue 3 构建的工具,用于管理多个 OpenAI Codex 桌面账户。它可以在已保存的 ChatGPT/Codex 认证配置之间切换,实时查看 5 小时和每周配额使用情况,验证 token 健康状态,查看当前账户详情,并在无需手动复制的情况下导入或保存 auth.json 文件。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
+5 -1
View File
@@ -32,7 +32,7 @@ PackyCodeは当ソフトウェアのユーザーに特別割引を提供して
</tr>
<tr>
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
<td>VisionCoderのご支援に感謝します!<a href="https://coder.visioncoder.cn">VisionCoder 開発プラットフォーム</a> は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに <a href="https://coder.visioncoder.cn">Token Plan</a> の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。</td>
<td><b>VisionCoder</b>のご支援に感謝します!<a href="https://coder.visioncoder.cn">VisionCoder 開発プラットフォーム</a> は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに <a href="https://coder.visioncoder.cn">Token Plan</a> の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。</td>
</tr>
</tbody>
</table>
@@ -217,6 +217,10 @@ OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:
上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
Tauri 2 + Vue 3で構築された、複数のOpenAI Codexデスクトップアカウントを管理するためのツールです。保存済みのChatGPT/Codex認証プロファイルを切り替え、5時間および週次クォータ使用量をリアルタイムで確認し、tokenの状態を検証し、現在のアカウント詳細を表示し、手動コピーなしでauth.jsonファイルをインポートまたは保存できます。
> [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
-77
View File
@@ -1,77 +0,0 @@
package main
import "testing"
func TestParseHomeFlagConfigHostPort(t *testing.T) {
cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if !cfg.Enabled {
t.Fatal("Enabled = false, want true")
}
if cfg.Host != "home.example.com" {
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
}
if cfg.Port != 8327 {
t.Fatalf("Port = %d, want 8327", cfg.Port)
}
if cfg.Password != "secret" {
t.Fatalf("Password = %q, want secret", cfg.Password)
}
if cfg.TLS.Enable {
t.Fatal("TLS.Enable = true, want false")
}
}
func TestParseHomeFlagConfigRediss(t *testing.T) {
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if cfg.Host != "home.example.com" {
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
}
if cfg.Port != 444 {
t.Fatalf("Port = %d, want 444", cfg.Port)
}
if cfg.Password != "url-secret" {
t.Fatalf("Password = %q, want url-secret", cfg.Password)
}
if !cfg.TLS.Enable {
t.Fatal("TLS.Enable = false, want true")
}
if cfg.TLS.ServerName != "home.example.com" {
t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName)
}
if !cfg.TLS.InsecureSkipVerify {
t.Fatal("TLS.InsecureSkipVerify = false, want true")
}
if cfg.TLS.CACert != "C:/certs/ca.pem" {
t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert)
}
}
func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) {
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if cfg.Password != "flag-secret" {
t.Fatalf("Password = %q, want flag-secret", cfg.Password)
}
}
func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) {
cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "")
if err != nil {
t.Fatalf("parseHomeFlagConfig() error = %v", err)
}
if !cfg.DisableClusterDiscovery {
t.Fatal("DisableClusterDiscovery = false, want true")
}
}
+18 -128
View File
@@ -10,11 +10,9 @@ import (
"fmt"
"io"
"io/fs"
"net"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
@@ -53,120 +51,6 @@ func init() {
buildinfo.BuildDate = BuildDate
}
func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) {
rawAddr = strings.TrimSpace(rawAddr)
if rawAddr == "" {
return config.HomeConfig{}, fmt.Errorf("address is empty")
}
if strings.Contains(rawAddr, "://") {
return parseHomeURLConfig(rawAddr, password)
}
host, portStr, errSplit := net.SplitHostPort(rawAddr)
if errSplit != nil {
return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit)
}
host = strings.TrimSpace(host)
if host == "" {
return config.HomeConfig{}, fmt.Errorf("host is empty")
}
port, errPort := parseHomePort(portStr)
if errPort != nil {
return config.HomeConfig{}, errPort
}
return config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: password,
}, nil
}
func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) {
parsed, errParse := url.Parse(rawAddr)
if errParse != nil {
return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse)
}
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
if scheme != "redis" && scheme != "rediss" {
return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme)
}
host := strings.TrimSpace(parsed.Hostname())
if host == "" {
return config.HomeConfig{}, fmt.Errorf("host is empty")
}
port, errPort := parseHomePort(parsed.Port())
if errPort != nil {
return config.HomeConfig{}, errPort
}
if password == "" && parsed.User != nil {
if urlPassword, ok := parsed.User.Password(); ok {
password = urlPassword
}
}
homeCfg := config.HomeConfig{
Enabled: true,
Host: host,
Port: port,
Password: password,
}
query := parsed.Query()
homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery")
if scheme == "rediss" {
homeCfg.TLS.Enable = true
homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name"))
homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify")
homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert"))
}
return homeCfg, nil
}
func parseHomePort(rawPort string) (int, error) {
rawPort = strings.TrimSpace(rawPort)
if rawPort == "" {
return 0, fmt.Errorf("port is empty")
}
port, errPort := strconv.Atoi(rawPort)
if errPort != nil || port <= 0 || port > 65535 {
return 0, fmt.Errorf("invalid port %q", rawPort)
}
return port, nil
}
func firstHomeQueryValue(values url.Values, keys ...string) string {
for _, key := range keys {
if value := values.Get(key); value != "" {
return value
}
}
return ""
}
func parseHomeBoolQuery(values url.Values, keys ...string) bool {
for _, key := range keys {
value := strings.TrimSpace(values.Get(key))
if value == "" {
continue
}
parsed, errParse := strconv.ParseBool(value)
return errParse == nil && parsed
}
return false
}
// main is the entry point of the application.
// It parses command-line flags, loads configuration, and starts the appropriate
// service based on the provided flags (login, codex-login, or server mode).
@@ -188,8 +72,7 @@ func main() {
var vertexImportPrefix string
var configPath string
var password string
var homeAddr string
var homePassword string
var homeJWT string
var homeDisableClusterDiscovery bool
var tuiMode bool
var standalone bool
@@ -210,9 +93,8 @@ func main() {
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
flag.StringVar(&password, "password", "", "")
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection")
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home-jwt address")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
@@ -299,6 +181,13 @@ func main() {
return "", false
}
writableBase := util.WritablePath()
if strings.TrimSpace(homeJWT) == "" {
if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok {
homeJWT = v
}
}
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
usePostgresStore = true
pgStoreDSN = value
@@ -362,12 +251,13 @@ func main() {
// Determine and load the configuration file.
// Prefer the Postgres store when configured, otherwise fallback to git or local files.
var configFilePath string
if strings.TrimSpace(homeAddr) != "" {
if strings.TrimSpace(homeJWT) != "" {
configLoadedFromHome = true
trimmedHomePassword := strings.TrimSpace(homePassword)
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
homeCfg, errHomeCfg := home.ConfigFromJWT(ctxHome, homeJWT)
cancelHome()
if errHomeCfg != nil {
log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg)
log.Errorf("invalid -home-jwt: %v", errHomeCfg)
return
}
if homeDisableClusterDiscovery {
@@ -376,9 +266,9 @@ func main() {
homeClient := home.New(homeCfg)
defer homeClient.Close()
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
raw, errGetConfig := homeClient.GetConfig(ctxHome)
cancelHome()
ctxHomeConfig, cancelHomeConfig := context.WithTimeout(context.Background(), 30*time.Second)
raw, errGetConfig := homeClient.GetConfig(ctxHomeConfig)
cancelHomeConfig()
if errGetConfig != nil {
log.Errorf("failed to fetch config from home: %v", errGetConfig)
return
+3 -22
View File
@@ -11,26 +11,6 @@ tls:
cert: ""
key: ""
# Optional "home" control plane integration over Redis protocol.
home:
enabled: false
host: "127.0.0.1"
port: 6379
password: ""
# Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries.
# Useful when Home is behind NAT, Docker networking, or a reverse proxy.
disable-cluster-discovery: false
# Optional TLS for the outbound Redis connection to the home control plane.
# Enable this when connecting through rediss:// or an SSL stream proxy.
tls:
enable: false
# Optional SNI/certificate name override. Leave empty to use the configured home host.
server-name: ""
# Trust a private CA bundle in addition to system roots.
ca-cert: ""
# Only for testing self-signed endpoints; disables certificate verification.
insecure-skip-verify: false
# Management API settings
remote-management:
# Whether to allow remote (non-localhost) management access.
@@ -86,8 +66,8 @@ error-logs-max-files: 10
# When false, disable in-memory usage statistics aggregation
usage-statistics-enabled: false
# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP).
# Note: the in-process Redis RESP usage output is disabled when home.enabled is true.
# How long (in seconds) usage queue items are retained in memory for the Management API.
# The local Redis RESP usage output is disabled.
# Default: 60. Max: 3600.
redis-usage-queue-retention-seconds: 60
@@ -277,6 +257,7 @@ nonstream-keepalive-interval: 0
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# image: false # optional: set true to allow this model on /v1/images/generations and /v1/images/edits
# thinking: # optional: omit to default to levels ["low","medium","high"]
# levels: ["low", "medium", "high"]
# # You may repeat the same alias to build an internal model pool.
+29
View File
@@ -0,0 +1,29 @@
services:
cli-proxy-api:
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
pull_policy: always
build:
context: .
dockerfile: Dockerfile
args:
VERSION: ${VERSION:-dev}
COMMIT: ${COMMIT:-none}
BUILD_DATE: ${BUILD_DATE:-unknown}
container_name: cli-proxy-api-cluster
environment:
HOME_JWT: ${HOME_JWT:-}
ports:
- "8317:8317"
volumes:
- ./home:/root/.cli-proxy-api
- ./logs:/CLIProxyAPI/logs
command: >
sh -eu -c '
if [ -z "$$HOME_JWT" ]; then
echo "HOME_JWT is required" >&2
exit 1
fi
exec ./CLIProxyAPI -home-jwt "$$HOME_JWT"
'
restart: unless-stopped
@@ -2081,7 +2081,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
log.Infof("antigravity: obtained project ID %s", util.HideAPIKey(projectID))
}
}
@@ -2125,7 +2125,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
CompleteOAuthSessionsByProvider("antigravity")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
fmt.Printf("Using GCP project: %s\n", util.HideAPIKey(projectID))
}
fmt.Println("You can now use Antigravity services through this CLI")
}()
-12
View File
@@ -103,18 +103,6 @@ func (s *Server) routeMuxConnection(conn net.Conn, httpListener *muxListener) {
}
if isRedisRESPPrefix(prefix[0]) {
if s.cfg != nil && s.cfg.Home.Enabled {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close redis connection while home mode is enabled: %v", errClose)
}
return
}
if !s.managementRoutesEnabled.Load() {
if errClose := conn.Close(); errClose != nil {
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
}
return
}
_ = conn.SetReadDeadline(time.Time{})
s.handleRedisConnection(conn, reader)
return
+81 -93
View File
@@ -31,9 +31,12 @@ func isRedisRESPPrefix(prefix byte) bool {
}
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
if s == nil || conn == nil || reader == nil {
if s == nil || conn == nil {
return
}
if reader == nil {
reader = bufio.NewReader(conn)
}
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
authed := false
@@ -63,10 +66,10 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
return
}
args, err := readRESPArray(reader)
if err != nil {
if !errors.Is(err, io.EOF) {
_ = writeRedisError(writer, "ERR "+err.Error())
args, errRead := readRESPArray(reader)
if errRead != nil {
if !errors.Is(errRead, io.EOF) {
_ = writeRedisError(writer, "ERR "+errRead.Error())
_ = writer.Flush()
}
return
@@ -139,13 +142,6 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
return
}
case "SUBSCRIBE":
if !authed {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
if !flush() {
return
}
continue
}
channel, ok := parseSubscribeChannel(args)
if !ok {
_ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command")
@@ -174,13 +170,6 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe)
return
case "LPOP", "RPOP":
if !authed {
_ = writeRedisError(writer, "NOAUTH Authentication required.")
if !flush() {
return
}
continue
}
count, hasCount, ok := parsePopCount(args)
if !ok {
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
@@ -270,11 +259,11 @@ func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSu
defer close(commands)
for {
args, err := readRESPArray(reader)
if err != nil {
if !errors.Is(err, io.EOF) {
args, errRead := readRESPArray(reader)
if errRead != nil {
if !errors.Is(errRead, io.EOF) {
select {
case commands <- redisSubscriptionCommand{err: err}:
case commands <- redisSubscriptionCommand{err: errRead}:
case <-done:
}
}
@@ -336,7 +325,7 @@ func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
}
default:
host = addr.String()
if h, _, err := net.SplitHostPort(host); err == nil {
if h, _, errSplit := net.SplitHostPort(host); errSplit == nil {
host = h
}
host = strings.TrimSpace(host)
@@ -362,7 +351,6 @@ func parseAuthPassword(args []string) (string, bool) {
case 2:
return args[1], true
case 3:
// Support AUTH <username> <password> by ignoring username for compatibility.
return args[2], true
default:
return "", false
@@ -383,34 +371,34 @@ func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
if len(args) == 2 {
return 1, false, true
}
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
if err != nil {
parsed, errParse := strconv.Atoi(strings.TrimSpace(args[2]))
if errParse != nil {
return 0, true, true
}
return parsed, true, true
}
func readRESPArray(reader *bufio.Reader) ([]string, error) {
prefix, err := reader.ReadByte()
if err != nil {
return nil, err
prefix, errRead := reader.ReadByte()
if errRead != nil {
return nil, errRead
}
if prefix != '*' {
return nil, fmt.Errorf("protocol error")
}
line, err := readRESPLine(reader)
if err != nil {
return nil, err
line, errLine := readRESPLine(reader)
if errLine != nil {
return nil, errLine
}
count, err := strconv.Atoi(line)
if err != nil || count < 0 {
count, errParse := strconv.Atoi(line)
if errParse != nil || count < 0 {
return nil, fmt.Errorf("protocol error")
}
args := make([]string, 0, count)
for i := 0; i < count; i++ {
value, err := readRESPString(reader)
if err != nil {
return nil, err
value, errString := readRESPString(reader)
if errString != nil {
return nil, errString
}
args = append(args, value)
}
@@ -418,9 +406,9 @@ func readRESPArray(reader *bufio.Reader) ([]string, error) {
}
func readRESPString(reader *bufio.Reader) (string, error) {
prefix, err := reader.ReadByte()
if err != nil {
return "", err
prefix, errRead := reader.ReadByte()
if errRead != nil {
return "", errRead
}
switch prefix {
case '$':
@@ -433,20 +421,20 @@ func readRESPString(reader *bufio.Reader) (string, error) {
}
func readRESPBulkString(reader *bufio.Reader) (string, error) {
line, err := readRESPLine(reader)
if err != nil {
return "", err
line, errLine := readRESPLine(reader)
if errLine != nil {
return "", errLine
}
length, err := strconv.Atoi(line)
if err != nil {
length, errParse := strconv.Atoi(line)
if errParse != nil {
return "", fmt.Errorf("protocol error")
}
if length < 0 {
return "", nil
}
buf := make([]byte, length+2)
if _, err := io.ReadFull(reader, buf); err != nil {
return "", err
if _, errRead := io.ReadFull(reader, buf); errRead != nil {
return "", errRead
}
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
return "", fmt.Errorf("protocol error")
@@ -455,9 +443,9 @@ func readRESPBulkString(reader *bufio.Reader) (string, error) {
}
func readRESPLine(reader *bufio.Reader) (string, error) {
line, err := reader.ReadString('\n')
if err != nil {
return "", err
line, errRead := reader.ReadString('\n')
if errRead != nil {
return "", errRead
}
line = strings.TrimSuffix(line, "\n")
line = strings.TrimSuffix(line, "\r")
@@ -468,24 +456,24 @@ func writeRedisSimpleString(writer *bufio.Writer, value string) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("+" + value + "\r\n")
return err
_, errWrite := writer.WriteString("+" + value + "\r\n")
return errWrite
}
func writeRedisError(writer *bufio.Writer, message string) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("-" + message + "\r\n")
return err
_, errWrite := writer.WriteString("-" + message + "\r\n")
return errWrite
}
func writeRedisNilBulkString(writer *bufio.Writer) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("$-1\r\n")
return err
_, errWrite := writer.WriteString("$-1\r\n")
return errWrite
}
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
@@ -495,26 +483,26 @@ func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
if payload == nil {
return writeRedisNilBulkString(writer)
}
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
return err
if _, errWrite := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); errWrite != nil {
return errWrite
}
if _, err := writer.Write(payload); err != nil {
return err
if _, errWrite := writer.Write(payload); errWrite != nil {
return errWrite
}
_, err := writer.WriteString("\r\n")
return err
_, errWrite := writer.WriteString("\r\n")
return errWrite
}
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
if writer == nil {
return net.ErrClosed
}
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
return err
if _, errWrite := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); errWrite != nil {
return errWrite
}
for i := range items {
if err := writeRedisBulkString(writer, items[i]); err != nil {
return err
if errWrite := writeRedisBulkString(writer, items[i]); errWrite != nil {
return errWrite
}
}
return nil
@@ -524,63 +512,63 @@ func writeRedisInteger(writer *bufio.Writer, value int) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
return err
_, errWrite := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
return errWrite
}
func writeRedisArrayHeader(writer *bufio.Writer, count int) error {
if writer == nil {
return net.ErrClosed
}
_, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
return err
_, errWrite := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
return errWrite
}
func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error {
if err := writeRedisArrayHeader(writer, 3); err != nil {
return err
if errWrite := writeRedisArrayHeader(writer, 3); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte("subscribe")); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte(channel)); errWrite != nil {
return errWrite
}
return writeRedisInteger(writer, count)
}
func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error {
if err := writeRedisArrayHeader(writer, 3); err != nil {
return err
if errWrite := writeRedisArrayHeader(writer, 3); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte("unsubscribe")); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte(channel)); errWrite != nil {
return errWrite
}
return writeRedisInteger(writer, count)
}
func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error {
if err := writeRedisArrayHeader(writer, 3); err != nil {
return err
if errWrite := writeRedisArrayHeader(writer, 3); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte("message")); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte("message")); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte(channel)); errWrite != nil {
return errWrite
}
return writeRedisBulkString(writer, payload)
}
func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error {
if err := writeRedisArrayHeader(writer, 2); err != nil {
return err
if errWrite := writeRedisArrayHeader(writer, 2); errWrite != nil {
return errWrite
}
if err := writeRedisBulkString(writer, []byte("pong")); err != nil {
return err
if errWrite := writeRedisBulkString(writer, []byte("pong")); errWrite != nil {
return errWrite
}
return writeRedisBulkString(writer, payload)
}
@@ -3,13 +3,10 @@ package api
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
@@ -18,18 +15,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
)
type remoteAddrConn struct {
net.Conn
remoteAddr net.Addr
}
func (c *remoteAddrConn) RemoteAddr() net.Addr {
if c == nil {
return nil
}
return c.remoteAddr
}
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
t.Helper()
@@ -86,17 +71,6 @@ func readTestRESPLine(r *bufio.Reader) (string, error) {
return strings.TrimSuffix(line, "\r\n"), nil
}
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
prefix, err := r.ReadByte()
if err != nil {
return "", err
}
if prefix != '+' {
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
}
return readTestRESPLine(r)
}
func readTestRESPError(r *bufio.Reader) (string, error) {
prefix, err := r.ReadByte()
if err != nil {
@@ -108,22 +82,33 @@ func readTestRESPError(r *bufio.Reader) (string, error) {
return readTestRESPLine(r)
}
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
prefix, errRead := r.ReadByte()
if errRead != nil {
return "", errRead
}
if prefix != '+' {
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
}
return readTestRESPLine(r)
}
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
prefix, errRead := r.ReadByte()
if errRead != nil {
return nil, errRead
}
if prefix != '$' {
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return nil, err
line, errLine := readTestRESPLine(r)
if errLine != nil {
return nil, errLine
}
length, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
length, errParse := strconv.Atoi(line)
if errParse != nil {
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, errParse)
}
if length == -1 {
return nil, nil
@@ -133,8 +118,8 @@ func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
}
payload := make([]byte, length+2)
if _, err := io.ReadFull(r, payload); err != nil {
return nil, err
if _, errRead := io.ReadFull(r, payload); errRead != nil {
return nil, errRead
}
if payload[length] != '\r' || payload[length+1] != '\n' {
return nil, fmt.Errorf("invalid bulk string terminator")
@@ -143,21 +128,21 @@ func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
}
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
prefix, err := r.ReadByte()
if err != nil {
return nil, err
prefix, errRead := r.ReadByte()
if errRead != nil {
return nil, errRead
}
if prefix != '*' {
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return nil, err
line, errLine := readTestRESPLine(r)
if errLine != nil {
return nil, errLine
}
count, err := strconv.Atoi(line)
if err != nil {
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
count, errParse := strconv.Atoi(line)
if errParse != nil {
return nil, fmt.Errorf("invalid array length %q: %v", line, errParse)
}
if count < 0 {
return nil, fmt.Errorf("invalid array length %d", count)
@@ -165,114 +150,15 @@ func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
out := make([][]byte, 0, count)
for i := 0; i < count; i++ {
item, err := readTestRESPBulkString(r)
if err != nil {
return nil, err
item, errItem := readTestRESPBulkString(r)
if errItem != nil {
return nil, errItem
}
out = append(out, item)
}
return out, nil
}
func readTestRESPInteger(r *bufio.Reader) (int, error) {
prefix, err := r.ReadByte()
if err != nil {
return 0, err
}
if prefix != ':' {
return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return 0, err
}
value, err := strconv.Atoi(line)
if err != nil {
return 0, fmt.Errorf("invalid integer %q: %v", line, err)
}
return value, nil
}
func readTestRESPArrayHeader(r *bufio.Reader) (int, error) {
prefix, err := r.ReadByte()
if err != nil {
return 0, err
}
if prefix != '*' {
return 0, fmt.Errorf("expected array prefix '*', got %q", prefix)
}
line, err := readTestRESPLine(r)
if err != nil {
return 0, err
}
count, err := strconv.Atoi(line)
if err != nil {
return 0, fmt.Errorf("invalid array length %q: %v", line, err)
}
if count < 0 {
return 0, fmt.Errorf("invalid array length %d", count)
}
return count, nil
}
func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) {
count, err := readTestRESPArrayHeader(r)
if err != nil {
return "", 0, err
}
if count != 3 {
return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count)
}
kind, err := readTestRESPBulkString(r)
if err != nil {
return "", 0, err
}
if string(kind) != "subscribe" {
return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind))
}
channel, err := readTestRESPBulkString(r)
if err != nil {
return "", 0, err
}
subscriptions, err := readTestRESPInteger(r)
if err != nil {
return "", 0, err
}
return string(channel), subscriptions, nil
}
func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) {
count, err := readTestRESPArrayHeader(r)
if err != nil {
return "", nil, err
}
if count != 3 {
return "", nil, fmt.Errorf("message array length = %d, want 3", count)
}
kind, err := readTestRESPBulkString(r)
if err != nil {
return "", nil, err
}
if string(kind) != "message" {
return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind))
}
channel, err := readTestRESPBulkString(r)
if err != nil {
return "", nil, err
}
payload, err := readTestRESPBulkString(r)
if err != nil {
return "", nil, err
}
return string(channel), payload, nil
}
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
redisqueue.SetEnabled(false)
@@ -333,13 +219,19 @@ func TestRedisProtocol_HomeEnabled_DisablesConnection(t *testing.T) {
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
_ = writeTestRESPCommand(conn, "PING")
if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil {
t.Fatalf("failed to read home-mode RESP error: %v", err)
} else if msg != "ERR redis usage output disabled in home mode" {
t.Fatalf("unexpected disabled RESP error: %q", msg)
}
buf := make([]byte, 1)
_, errRead := conn.Read(buf)
if errRead == nil {
t.Fatalf("expected connection to be closed when home mode is enabled")
t.Fatalf("expected connection to be closed after home-mode RESP error")
}
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
t.Fatalf("expected connection to be closed when home mode is enabled, got timeout: %v", errRead)
t.Fatalf("expected connection to be closed after home-mode RESP error, got timeout: %v", errRead)
}
}
@@ -368,29 +260,11 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error: %q", msg)
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
} else if msg != "NOAUTH Authentication required." {
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
}
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(reader); err != nil {
t.Fatalf("failed to read AUTH response: %v", err)
if msg, errRead := readTestRESPSimpleString(reader); errRead != nil {
t.Fatalf("failed to read AUTH response: %v", errRead)
} else if msg != "OK" {
t.Fatalf("unexpected AUTH response: %q", msg)
}
@@ -402,25 +276,25 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
redisqueue.Enqueue([]byte("b"))
redisqueue.Enqueue([]byte("c"))
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
if errWrite := writeTestRESPCommand(conn, "RPOP", "usage"); errWrite != nil {
t.Fatalf("failed to write RPOP command: %v", errWrite)
}
if item, err := readTestRESPBulkString(reader); err != nil {
t.Fatalf("failed to read RPOP response: %v", err)
if item, errRead := readTestRESPBulkString(reader); errRead != nil {
t.Fatalf("failed to read RPOP response: %v", errRead)
} else if string(item) != "a" {
t.Fatalf("unexpected RPOP item: %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
if errWrite := writeTestRESPCommand(conn, "LPOP", "usage"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if item, err := readTestRESPBulkString(reader); err != nil {
t.Fatalf("failed to read LPOP response: %v", err)
if item, errRead := readTestRESPBulkString(reader); errRead != nil {
t.Fatalf("failed to read LPOP response: %v", errRead)
} else if string(item) != "b" {
t.Fatalf("unexpected LPOP item: %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
if errWrite := writeTestRESPCommand(conn, "RPOP", "usage", "10"); errWrite != nil {
t.Fatalf("failed to write RPOP count command: %v", errWrite)
}
items, errItems := readRESPArrayOfBulkStrings(reader)
@@ -431,7 +305,7 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
t.Fatalf("unexpected RPOP count items: %#v", items)
}
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
if errWrite := writeTestRESPCommand(conn, "LPOP", "usage"); errWrite != nil {
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
}
item, errItem := readTestRESPBulkString(reader)
@@ -442,7 +316,7 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
}
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
if errWrite := writeTestRESPCommand(conn, "RPOP", "usage", "2"); errWrite != nil {
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
}
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
@@ -453,284 +327,3 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
}
}
func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
addr, stop := startRedisMuxListener(t, server)
t.Cleanup(stop)
firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second)
if errDialFirst != nil {
t.Fatalf("failed to dial first redis listener: %v", errDialFirst)
}
t.Cleanup(func() { _ = firstConn.Close() })
firstReader := bufio.NewReader(firstConn)
_ = firstConn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write first AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(firstReader); err != nil {
t.Fatalf("failed to read first AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected first AUTH response: %q", msg)
}
if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil {
t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite)
}
if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil {
t.Fatalf("failed to read first SUBSCRIBE response: %v", err)
} else if channel != "usage" || count != 1 {
t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count)
}
secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second)
if errDialSecond != nil {
t.Fatalf("failed to dial second redis listener: %v", errDialSecond)
}
t.Cleanup(func() { _ = secondConn.Close() })
secondReader := bufio.NewReader(secondConn)
_ = secondConn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write second AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(secondReader); err != nil {
t.Fatalf("failed to read second AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected second AUTH response: %q", msg)
}
if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil {
t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite)
}
if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil {
t.Fatalf("failed to read second SUBSCRIBE response: %v", err)
} else if channel != "usage" || count != 1 {
t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count)
}
redisqueue.Enqueue([]byte(`{"id":1}`))
if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil {
t.Fatalf("failed to read first pubsub message: %v", err)
} else if channel != "usage" || string(payload) != `{"id":1}` {
t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload))
}
if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil {
t.Fatalf("failed to read second pubsub message: %v", err)
} else if channel != "usage" || string(payload) != `{"id":1}` {
t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload))
}
popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second)
if errDialPop != nil {
t.Fatalf("failed to dial pop redis listener: %v", errDialPop)
}
t.Cleanup(func() { _ = popConn.Close() })
popReader := bufio.NewReader(popConn)
_ = popConn.SetDeadline(time.Now().Add(5 * time.Second))
if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write pop AUTH command: %v", errWrite)
}
if msg, err := readTestRESPSimpleString(popReader); err != nil {
t.Fatalf("failed to read pop AUTH response: %v", err)
} else if msg != "OK" {
t.Fatalf("unexpected pop AUTH response: %q", msg)
}
if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil {
t.Fatalf("failed to write pop LPOP command: %v", errWrite)
}
item, errItem := readTestRESPBulkString(popReader)
if errItem != nil {
t.Fatalf("failed to read pop LPOP response: %v", errItem)
}
if item != nil {
t.Fatalf("expected subscribed usage to skip queue, got %q", string(item))
}
managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil)
managementReq.Header.Set("Authorization", "Bearer "+managementPassword)
managementRR := httptest.NewRecorder()
server.engine.ServeHTTP(managementRR, managementReq)
if managementRR.Code != http.StatusOK {
t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String())
}
var managementPayload []json.RawMessage
if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil {
t.Fatalf("unmarshal management usage response: %v", errUnmarshal)
}
if len(managementPayload) != 0 {
t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String())
}
}
func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
clientConn, serverConn := net.Pipe()
t.Cleanup(func() { _ = clientConn.Close() })
t.Cleanup(func() { _ = serverConn.Close() })
fakeRemote := &net.TCPAddr{
IP: net.ParseIP("1.2.3.4"),
Port: 1234,
}
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
reader := bufio.NewReader(clientConn)
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
for i := 0; i < 5; i++ {
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
} else if msg != "NOAUTH Authentication required." {
t.Fatalf("unexpected LPOP NOAUTH error at attempt %d: %q", i+1, msg)
}
}
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
t.Fatalf("failed to write LPOP command after failures: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read LPOP banned error: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected LPOP banned error: %q", msg)
}
}
func TestRedisProtocol_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
clientConn, serverConn := net.Pipe()
t.Cleanup(func() { _ = clientConn.Close() })
t.Cleanup(func() { _ = serverConn.Close() })
fakeRemote := &net.TCPAddr{
IP: net.ParseIP("1.2.3.4"),
Port: 1234,
}
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
reader := bufio.NewReader(clientConn)
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
for i := 0; i < 5; i++ {
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
}
}
for i := 0; i < 2; i++ {
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
t.Fatalf("failed to write AUTH command after failures: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read AUTH banned error: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected AUTH banned error at attempt %d: %q", i+6, msg)
}
}
if errWrite := writeTestRESPCommand(clientConn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
}
}
func TestRedisProtocol_LOCALHOST_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
const managementPassword = "test-management-password"
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
redisqueue.SetEnabled(false)
t.Cleanup(func() { redisqueue.SetEnabled(false) })
server := newTestServer(t)
if !server.managementRoutesEnabled.Load() {
t.Fatalf("expected managementRoutesEnabled to be true")
}
addr, stop := startRedisMuxListener(t, server)
t.Cleanup(stop)
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
if errDial != nil {
t.Fatalf("failed to dial redis listener: %v", errDial)
}
t.Cleanup(func() { _ = conn.Close() })
reader := bufio.NewReader(conn)
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
for i := 0; i < 5; i++ {
if errWrite := writeTestRESPCommand(conn, "AUTH", "wrong-password"); errWrite != nil {
t.Fatalf("failed to write AUTH command: %v", errWrite)
}
if msg, err := readTestRESPError(reader); err != nil {
t.Fatalf("failed to read AUTH error: %v", err)
} else if msg != "ERR invalid management key" {
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
}
}
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
}
msg, err := readTestRESPError(reader)
if err != nil {
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
}
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
}
}
-3
View File
@@ -217,9 +217,6 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Create gin engine
engine := gin.New()
if errSetTrustedProxies := engine.SetTrustedProxies(nil); errSetTrustedProxies != nil {
log.Warnf("failed to disable trusted proxy headers: %v", errSetTrustedProxies)
}
if optionState.engineConfigurator != nil {
optionState.engineConfigurator(engine)
}
+24 -26
View File
@@ -21,10 +21,6 @@ import (
)
func newTestServer(t *testing.T) *Server {
return newTestServerWithOptions(t)
}
func newTestServerWithOptions(t *testing.T, opts ...ServerOption) *Server {
t.Helper()
gin.SetMode(gin.TestMode)
@@ -50,7 +46,7 @@ func newTestServerWithOptions(t *testing.T, opts ...ServerOption) *Server {
accessManager := sdkaccess.NewManager()
configPath := filepath.Join(tmpDir, "config.yaml")
return NewServer(cfg, authManager, accessManager, configPath, opts...)
return NewServer(cfg, authManager, accessManager, configPath)
}
func TestHealthz(t *testing.T) {
@@ -152,26 +148,6 @@ func TestManagementUsageRequiresManagementAuthAndPopsArray(t *testing.T) {
}
}
func TestManagementLocalPasswordRejectsSpoofedForwardedFor(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
server := newTestServerWithOptions(t, WithLocalManagementPassword("test-local-key"))
req := httptest.NewRequest(http.MethodGet, "/v0/management/config", nil)
req.RemoteAddr = "203.0.113.10:45678"
req.Header.Set("X-Forwarded-For", "127.0.0.1")
req.Header.Set("Authorization", "Bearer test-local-key")
rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Fatalf("status = %d, want %d body=%s", rr.Code, http.StatusForbidden, rr.Body.String())
}
if body := rr.Body.String(); !strings.Contains(body, "remote management disabled") {
t.Fatalf("body = %q, want remote management disabled", body)
}
}
func TestHomeEnabledHidesManagementEndpointsAndControlPanel(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "test-management-key")
@@ -287,7 +263,7 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) {
DisplayName: "Custom Codex Model",
Description: "Custom model from registry",
ContextLength: 123456,
Thinking: &registry.ThinkingSupport{Levels: []string{"low", "medium"}},
Thinking: &registry.ThinkingSupport{Levels: []string{"none", "minimal", "low", "medium", "unsupported", "high", "xhigh"}},
},
{ID: "grok-imagine-image-quality", Object: "model", OwnedBy: "xai", Type: "openai"},
{ID: "gpt-image-2", Object: "model", OwnedBy: "openai", Type: "openai"},
@@ -358,6 +334,7 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) {
if got, _ := custom["context_window"].(float64); got != 123456 {
t.Fatalf("custom context_window = %v, want 123456", custom["context_window"])
}
assertCodexSupportedReasoningLevels(t, custom, []string{"none", "low", "medium", "high", "xhigh"})
if custom["base_instructions"] != gpt55["base_instructions"] {
t.Fatal("expected custom model to use gpt-5.5 base_instructions fallback")
}
@@ -400,6 +377,27 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) {
}
}
func assertCodexSupportedReasoningLevels(t *testing.T, model map[string]any, want []string) {
t.Helper()
rawLevels, ok := model["supported_reasoning_levels"].([]any)
if !ok {
t.Fatalf("expected supported_reasoning_levels, got %#v", model["supported_reasoning_levels"])
}
if len(rawLevels) != len(want) {
t.Fatalf("supported_reasoning_levels length = %d, want %d: %#v", len(rawLevels), len(want), rawLevels)
}
for index, rawLevel := range rawLevels {
levelEntry, ok := rawLevel.(map[string]any)
if !ok {
t.Fatalf("supported_reasoning_levels[%d] = %#v, want object", index, rawLevel)
}
if got, _ := levelEntry["effort"].(string); got != want[index] {
t.Fatalf("supported_reasoning_levels[%d].effort = %q, want %q", index, got, want[index])
}
}
}
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
t.Setenv("WRITABLE_PATH", "")
t.Setenv("writable_path", "")
+84 -56
View File
@@ -48,10 +48,76 @@ func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *Antigravit
}
}
func (o *AntigravityAuth) loadCodeAssistUserAgent() string {
func (o *AntigravityAuth) shortUserAgent() string {
return misc.AntigravityRequestUserAgent("")
}
func (o *AntigravityAuth) nodeUserAgent() string {
return misc.AntigravityLoadCodeAssistUserAgent("")
}
func antigravityLoadCodeAssistMetadata() map[string]string {
return map[string]string{
"ideType": "ANTIGRAVITY",
}
}
func antigravityControlPlaneMetadata(userAgent string) map[string]string {
return map[string]string{
"ide_type": "ANTIGRAVITY",
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
"ide_name": "antigravity",
}
}
func extractCloudaicompanionProject(data map[string]any) string {
if data == nil {
return ""
}
for _, key := range []string{"cloudaicompanionProject", "projectId", "project"} {
switch value := data[key].(type) {
case string:
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
case map[string]any:
if id, ok := value["id"].(string); ok {
if trimmed := strings.TrimSpace(id); trimmed != "" {
return trimmed
}
}
}
}
return ""
}
func defaultAntigravityTierID(loadResp map[string]any) string {
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); !okDefault || !isDefault {
continue
}
if id, okID := tier["id"].(string); okID {
if trimmed := strings.TrimSpace(id); trimmed != "" {
return trimmed
}
}
}
}
if currentTier, okTier := loadResp["currentTier"].(map[string]any); okTier {
if id, okID := currentTier["id"].(string); okID {
if trimmed := strings.TrimSpace(id); trimmed != "" {
return trimmed
}
}
}
return "free-tier"
}
// BuildAuthURL generates the OAuth authorization URL.
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
if strings.TrimSpace(redirectURI) == "" {
@@ -123,7 +189,7 @@ func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string)
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("User-Agent", o.loadCodeAssistUserAgent())
req.Header.Set("User-Agent", o.shortUserAgent())
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
@@ -159,13 +225,9 @@ func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string)
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
userAgent := o.loadCodeAssistUserAgent()
userAgent := o.shortUserAgent()
loadReqBody := map[string]any{
"metadata": map[string]string{
"ide_type": "ANTIGRAVITY",
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
"ide_name": "antigravity",
},
"metadata": antigravityLoadCodeAssistMetadata(),
}
rawBody, errMarshal := json.Marshal(loadReqBody)
@@ -179,9 +241,9 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "*/*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
req.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
@@ -207,40 +269,16 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
projectID := extractCloudaicompanionProject(loadResp)
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
projectID, err = o.OnboardUser(ctx, accessToken, defaultAntigravityTierID(loadResp))
if err != nil {
return "", err
}
if projectID == "" {
return "", fmt.Errorf("project id not found in loadCodeAssist or onboardUser response")
}
return projectID, nil
}
@@ -250,14 +288,10 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
userAgent := o.loadCodeAssistUserAgent()
userAgent := o.nodeUserAgent()
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ide_type": "ANTIGRAVITY",
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
"ide_name": "antigravity",
},
"tier_id": tierID,
"metadata": antigravityControlPlaneMetadata(userAgent),
}
rawBody, errMarshal := json.Marshal(requestBody)
@@ -276,13 +310,14 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", DailyAPIEndpoint, APIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "*/*")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", userAgent)
req.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA)
@@ -312,18 +347,11 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
projectID = extractCloudaicompanionProject(responseData)
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
log.Infof("Successfully fetched project_id: %s", util.HideAPIKey(projectID))
return projectID, nil
}
@@ -346,5 +374,5 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
return "", fmt.Errorf("onboard user did not complete after %d attempts", maxAttempts)
}
+127
View File
@@ -0,0 +1,127 @@
package antigravity
import (
"context"
"io"
"net/http"
"strings"
"testing"
)
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestFetchProjectIDFromLoadCodeAssist(t *testing.T) {
auth := NewAntigravityAuth(nil, &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" {
t.Fatalf("unexpected request URL: %s", req.URL.String())
}
assertLoadCodeAssistHeaders(t, req)
assertJSONContains(t, req, `"ideType":"ANTIGRAVITY"`)
return jsonResponse(`{"cloudaicompanionProject":"cogent-snow-4mnnp"}`), nil
})})
projectID, err := auth.FetchProjectID(context.Background(), "access-token")
if err != nil {
t.Fatalf("FetchProjectID error: %v", err)
}
if projectID != "cogent-snow-4mnnp" {
t.Fatalf("projectID = %q", projectID)
}
}
func TestFetchProjectIDFallsBackToDailyOnboardUser(t *testing.T) {
var sawOnboard bool
auth := NewAntigravityAuth(nil, &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
switch req.URL.String() {
case "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist":
assertLoadCodeAssistHeaders(t, req)
return jsonResponse(`{"allowedTiers":[{"id":"free-tier","isDefault":true}]}`), nil
case "https://daily-cloudcode-pa.googleapis.com/v1internal:onboardUser":
sawOnboard = true
assertOnboardUserHeaders(t, req)
assertJSONContains(t, req, `"tier_id":"free-tier"`)
assertJSONContains(t, req, `"ide_type":"ANTIGRAVITY"`)
return jsonResponse(`{
"done": true,
"response": {
"cloudaicompanionProject": {
"id": "cogent-snow-4mnnp",
"name": "cogent-snow-4mnnp",
"projectNumber": "22597072101"
}
}
}`), nil
default:
t.Fatalf("unexpected request URL: %s", req.URL.String())
return nil, nil
}
})})
projectID, err := auth.FetchProjectID(context.Background(), "access-token")
if err != nil {
t.Fatalf("FetchProjectID error: %v", err)
}
if !sawOnboard {
t.Fatalf("expected onboardUser fallback")
}
if projectID != "cogent-snow-4mnnp" {
t.Fatalf("projectID = %q", projectID)
}
}
func assertLoadCodeAssistHeaders(t *testing.T, req *http.Request) {
t.Helper()
if got := req.Header.Get("Authorization"); got != "Bearer access-token" {
t.Fatalf("Authorization = %q", got)
}
if got := req.Header.Get("Accept"); got != "*/*" {
t.Fatalf("Accept = %q", got)
}
if got := req.Header.Get("X-Goog-Api-Client"); got != "" {
t.Fatalf("X-Goog-Api-Client = %q, want empty", got)
}
if got := req.Header.Get("User-Agent"); strings.Contains(got, "google-api-nodejs-client/") {
t.Fatalf("User-Agent = %q", got)
}
}
func assertOnboardUserHeaders(t *testing.T, req *http.Request) {
t.Helper()
if got := req.Header.Get("Authorization"); got != "Bearer access-token" {
t.Fatalf("Authorization = %q", got)
}
if got := req.Header.Get("Accept"); got != "*/*" {
t.Fatalf("Accept = %q", got)
}
if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" {
t.Fatalf("X-Goog-Api-Client = %q", got)
}
if got := req.Header.Get("User-Agent"); !strings.Contains(got, "google-api-nodejs-client/10.3.0") {
t.Fatalf("User-Agent = %q", got)
}
}
func assertJSONContains(t *testing.T, req *http.Request, want string) {
t.Helper()
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("read body: %v", err)
}
bodyText := string(body)
req.Body = io.NopCloser(strings.NewReader(bodyText))
if !strings.Contains(bodyText, want) {
t.Fatalf("body missing %s: %s", want, bodyText)
}
}
func jsonResponse(body string) *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
+3 -2
View File
@@ -26,6 +26,7 @@ const (
// Antigravity API configuration
const (
APIEndpoint = "https://cloudcode-pa.googleapis.com"
APIVersion = "v1internal"
APIEndpoint = "https://cloudcode-pa.googleapis.com"
DailyAPIEndpoint = "https://daily-cloudcode-pa.googleapis.com"
APIVersion = "v1internal"
)
+1 -1
View File
@@ -34,7 +34,7 @@ func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
if cfg != nil {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
log.Errorf("failed to configure proxy dialer for %q: %v", proxyutil.Redact(cfg.ProxyURL), errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
+7 -4
View File
@@ -37,8 +37,8 @@ type Config struct {
// TLS config controls HTTPS server settings.
TLS TLSConfig `yaml:"tls" json:"tls"`
// Home config enables the Redis-based control plane integration.
Home HomeConfig `yaml:"home" json:"-"`
// Home config is runtime-only and is populated from -home-jwt.
Home HomeConfig `yaml:"-" json:"-"`
// RemoteManagement nests management-related options under 'remote-management'.
RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"`
@@ -69,8 +69,8 @@ type Config struct {
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
// are retained in memory for the Redis RESP interface (LPOP/RPOP).
// RedisUsageQueueRetentionSeconds controls how long usage queue items are retained
// in memory for Management API consumers.
// Default: 60. Max: 3600.
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
@@ -585,6 +585,9 @@ type OpenAICompatibilityModel struct {
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
// Image marks this model as callable through /v1/images/generations and /v1/images/edits.
Image bool `yaml:"image,omitempty" json:"image,omitempty"`
// Thinking configures the thinking/reasoning capability for this model.
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
+8 -6
View File
@@ -1,19 +1,21 @@
package config
// HomeConfig configures the optional "home" control plane integration over Redis protocol.
// HomeConfig stores runtime-only Home control plane settings from -home-jwt.
type HomeConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
Host string `yaml:"host" json:"-"`
Port int `yaml:"port" json:"-"`
Password string `yaml:"password" json:"-"`
DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"`
TLS HomeTLSConfig `yaml:"tls" json:"-"`
}
// HomeTLSConfig configures client-side TLS for the home Redis connection.
type HomeTLSConfig struct {
Enable bool `yaml:"enable" json:"-"`
ServerName string `yaml:"server-name" json:"-"`
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
CACert string `yaml:"ca-cert" json:"-"`
Enable bool `yaml:"enable" json:"-"`
ServerName string `yaml:"server-name" json:"-"`
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
CACert string `yaml:"ca-cert" json:"-"`
ClientCert string `yaml:"-" json:"-"`
ClientKey string `yaml:"-" json:"-"`
UseTargetServerName bool `yaml:"-" json:"-"`
}
+17 -21
View File
@@ -2,13 +2,12 @@ package config
import "testing"
func TestParseConfigBytesHomeTLS(t *testing.T) {
func TestParseConfigBytesIgnoresHomeConfig(t *testing.T) {
cfg, err := ParseConfigBytes([]byte(`
home:
enabled: true
host: home.example.com
port: 444
password: secret
disable-cluster-discovery: true
tls:
enable: true
@@ -20,31 +19,28 @@ home:
t.Fatalf("ParseConfigBytes() error = %v", err)
}
if !cfg.Home.Enabled {
t.Fatal("Home.Enabled = false, want true")
if cfg.Home.Enabled {
t.Fatal("Home.Enabled = true, want false")
}
if cfg.Home.Host != "home.example.com" {
t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host)
if cfg.Home.Host != "" {
t.Fatalf("Home.Host = %q, want empty", cfg.Home.Host)
}
if cfg.Home.Port != 444 {
t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port)
if cfg.Home.Port != 0 {
t.Fatalf("Home.Port = %d, want 0", cfg.Home.Port)
}
if cfg.Home.Password != "secret" {
t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password)
if cfg.Home.DisableClusterDiscovery {
t.Fatal("Home.DisableClusterDiscovery = true, want false")
}
if !cfg.Home.DisableClusterDiscovery {
t.Fatal("Home.DisableClusterDiscovery = false, want true")
if cfg.Home.TLS.Enable {
t.Fatal("Home.TLS.Enable = true, want false")
}
if !cfg.Home.TLS.Enable {
t.Fatal("Home.TLS.Enable = false, want true")
if cfg.Home.TLS.ServerName != "" {
t.Fatalf("Home.TLS.ServerName = %q, want empty", cfg.Home.TLS.ServerName)
}
if cfg.Home.TLS.ServerName != "home.example.com" {
t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName)
if cfg.Home.TLS.CACert != "" {
t.Fatalf("Home.TLS.CACert = %q, want empty", cfg.Home.TLS.CACert)
}
if cfg.Home.TLS.CACert != "C:/certs/ca.pem" {
t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert)
}
if !cfg.Home.TLS.InsecureSkipVerify {
t.Fatal("Home.TLS.InsecureSkipVerify = false, want true")
if cfg.Home.TLS.InsecureSkipVerify {
t.Fatal("Home.TLS.InsecureSkipVerify = true, want false")
}
}
+386
View File
@@ -0,0 +1,386 @@
package home
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
)
const homeCertificateRequestTimeout = 30 * time.Second
type homeJWTClaims struct {
CertificateID string `json:"certificate_id"`
ClusterID string `json:"cluster_id"`
CAFingerprint string `json:"ca_fingerprint"`
EnrollmentSecret string `json:"enrollment_secret"`
IP string `json:"ip"`
Port int `json:"port"`
IssuedAt int64 `json:"iat"`
}
type certificateRequestResponse struct {
OK bool `json:"ok"`
Certificate string `json:"certificate"`
CA string `json:"ca"`
}
type certificatePaths struct {
Dir string
ClientCert string
ClientKey string
CACert string
}
// ConfigFromJWT prepares a Home config from the JWT and ensures local mTLS files exist.
func ConfigFromJWT(ctx context.Context, rawJWT string) (config.HomeConfig, error) {
claims, errClaims := parseHomeJWTClaims(rawJWT)
if errClaims != nil {
return config.HomeConfig{}, errClaims
}
paths, errPaths := defaultCertificatePaths()
if errPaths != nil {
return config.HomeConfig{}, errPaths
}
if errEnsure := ensureHomeCertificateFiles(ctx, claims, paths); errEnsure != nil {
return config.HomeConfig{}, errEnsure
}
return config.HomeConfig{
Enabled: true,
Host: strings.TrimSpace(claims.IP),
Port: claims.Port,
TLS: config.HomeTLSConfig{
Enable: true,
CACert: paths.CACert,
ClientCert: paths.ClientCert,
ClientKey: paths.ClientKey,
UseTargetServerName: true,
},
}, nil
}
func parseHomeJWTClaims(rawJWT string) (homeJWTClaims, error) {
var claims homeJWTClaims
parts := strings.Split(strings.TrimSpace(rawJWT), ".")
if len(parts) != 3 {
return claims, fmt.Errorf("home jwt is invalid")
}
payload, errDecode := decodeJWTPart(parts[1])
if errDecode != nil {
return claims, errDecode
}
if errUnmarshal := json.Unmarshal(payload, &claims); errUnmarshal != nil {
return claims, errUnmarshal
}
if strings.TrimSpace(claims.CertificateID) == "" {
return claims, fmt.Errorf("home jwt certificate_id is required")
}
if strings.TrimSpace(claims.ClusterID) == "" {
return claims, fmt.Errorf("home jwt cluster_id is required")
}
if normalizeFingerprint(claims.CAFingerprint) == "" {
return claims, fmt.Errorf("home jwt ca_fingerprint is required")
}
if strings.TrimSpace(claims.EnrollmentSecret) == "" {
return claims, fmt.Errorf("home jwt enrollment_secret is required")
}
if strings.TrimSpace(claims.IP) == "" || claims.Port <= 0 {
return claims, fmt.Errorf("home jwt target address is invalid")
}
return claims, nil
}
func decodeJWTPart(part string) ([]byte, error) {
if decoded, errDecode := base64.RawURLEncoding.DecodeString(part); errDecode == nil {
return decoded, nil
}
return base64.URLEncoding.DecodeString(part)
}
func defaultCertificatePaths() (certificatePaths, error) {
homeDir, errHome := os.UserHomeDir()
if errHome != nil {
return certificatePaths{}, errHome
}
dir := filepath.Join(homeDir, ".cli-proxy-api")
return certificatePaths{
Dir: dir,
ClientCert: filepath.Join(dir, "client-crt.pem"),
ClientKey: filepath.Join(dir, "client-key.pem"),
CACert: filepath.Join(dir, "home-ca-crt.pem"),
}, nil
}
func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths certificatePaths) error {
if fileExists(paths.ClientCert) && fileExists(paths.ClientKey) {
if !fileExists(paths.CACert) {
return fmt.Errorf("home ca certificate file is missing")
}
if errVerify := verifyCACertificateFile(paths.CACert, claims.CAFingerprint); errVerify != nil {
return errVerify
}
if errChmod := chmodCertificateFiles(paths); errChmod != nil {
return errChmod
}
return nil
}
if errMkdir := os.MkdirAll(paths.Dir, 0o700); errMkdir != nil {
return errMkdir
}
key, errKey := loadOrCreateClientKey(paths.ClientKey)
if errKey != nil {
return errKey
}
csrPEM, errCSR := createClientCSR(claims.CertificateID, key)
if errCSR != nil {
return errCSR
}
response, errRequest := requestClientCertificate(ctx, claims, csrPEM)
if errRequest != nil {
return errRequest
}
if strings.TrimSpace(response.Certificate) == "" || strings.TrimSpace(response.CA) == "" {
return fmt.Errorf("home certificate response is incomplete")
}
if errVerify := verifyCACertificatePEM([]byte(response.CA), claims.CAFingerprint); errVerify != nil {
return errVerify
}
if errWrite := writeFile0600(paths.ClientCert, []byte(response.Certificate)); errWrite != nil {
return errWrite
}
if errWrite := writeFile0600(paths.CACert, []byte(response.CA)); errWrite != nil {
return errWrite
}
return nil
}
func verifyCACertificateFile(path string, expectedFingerprint string) error {
raw, errRead := os.ReadFile(path)
if errRead != nil {
return errRead
}
return verifyCACertificatePEM(raw, expectedFingerprint)
}
func verifyCACertificatePEM(raw []byte, expectedFingerprint string) error {
actual, errFingerprint := certificateFingerprintPEM(raw)
if errFingerprint != nil {
return errFingerprint
}
expected := normalizeFingerprint(expectedFingerprint)
if expected == "" {
return fmt.Errorf("home ca fingerprint is required")
}
if actual != expected {
return fmt.Errorf("home ca fingerprint mismatch")
}
return nil
}
func certificateFingerprintPEM(raw []byte) (string, error) {
block, _ := pem.Decode(raw)
if block == nil || block.Type != "CERTIFICATE" {
return "", fmt.Errorf("home ca certificate pem is invalid")
}
cert, errParse := x509.ParseCertificate(block.Bytes)
if errParse != nil {
return "", errParse
}
sum := sha256.Sum256(cert.Raw)
return hex.EncodeToString(sum[:]), nil
}
func normalizeFingerprint(fingerprint string) string {
fingerprint = strings.TrimSpace(strings.ToLower(fingerprint))
fingerprint = strings.ReplaceAll(fingerprint, ":", "")
fingerprint = strings.ReplaceAll(fingerprint, " ", "")
return fingerprint
}
func loadOrCreateClientKey(path string) (*rsa.PrivateKey, error) {
if fileExists(path) {
raw, errRead := os.ReadFile(path)
if errRead != nil {
return nil, errRead
}
key, errParse := parseRSAPrivateKeyPEM(raw)
if errParse != nil {
return nil, errParse
}
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
return nil, errChmod
}
return key, nil
}
key, errKey := rsa.GenerateKey(rand.Reader, 2048)
if errKey != nil {
return nil, errKey
}
raw := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
if errWrite := writeFile0600(path, raw); errWrite != nil {
return nil, errWrite
}
return key, nil
}
func writeFile0600(path string, raw []byte) error {
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
return errWrite
}
return os.Chmod(path, 0o600)
}
func chmodCertificateFiles(paths certificatePaths) error {
for _, path := range []string{paths.ClientCert, paths.ClientKey, paths.CACert} {
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
return errChmod
}
}
return nil
}
func parseRSAPrivateKeyPEM(raw []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(raw)
if block == nil {
return nil, fmt.Errorf("client key pem is invalid")
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
case "PRIVATE KEY":
key, errParse := x509.ParsePKCS8PrivateKey(block.Bytes)
if errParse != nil {
return nil, errParse
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("client key is not rsa")
}
return rsaKey, nil
default:
return nil, fmt.Errorf("client key pem type %q is unsupported", block.Type)
}
}
func createClientCSR(certificateID string, key *rsa.PrivateKey) ([]byte, error) {
certificateID = strings.TrimSpace(certificateID)
if certificateID == "" {
return nil, fmt.Errorf("certificate id is required")
}
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: certificateID,
},
}
der, errCreate := x509.CreateCertificateRequest(rand.Reader, template, key)
if errCreate != nil {
return nil, errCreate
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: der}), nil
}
func requestClientCertificate(ctx context.Context, claims homeJWTClaims, csrPEM []byte) (certificateRequestResponse, error) {
var response certificateRequestResponse
if ctx == nil {
ctx = context.Background()
}
dialCtx, cancel := context.WithTimeout(ctx, homeCertificateRequestTimeout)
defer cancel()
addr := net.JoinHostPort(strings.TrimSpace(claims.IP), strconv.Itoa(claims.Port))
conn, errDial := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr)
if errDial != nil {
return response, errDial
}
defer func() {
_ = conn.Close()
}()
if deadline, ok := dialCtx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, claims.EnrollmentSecret, string(csrPEM))); errWrite != nil {
return response, errWrite
}
raw, errRead := readRESPBulk(bufio.NewReader(conn))
if errRead != nil {
return response, errRead
}
if errUnmarshal := json.Unmarshal(raw, &response); errUnmarshal != nil {
return response, errUnmarshal
}
if !response.OK {
return response, fmt.Errorf("home certificate request failed")
}
return response, nil
}
func encodeRESPArray(args ...string) []byte {
var buf bytes.Buffer
buf.WriteString("*")
buf.WriteString(strconv.Itoa(len(args)))
buf.WriteString("\r\n")
for _, arg := range args {
buf.WriteString("$")
buf.WriteString(strconv.Itoa(len(arg)))
buf.WriteString("\r\n")
buf.WriteString(arg)
buf.WriteString("\r\n")
}
return buf.Bytes()
}
func readRESPBulk(reader *bufio.Reader) ([]byte, error) {
prefix, errRead := reader.ReadByte()
if errRead != nil {
return nil, errRead
}
switch prefix {
case '$':
line, errLine := reader.ReadString('\n')
if errLine != nil {
return nil, errLine
}
size, errSize := strconv.Atoi(strings.TrimSpace(line))
if errSize != nil {
return nil, errSize
}
if size < 0 {
return nil, fmt.Errorf("home certificate request returned nil")
}
payload := make([]byte, size+2)
if _, errFull := io.ReadFull(reader, payload); errFull != nil {
return nil, errFull
}
return payload[:size], nil
case '-':
line, errLine := reader.ReadString('\n')
if errLine != nil {
return nil, errLine
}
return nil, fmt.Errorf("%s", strings.TrimSpace(line))
default:
return nil, fmt.Errorf("home certificate request returned unsupported resp prefix %q", prefix)
}
}
func fileExists(path string) bool {
info, errStat := os.Stat(path)
return errStat == nil && !info.IsDir()
}
+99 -17
View File
@@ -31,6 +31,8 @@ const (
homeReconnectInterval = time.Second
homeReconnectFailoverThreshold = 3
homeRedisOperationTimeout = 3 * time.Second
homeSubscriptionReceiveTimeout = 3 * time.Second
redisChannelCluster = "cluster"
)
@@ -172,21 +174,30 @@ func (c *Client) ensureClients() error {
}
func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
tlsConfig, errTLS := c.homeTLSConfigLocked()
tlsConfig, errTLS := c.homeTLSConfigLocked(addr)
if errTLS != nil {
return nil, errTLS
}
return &redis.Options{
Addr: addr,
Password: c.homeCfg.Password,
TLSConfig: tlsConfig,
Addr: addr,
TLSConfig: tlsConfig,
DialTimeout: homeRedisOperationTimeout,
ReadTimeout: homeRedisOperationTimeout,
WriteTimeout: homeRedisOperationTimeout,
MaxRetries: -1,
DialerRetries: 1,
ContextTimeoutEnabled: true,
}, nil
}
func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
func (c *Client) homeTLSConfigLocked(addr string) (*tls.Config, error) {
serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName)
if serverName == "" {
serverName = strings.TrimSpace(c.seedHost)
if c.homeCfg.TLS.UseTargetServerName {
serverName = hostFromAddress(addr)
} else {
serverName = strings.TrimSpace(c.seedHost)
}
}
if serverName == "" {
serverName = strings.TrimSpace(c.homeCfg.Host)
@@ -194,6 +205,14 @@ func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
return newHomeTLSConfig(c.homeCfg.TLS, serverName)
}
func hostFromAddress(addr string) string {
host, _, errSplit := net.SplitHostPort(strings.TrimSpace(addr))
if errSplit == nil {
return strings.TrimSpace(host)
}
return strings.TrimSpace(addr)
}
func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) {
if !cfg.Enable {
return nil, nil
@@ -210,6 +229,19 @@ func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls
InsecureSkipVerify: cfg.InsecureSkipVerify,
}
clientCertPath := strings.TrimSpace(cfg.ClientCert)
clientKeyPath := strings.TrimSpace(cfg.ClientKey)
if clientCertPath != "" || clientKeyPath != "" {
if clientCertPath == "" || clientKeyPath == "" {
return nil, fmt.Errorf("home tls: client certificate and key must be set together")
}
certPair, errLoad := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
if errLoad != nil {
return nil, fmt.Errorf("home tls: load client certificate: %w", errLoad)
}
tlsConfig.Certificates = []tls.Certificate{certPair}
}
caCertPath := strings.TrimSpace(cfg.CACert)
if caCertPath == "" {
return tlsConfig, nil
@@ -404,6 +436,25 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) {
}
c.reconnectFailures = 0
return c.switchToNextNodeLocked()
}
func (c *Client) failoverAfterSubscriptionTimeout() (bool, string) {
if c == nil {
return false, ""
}
c.mu.Lock()
defer c.mu.Unlock()
if !c.clusterDiscoveryEnabledLocked() {
c.reconnectFailures = 0
return false, ""
}
c.reconnectFailures = 0
return c.switchToNextNodeLocked()
}
func (c *Client) switchToNextNodeLocked() (bool, string) {
currentHost := strings.TrimSpace(c.homeCfg.Host)
currentPort := c.homeCfg.Port
candidates := append([]clusterNode(nil), c.clusterNodes...)
@@ -426,6 +477,13 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) {
return false, ""
}
func (c *Client) markSubscriptionTimeout() {
switched, addr := c.failoverAfterSubscriptionTimeout()
if switched {
log.Warnf("home subscription heartbeat timeout; switching to %s", addr)
}
}
func (c *Client) resetReconnectFailures() {
if c == nil {
return
@@ -683,7 +741,7 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
}
// Ensure the subscription is established before marking heartbeat OK.
if _, errReceive := pubsub.Receive(ctx); errReceive != nil {
if _, errReceive := pubsub.ReceiveTimeout(ctx, homeSubscriptionReceiveTimeout); errReceive != nil {
_ = pubsub.Close()
c.markReconnectFailure("subscribe")
sleepWithContext(ctx, homeReconnectInterval)
@@ -694,28 +752,52 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
c.heartbeatOK.Store(true)
for {
msg, errMsg := pubsub.ReceiveMessage(ctx)
event, errMsg := pubsub.ReceiveTimeout(ctx, homeSubscriptionReceiveTimeout)
if errMsg != nil {
_ = pubsub.Close()
c.heartbeatOK.Store(false)
c.markReconnectFailure("subscription")
if isTimeoutError(errMsg) {
c.markSubscriptionTimeout()
} else {
c.markReconnectFailure("subscription")
}
sleepWithContext(ctx, homeReconnectInterval)
break
}
if msg == nil {
continue
}
if errApply := c.handleSubscriptionPayload(msg.Channel, msg.Payload, onConfig); errApply != nil {
if strings.EqualFold(strings.TrimSpace(msg.Channel), redisChannelCluster) {
log.Warn("failed to apply cluster update from home control center, ignoring")
} else {
log.Warn("failed to apply config update from home control center, ignoring")
switch msg := event.(type) {
case *redis.Message:
if msg == nil {
continue
}
if errApply := c.handleSubscriptionPayload(msg.Channel, msg.Payload, onConfig); errApply != nil {
if strings.EqualFold(strings.TrimSpace(msg.Channel), redisChannelCluster) {
log.Warn("failed to apply cluster update from home control center, ignoring")
} else {
log.Warn("failed to apply config update from home control center, ignoring")
}
}
case *redis.Pong:
c.resetReconnectFailures()
case *redis.Subscription:
continue
default:
log.Debugf("home subscription returned unsupported message type %T", event)
}
}
}
}
func isTimeoutError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
func sleepWithContext(ctx context.Context, d time.Duration) {
if d <= 0 {
return
+5 -6
View File
@@ -37,10 +37,9 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) {
func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
client := New(config.HomeConfig{
Enabled: true,
Host: "127.0.0.1",
Port: 6379,
Password: "secret",
Enabled: true,
Host: "127.0.0.1",
Port: 6379,
})
client.mu.Lock()
@@ -53,8 +52,8 @@ func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
if options.TLSConfig != nil {
t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig)
}
if options.Password != "secret" {
t.Fatalf("Password = %q, want secret", options.Password)
if options.Password != "" {
t.Fatalf("Password = %q, want empty", options.Password)
}
}
+55
View File
@@ -2,16 +2,24 @@ package logging
import (
"context"
"net/http"
"sync"
"sync/atomic"
)
type endpointKey struct{}
type responseStatusKey struct{}
type responseHeadersKey struct{}
type responseStatusHolder struct {
status atomic.Int32
}
type responseHeadersHolder struct {
mu sync.RWMutex
headers http.Header
}
func WithEndpoint(ctx context.Context, endpoint string) context.Context {
if ctx == nil {
ctx = context.Background()
@@ -39,6 +47,16 @@ func WithResponseStatusHolder(ctx context.Context) context.Context {
return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{})
}
func WithResponseHeadersHolder(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
if holder, ok := ctx.Value(responseHeadersKey{}).(*responseHeadersHolder); ok && holder != nil {
return ctx
}
return context.WithValue(ctx, responseHeadersKey{}, &responseHeadersHolder{})
}
func SetResponseStatus(ctx context.Context, status int) {
if ctx == nil || status <= 0 {
return
@@ -50,6 +68,19 @@ func SetResponseStatus(ctx context.Context, status int) {
holder.status.Store(int32(status))
}
func SetResponseHeaders(ctx context.Context, headers http.Header) {
if ctx == nil {
return
}
holder, ok := ctx.Value(responseHeadersKey{}).(*responseHeadersHolder)
if !ok || holder == nil {
return
}
holder.mu.Lock()
defer holder.mu.Unlock()
holder.headers = cloneHTTPHeader(headers)
}
func GetResponseStatus(ctx context.Context) int {
if ctx == nil {
return 0
@@ -60,3 +91,27 @@ func GetResponseStatus(ctx context.Context) int {
}
return int(holder.status.Load())
}
func GetResponseHeaders(ctx context.Context) http.Header {
if ctx == nil {
return nil
}
holder, ok := ctx.Value(responseHeadersKey{}).(*responseHeadersHolder)
if !ok || holder == nil {
return nil
}
holder.mu.RLock()
defer holder.mu.RUnlock()
return cloneHTTPHeader(holder.headers)
}
func cloneHTTPHeader(src http.Header) http.Header {
if len(src) == 0 {
return nil
}
dst := make(http.Header, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
+38 -29
View File
@@ -3,6 +3,7 @@ package redisqueue
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
@@ -47,6 +48,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
apiKey := strings.TrimSpace(record.APIKey)
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
reasoningEffort := strings.TrimSpace(record.ReasoningEffort)
if reasoningEffort == "" {
reasoningEffort = coreusage.ReasoningEffortFromContext(ctx)
}
tokens := tokenStats{
InputTokens: record.Detail.InputTokens,
@@ -71,24 +76,26 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
fail := resolveFail(ctx, record, failed)
detail := requestDetail{
Timestamp: timestamp,
LatencyMs: record.Latency.Milliseconds(),
Source: record.Source,
AuthIndex: record.AuthIndex,
Tokens: tokens,
Failed: failed,
Fail: fail,
Timestamp: timestamp,
LatencyMs: record.Latency.Milliseconds(),
Source: record.Source,
AuthIndex: record.AuthIndex,
Tokens: tokens,
Failed: failed,
Fail: fail,
ResponseHeaders: record.ResponseHeaders,
}
payload, err := json.Marshal(queuedUsageDetail{
requestDetail: detail,
Provider: provider,
Model: modelName,
Alias: aliasName,
Endpoint: resolveEndpoint(ctx),
AuthType: authType,
APIKey: apiKey,
RequestID: requestID,
requestDetail: detail,
Provider: provider,
Model: modelName,
Alias: aliasName,
Endpoint: resolveEndpoint(ctx),
AuthType: authType,
APIKey: apiKey,
RequestID: requestID,
ReasoningEffort: reasoningEffort,
})
if err != nil {
return
@@ -98,23 +105,25 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
type queuedUsageDetail struct {
requestDetail
Provider string `json:"provider"`
Model string `json:"model"`
Alias string `json:"alias"`
Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"`
APIKey string `json:"api_key"`
RequestID string `json:"request_id"`
Provider string `json:"provider"`
Model string `json:"model"`
Alias string `json:"alias"`
Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"`
APIKey string `json:"api_key"`
RequestID string `json:"request_id"`
ReasoningEffort string `json:"reasoning_effort"`
}
type requestDetail struct {
Timestamp time.Time `json:"timestamp"`
LatencyMs int64 `json:"latency_ms"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens tokenStats `json:"tokens"`
Failed bool `json:"failed"`
Fail failDetail `json:"fail"`
Timestamp time.Time `json:"timestamp"`
LatencyMs int64 `json:"latency_ms"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens tokenStats `json:"tokens"`
Failed bool `json:"failed"`
Fail failDetail `json:"fail"`
ResponseHeaders http.Header `json:"response_headers,omitempty"`
}
type tokenStats struct {
+88 -10
View File
@@ -19,9 +19,69 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusOK)
responseHeaders := http.Header{}
responseHeaders.Add("X-Upstream-Request-Id", "upstream-req-1")
responseHeaders.Add("Retry-After", "30")
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
Source: "user@example.com",
ReasoningEffort: "medium",
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
Latency: 1500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
ResponseHeaders: responseHeaders.Clone(),
})
responseHeaders.Set("Retry-After", "999")
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey")
requireMissingField(t, payload, "user_api_key")
requireStringField(t, payload, "request_id", "ctx-request-id")
requireStringField(t, payload, "reasoning_effort", "medium")
requireHeaderField(t, payload, "response_headers", "X-Upstream-Request-Id", []string{"upstream-req-1"})
requireHeaderField(t, payload, "response_headers", "Retry-After", []string{"30"})
requireBoolField(t, payload, "failed", false)
requireFailField(t, payload, http.StatusOK, "")
})
}
func TestUsageQueuePluginAsyncUsesRecordResponseHeaders(t *testing.T) {
withEnabledQueue(t, func() {
ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id")
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
ctx = internallogging.WithResponseStatusHolder(ctx)
ctx = internallogging.WithResponseHeadersHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusOK)
initialHeaders := http.Header{}
initialHeaders.Set("X-Upstream-Request-Id", "upstream-req-1")
internallogging.SetResponseHeaders(ctx, initialHeaders)
mgr := coreusage.NewManager(16)
defer mgr.Stop()
mgr.Register(pluginFunc(func(ctx context.Context, _ coreusage.Record) {
nextHeaders := http.Header{}
nextHeaders.Set("X-Upstream-Request-Id", "upstream-req-2")
internallogging.SetResponseHeaders(ctx, nextHeaders)
}))
mgr.Register(&usageQueuePlugin{})
mgr.Publish(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
Alias: "client-gpt",
@@ -36,18 +96,11 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
OutputTokens: 20,
TotalTokens: 30,
},
ResponseHeaders: internallogging.GetResponseHeaders(ctx),
})
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey")
requireMissingField(t, payload, "user_api_key")
requireStringField(t, payload, "request_id", "ctx-request-id")
requireBoolField(t, payload, "failed", false)
requireFailField(t, payload, http.StatusOK, "")
payload := waitForSinglePayload(t, 2*time.Second)
requireHeaderField(t, payload, "response_headers", "X-Upstream-Request-Id", []string{"upstream-req-1"})
})
}
@@ -276,3 +329,28 @@ func requireFailField(t *testing.T, payload map[string]json.RawMessage, wantStat
t.Fatalf("fail = {status_code:%d body:%q}, want {status_code:%d body:%q}", got.StatusCode, got.Body, wantStatus, wantBody)
}
}
func requireHeaderField(t *testing.T, payload map[string]json.RawMessage, field, key string, want []string) {
t.Helper()
raw, ok := payload[field]
if !ok {
t.Fatalf("payload missing %q", field)
}
var headers map[string][]string
if err := json.Unmarshal(raw, &headers); err != nil {
t.Fatalf("unmarshal %q: %v", field, err)
}
got, ok := headers[key]
if !ok {
t.Fatalf("%s missing header %q", field, key)
}
if len(got) != len(want) {
t.Fatalf("%s[%q] = %v, want %v", field, key, got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("%s[%q] = %v, want %v", field, key, got, want)
}
}
}
-146
View File
@@ -1,146 +0,0 @@
package registry
import "testing"
func TestCodexFreeModelsExcludeGPT55(t *testing.T) {
model := findModelInfo(GetCodexFreeModels(), "gpt-5.5")
if model != nil {
t.Fatal("expected codex free tier to NOT include gpt-5.5")
}
}
func TestCodexStaticModelsIncludeGPT55(t *testing.T) {
tierModels := map[string][]*ModelInfo{
"team": GetCodexTeamModels(),
"plus": GetCodexPlusModels(),
"pro": GetCodexProModels(),
}
for tier, models := range tierModels {
t.Run(tier, func(t *testing.T) {
model := findModelInfo(models, "gpt-5.5")
if model == nil {
t.Fatalf("expected codex %s tier to include gpt-5.5", tier)
}
assertGPT55ModelInfo(t, tier, model)
})
}
model := LookupStaticModelInfo("gpt-5.5")
if model == nil {
t.Fatal("expected LookupStaticModelInfo to find gpt-5.5")
}
assertGPT55ModelInfo(t, "lookup", model)
}
func TestWithXAIBuiltinsAddsVideoModel(t *testing.T) {
models := WithXAIBuiltins(nil)
found := false
for _, model := range models {
if model != nil && model.ID == xaiBuiltinVideoModelID {
found = true
if model.OwnedBy != "xai" {
t.Fatalf("OwnedBy = %q, want xai", model.OwnedBy)
}
}
}
if !found {
t.Fatalf("expected %s builtin model", xaiBuiltinVideoModelID)
}
}
func TestValidateModelsCatalogAllowsMissingSections(t *testing.T) {
data := validTestModelsCatalog()
data.XAI = nil
if err := validateModelsCatalog(data); err != nil {
t.Fatalf("validateModelsCatalog() error = %v", err)
}
}
func TestValidateModelsCatalogRejectsInvalidDefinitions(t *testing.T) {
data := validTestModelsCatalog()
data.Claude = []*ModelInfo{{ID: ""}}
if err := validateModelsCatalog(data); err == nil {
t.Fatal("expected invalid model definition error")
}
}
func validTestModelsCatalog() *staticModelsJSON {
models := []*ModelInfo{{ID: "test-model"}}
return &staticModelsJSON{
Claude: models,
Gemini: models,
Vertex: models,
GeminiCLI: models,
AIStudio: models,
CodexFree: models,
CodexTeam: models,
CodexPlus: models,
CodexPro: models,
Kimi: models,
Antigravity: models,
XAI: models,
}
}
func findModelInfo(models []*ModelInfo, id string) *ModelInfo {
for _, model := range models {
if model != nil && model.ID == id {
return model
}
}
return nil
}
func assertGPT55ModelInfo(t *testing.T, source string, model *ModelInfo) {
t.Helper()
if model.ID != "gpt-5.5" {
t.Fatalf("%s id mismatch: got %q", source, model.ID)
}
if model.Object != "model" {
t.Fatalf("%s object mismatch: got %q", source, model.Object)
}
if model.Created != 1776902400 {
t.Fatalf("%s created timestamp mismatch: got %d", source, model.Created)
}
if model.OwnedBy != "openai" {
t.Fatalf("%s owned_by mismatch: got %q", source, model.OwnedBy)
}
if model.Type != "openai" {
t.Fatalf("%s type mismatch: got %q", source, model.Type)
}
if model.DisplayName != "GPT 5.5" {
t.Fatalf("%s display name mismatch: got %q", source, model.DisplayName)
}
if model.Version != "gpt-5.5" {
t.Fatalf("%s version mismatch: got %q", source, model.Version)
}
if model.Description != "Frontier model for complex coding, research, and real-world work." {
t.Fatalf("%s description mismatch: got %q", source, model.Description)
}
if model.ContextLength != 272000 {
t.Fatalf("%s context length mismatch: got %d", source, model.ContextLength)
}
if model.MaxCompletionTokens != 128000 {
t.Fatalf("%s max completion tokens mismatch: got %d", source, model.MaxCompletionTokens)
}
if len(model.SupportedParameters) != 1 || model.SupportedParameters[0] != "tools" {
t.Fatalf("%s supported parameters mismatch: got %v", source, model.SupportedParameters)
}
if model.Thinking == nil {
t.Fatalf("%s missing thinking support", source)
}
want := []string{"low", "medium", "high", "xhigh"}
if len(model.Thinking.Levels) != len(want) {
t.Fatalf("%s thinking level count mismatch: got %d, want %d", source, len(model.Thinking.Levels), len(want))
}
for i, level := range want {
if model.Thinking.Levels[i] != level {
t.Fatalf("%s thinking level %d mismatch: got %q, want %q", source, i, model.Thinking.Levels[i], level)
}
}
}
+3
View File
@@ -15,6 +15,9 @@ import (
log "github.com/sirupsen/logrus"
)
// OpenAIImageModelType marks models that are callable through OpenAI-compatible image endpoints.
const OpenAIImageModelType = "openai-image"
// ModelInfo represents information about an available model
type ModelInfo struct {
// ID is the unique identifier for the model
+155
View File
@@ -421,6 +421,36 @@
"high"
]
}
},
{
"id": "gemini-3.5-flash",
"object": "model",
"created": 1779235200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.5 Flash",
"name": "models/gemini-3.5-flash",
"version": "3.5",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
}
],
"vertex": [
@@ -762,6 +792,36 @@
"supportedGenerationMethods": [
"predict"
]
},
{
"id": "gemini-3.5-flash",
"object": "model",
"created": 1779235200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.5 Flash",
"name": "models/gemini-3.5-flash",
"version": "3.5",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
}
],
"gemini-cli": [
@@ -1221,6 +1281,36 @@
"createCachedContent",
"batchGenerateContent"
]
},
{
"id": "gemini-3.5-flash",
"object": "model",
"created": 1779235200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.5 Flash",
"name": "models/gemini-3.5-flash",
"version": "3.5",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
}
],
"codex-free": [
@@ -1954,6 +2044,28 @@
]
}
},
{
"id": "gemini-3-flash-agent",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.5 Flash",
"name": "gemini-3-flash-agent",
"description": "Gemini 3.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gemini-3-pro-high",
"object": "model",
@@ -2087,9 +2199,52 @@
"high"
]
}
},
{
"id": "gemini-3.5-flash-low",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.5 Flash (Low)",
"name": "gemini-3.5-flash-low",
"description": "Gemini 3.5 Flash (Low)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"min": 1,
"max": 65535,
"dynamic_allowed": true,
"levels": [
"low",
"medium",
"high"
]
}
}
],
"xai": [
{
"id": "grok-build-0.1",
"object": "model",
"created": 1779321600,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok Build 0.1",
"name": "grok-build-0.1",
"description": "Grok Build 0.1 is xAIs fast coding model trained specifically for agentic software engineering workflows.",
"context_length": 256000,
"max_completion_tokens": 256000,
"thinking": {
"zero_allowed": true,
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "grok-4.3",
"object": "model",
@@ -1415,6 +1415,41 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
return updated, nil
}
func (e *AntigravityExecutor) ShouldPrepareRequestAuth(auth *cliproxyauth.Auth) bool {
return antigravityProjectIDFromAuth(auth) == ""
}
func (e *AntigravityExecutor) PrepareRequestAuth(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil || !e.ShouldPrepareRequestAuth(auth) {
return nil, nil
}
updated := auth.Clone()
token, refreshedAuth, errToken := e.ensureAccessToken(ctx, updated)
if errToken != nil {
return nil, errToken
}
if refreshedAuth != nil {
updated = refreshedAuth
}
if antigravityProjectIDFromAuth(updated) != "" {
return updated, nil
}
projectID, errProject := e.fetchAntigravityProjectID(ctx, updated, token)
if errProject != nil {
return nil, missingAntigravityProjectIDError(errProject)
}
if projectID == "" {
return nil, missingAntigravityProjectIDError(nil)
}
if updated.Metadata == nil {
updated.Metadata = make(map[string]any)
}
updated.Metadata["project_id"] = projectID
return updated, nil
}
// CountTokens counts tokens for the given request using the Antigravity API.
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
@@ -1752,34 +1787,67 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au
return nil
}
if auth.Metadata["project_id"] != nil {
if antigravityProjectIDFromAuth(auth) != "" {
return nil
}
token := strings.TrimSpace(accessToken)
if token == "" {
token = metaStringValue(auth.Metadata, "access_token")
}
if token == "" {
return nil
}
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
projectID, errFetch := e.fetchAntigravityProjectID(ctx, auth, accessToken)
if errFetch != nil {
return errFetch
}
if strings.TrimSpace(projectID) == "" {
if projectID == "" {
return nil
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
auth.Metadata["project_id"] = projectID
return nil
}
func (e *AntigravityExecutor) fetchAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (string, error) {
token := strings.TrimSpace(accessToken)
if token == "" {
token = metaStringValue(auth.Metadata, "access_token")
}
if token == "" {
return "", nil
}
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
if errFetch != nil {
return "", errFetch
}
return strings.TrimSpace(projectID), nil
}
func (e *AntigravityExecutor) projectIDForRequest(_ context.Context, auth *cliproxyauth.Auth, _ string) (string, error) {
if projectID := antigravityProjectIDFromAuth(auth); projectID != "" {
return projectID, nil
}
return "", missingAntigravityProjectIDError(nil)
}
func antigravityProjectIDFromAuth(auth *cliproxyauth.Auth) string {
if auth == nil || auth.Metadata == nil {
return ""
}
if pid, ok := auth.Metadata["project_id"].(string); ok {
return strings.TrimSpace(pid)
}
return ""
}
func missingAntigravityProjectIDError(cause error) statusErr {
msg := "antigravity auth missing project_id"
if cause != nil {
msg = fmt.Sprintf("%s: %v", msg, cause)
}
return statusErr{code: http.StatusBadRequest, msg: msg}
}
func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) {
if auth == nil || strings.TrimSpace(auth.ID) == "" {
return
@@ -1792,19 +1860,17 @@ func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Contex
return
}
userAgent := resolveLoadCodeAssistUserAgent(auth)
userAgent := resolveUserAgent(auth)
loadReqBody, errMarshal := json.Marshal(map[string]any{
"metadata": map[string]string{
"ide_type": "ANTIGRAVITY",
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
"ide_name": "antigravity",
"ideType": "ANTIGRAVITY",
},
})
if errMarshal != nil {
log.Debugf("antigravity executor: marshal loadCodeAssist request error: %v", errMarshal)
return
}
baseURL := buildBaseURL(auth)
baseURL := antigravityLoadCodeAssistBaseURL(auth)
endpointURL := strings.TrimSuffix(baseURL, "/") + "/v1internal:loadCodeAssist"
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(loadReqBody))
if errReq != nil {
@@ -1812,9 +1878,9 @@ func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Contex
return
}
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("Accept", "*/*")
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("User-Agent", userAgent)
httpReq.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
@@ -1909,12 +1975,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
requestURL.WriteString(url.QueryEscape(alt))
}
// Extract project_id from auth metadata if available
projectID := ""
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(pid)
}
projectID, errProject := e.projectIDForRequest(ctx, auth, token)
if errProject != nil {
return nil, errProject
}
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", modelName)
@@ -2100,6 +2163,13 @@ func buildBaseURL(auth *cliproxyauth.Auth) string {
return antigravityBaseURLDaily
}
func antigravityLoadCodeAssistBaseURL(auth *cliproxyauth.Auth) string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return base
}
return antigravityBaseURLProd
}
func resolveHost(base string) string {
parsed, errParse := url.Parse(base)
if errParse != nil {
@@ -2338,11 +2408,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
}
template, _ = sjson.SetBytes(template, "requestType", reqType)
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
template, _ = sjson.SetBytes(template, "project", projectID)
} else {
template, _ = sjson.SetBytes(template, "project", generateProjectID())
template, _ = sjson.DeleteBytes(template, "project")
}
if isImageModel {
@@ -2391,14 +2460,3 @@ func generateStableSessionID(payload []byte) string {
}
return generateSessionID()
}
func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randSourceMutex.Lock()
adj := adjectives[randSource.Intn(len(adjectives))]
noun := nouns[randSource.Intn(len(nouns))]
randSourceMutex.Unlock()
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}
@@ -4,7 +4,10 @@ import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"testing"
"time"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
)
@@ -90,6 +93,82 @@ func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithEmptyToolsArray(t *t
assertNonSchemaRequestPreserved(t, body)
}
func TestAntigravityBuildRequest_UsesAuthProjectID(t *testing.T) {
body := buildRequestBodyFromRawPayload(t, "gemini-3.1-pro", []byte(`{
"request": {
"contents": [
{
"role": "user",
"parts": [{"text": "hello"}]
}
]
}
}`))
if got, ok := body["project"].(string); !ok || got != "project-1" {
t.Fatalf("project should come from auth metadata, got=%v", body["project"])
}
}
func TestAntigravityPrepareRequestAuth_FetchesMissingProjectID(t *testing.T) {
executor := &AntigravityExecutor{}
auth := &cliproxyauth.Auth{Metadata: map[string]any{
"access_token": "token",
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
}}
ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" {
t.Fatalf("unexpected project discovery request: %s", req.URL.String())
}
if got := req.Header.Get("X-Goog-Api-Client"); got != "" {
t.Fatalf("X-Goog-Api-Client = %q, want empty", got)
}
raw, errRead := io.ReadAll(req.Body)
if errRead != nil {
t.Fatalf("read discovery body: %v", errRead)
}
if !strings.Contains(string(raw), `"ideType":"ANTIGRAVITY"`) {
t.Fatalf("unexpected discovery body: %s", string(raw))
}
return &http.Response{
StatusCode: http.StatusOK,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(`{"cloudaicompanionProject":"fetched-project"}`)),
}, nil
}))
updated, err := executor.PrepareRequestAuth(ctx, auth)
if err != nil {
t.Fatalf("PrepareRequestAuth error: %v", err)
}
if updated == nil {
t.Fatalf("PrepareRequestAuth returned nil auth")
}
if _, ok := auth.Metadata["project_id"]; ok {
t.Fatalf("original auth metadata should not be mutated")
}
if got, ok := updated.Metadata["project_id"].(string); !ok || got != "fetched-project" {
t.Fatalf("updated auth metadata project_id = %v, want fetched-project", updated.Metadata["project_id"])
}
}
func TestAntigravityBuildRequest_RejectsMissingProjectID(t *testing.T) {
executor := &AntigravityExecutor{}
auth := &cliproxyauth.Auth{Metadata: map[string]any{}}
_, err := executor.buildRequest(context.Background(), auth, "token", "gemini-3.1-pro", []byte(`{"request":{}}`), false, "", "https://example.com")
if err == nil {
t.Fatalf("buildRequest should fail when auth has no project_id")
}
status, ok := err.(interface{ StatusCode() int })
if !ok {
t.Fatalf("error should expose status code, got %T", err)
}
if got := status.StatusCode(); got != http.StatusBadRequest {
t.Fatalf("status code = %d, want %d", got, http.StatusBadRequest)
}
}
func assertNonSchemaRequestPreserved(t *testing.T, body map[string]any) {
t.Helper()
@@ -172,13 +251,19 @@ func buildRequestBodyFromRawPayload(t *testing.T, modelName string, payload []by
t.Helper()
executor := &AntigravityExecutor{}
auth := &cliproxyauth.Auth{}
auth := &cliproxyauth.Auth{Metadata: map[string]any{"project_id": "project-1"}}
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
if err != nil {
t.Fatalf("buildRequest error: %v", err)
}
return requestBody(t, req)
}
func requestBody(t *testing.T, req *http.Request) map[string]any {
t.Helper()
raw, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("read request body error: %v", err)
@@ -444,24 +444,25 @@ func TestUpdateAntigravityCreditsBalance_LoadCodeAssistUserAgent(t *testing.T) {
t.Cleanup(resetAntigravityCreditsRetryState)
exec := NewAntigravityExecutor(&config.Config{})
const userAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0"
const configuredUserAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0"
const loadCodeAssistUserAgent = "antigravity/1.23.2 windows/amd64"
auth := &cliproxyauth.Auth{
ID: "auth-load-code-assist-ua",
Attributes: map[string]string{"user_agent": userAgent},
Attributes: map[string]string{"user_agent": configuredUserAgent},
}
ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" {
t.Fatalf("unexpected request url %s", req.URL.String())
}
if got := req.Header.Get("User-Agent"); got != userAgent {
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
if got := req.Header.Get("User-Agent"); got != loadCodeAssistUserAgent {
t.Fatalf("User-Agent = %q, want %q", got, loadCodeAssistUserAgent)
}
if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" {
t.Fatalf("X-Goog-Api-Client = %q, want %q", got, "gl-node/22.21.1")
if got := req.Header.Get("X-Goog-Api-Client"); got != "" {
t.Fatalf("X-Goog-Api-Client = %q, want empty", got)
}
body, _ := io.ReadAll(req.Body)
_ = req.Body.Close()
if string(body) != `{"metadata":{"ide_name":"antigravity","ide_type":"ANTIGRAVITY","ide_version":"1.23.2"}}` {
if string(body) != `{"metadata":{"ideType":"ANTIGRAVITY"}}` {
t.Fatalf("loadCodeAssist body = %s", string(body))
}
return &http.Response{
+117
View File
@@ -100,6 +100,103 @@ func patchCodexCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]
return completedDataPatched
}
func codexTerminalStreamContextLengthErr(eventData []byte) (statusErr, bool) {
eventType := gjson.GetBytes(eventData, "type").String()
var body []byte
switch eventType {
case "error":
body = codexTerminalErrorBody(eventData, "error")
if len(body) == 0 {
body = codexTerminalTopLevelErrorBody(eventData)
}
case "response.failed":
body = codexTerminalErrorBody(eventData, "response.error")
if len(body) == 0 {
body = codexTerminalErrorBody(eventData, "error")
}
default:
return statusErr{}, false
}
if len(body) == 0 {
return statusErr{}, false
}
if !codexTerminalErrorIsContextLength(body) {
return statusErr{}, false
}
return newCodexStatusErr(http.StatusBadRequest, body), true
}
func codexTerminalErrorBody(eventData []byte, path string) []byte {
errorResult := gjson.GetBytes(eventData, path)
if !errorResult.Exists() {
return nil
}
body := []byte(`{"error":{}}`)
if errorResult.Type == gjson.JSON {
body, _ = sjson.SetRawBytes(body, "error", []byte(errorResult.Raw))
} else if message := strings.TrimSpace(errorResult.String()); message != "" {
body, _ = sjson.SetBytes(body, "error.message", message)
}
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
if message := strings.TrimSpace(gjson.GetBytes(eventData, "response.error.message").String()); message != "" {
body, _ = sjson.SetBytes(body, "error.message", message)
}
}
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" {
body, _ = sjson.SetBytes(body, "error.message", code)
}
}
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
if errorType := strings.TrimSpace(gjson.GetBytes(body, "error.type").String()); errorType != "" {
body, _ = sjson.SetBytes(body, "error.message", errorType)
}
}
return body
}
func codexTerminalTopLevelErrorBody(eventData []byte) []byte {
message := strings.TrimSpace(gjson.GetBytes(eventData, "message").String())
code := strings.TrimSpace(gjson.GetBytes(eventData, "code").String())
errorType := strings.TrimSpace(gjson.GetBytes(eventData, "error_type").String())
param := strings.TrimSpace(gjson.GetBytes(eventData, "param").String())
if message == "" && code == "" && errorType == "" && param == "" {
return nil
}
body := []byte(`{"error":{}}`)
if message != "" {
body, _ = sjson.SetBytes(body, "error.message", message)
}
if code != "" {
body, _ = sjson.SetBytes(body, "error.code", code)
}
if errorType != "" {
body, _ = sjson.SetBytes(body, "error.type", errorType)
}
if param != "" {
body, _ = sjson.SetBytes(body, "error.param", param)
}
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
if code != "" {
body, _ = sjson.SetBytes(body, "error.message", code)
} else if errorType != "" {
body, _ = sjson.SetBytes(body, "error.message", errorType)
}
}
return body
}
func codexTerminalErrorIsContextLength(body []byte) bool {
errorCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "error.code").String()))
message := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "error.message").String()))
return errorCode == "context_length_exceeded" ||
errorCode == "context_too_large" ||
strings.Contains(message, "context window") ||
strings.Contains(message, "context length") ||
strings.Contains(message, "too many tokens")
}
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
// If api_key is unavailable on auth, it falls back to legacy via ClientAdapter.
type CodexExecutor struct {
@@ -147,6 +244,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
if opts.Alt == "responses/compact" {
return e.executeCompact(ctx, auth, req, opts)
}
if isCodexOpenAIImageRequest(opts) {
return e.executeOpenAIImage(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
@@ -246,6 +346,11 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
eventData := bytes.TrimSpace(line[5:])
eventType := gjson.GetBytes(eventData, "type").String()
if streamErr, ok := codexTerminalStreamContextLengthErr(eventData); ok {
err = streamErr
return resp, err
}
if eventType == "response.output_item.done" {
itemResult := gjson.GetBytes(eventData, "item")
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
@@ -397,6 +502,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
if isCodexOpenAIImageRequest(opts) {
return e.executeOpenAIImageStream(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
@@ -500,6 +608,15 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
if bytes.HasPrefix(line, dataTag) {
data := bytes.TrimSpace(line[5:])
if streamErr, ok := codexTerminalStreamContextLengthErr(data); ok {
helps.RecordAPIResponseError(ctx, e.cfg, streamErr)
reporter.PublishFailure(ctx, streamErr)
select {
case out <- cliproxyexecutor.StreamChunk{Err: streamErr}:
case <-ctx.Done():
}
return
}
switch gjson.GetBytes(data, "type").String() {
case "response.output_item.done":
collectCodexOutputItemDone(data, outputItemsByIndex, &outputItemsFallback)
@@ -5,6 +5,7 @@ import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
@@ -46,6 +47,128 @@ func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *t
}
}
func TestCodexExecutorExecuteSurfacesTerminalStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: response.created\n"))
_, _ = w.Write([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.5"}}` + "\n\n"))
_, _ = w.Write([]byte("event: error\n"))
_, _ = w.Write([]byte(`data: {"type":"error","error":{"type":"invalid_request_error","code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again.","param":"input"},"sequence_number":2}` + "\n\n"))
_, _ = w.Write([]byte("event: response.failed\n"))
_, _ = w.Write([]byte(`data: {"type":"response.failed","response":{"id":"resp_1","status":"failed","error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."}}}` + "\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.5",
Payload: []byte(`{"model":"gpt-5.5","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: false,
})
if err == nil {
t.Fatal("expected terminal stream error, got nil")
}
if got := statusCodeFromTestError(t, err); got != http.StatusBadRequest {
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, err)
}
assertCodexErrorCode(t, err.Error(), "invalid_request_error", "context_too_large")
if !strings.Contains(err.Error(), "Your input exceeds the context window") {
t.Fatalf("error message missing upstream context text: %v", err)
}
}
func TestCodexExecutorExecuteStreamSurfacesTerminalStreamError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: response.created\n"))
_, _ = w.Write([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.5"}}` + "\n\n"))
_, _ = w.Write([]byte("event: error\n"))
_, _ = w.Write([]byte(`data: {"type":"error","error":{"type":"invalid_request_error","code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again.","param":"input"},"sequence_number":2}` + "\n\n"))
}))
defer server.Close()
executor := NewCodexExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL,
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.5",
Payload: []byte(`{"model":"gpt-5.5","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var streamErr error
for chunk := range result.Chunks {
if chunk.Err != nil {
streamErr = chunk.Err
break
}
}
if streamErr == nil {
t.Fatal("missing stream terminal error")
}
if got := statusCodeFromTestError(t, streamErr); got != http.StatusBadRequest {
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, streamErr)
}
assertCodexErrorCode(t, streamErr.Error(), "invalid_request_error", "context_too_large")
}
func TestCodexTerminalStreamContextLengthErrFromResponseFailed(t *testing.T) {
err, ok := codexTerminalStreamContextLengthErr([]byte(`{"type":"response.failed","response":{"id":"resp_1","status":"failed","error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."}}}`))
if !ok {
t.Fatal("expected context length terminal error")
}
if got := statusCodeFromTestError(t, err); got != http.StatusBadRequest {
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, err)
}
assertCodexErrorCode(t, err.Error(), "invalid_request_error", "context_too_large")
}
func TestCodexTerminalStreamContextLengthErrFromTopLevelError(t *testing.T) {
err, ok := codexTerminalStreamContextLengthErr([]byte(`{"type":"error","code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again.","sequence_number":2}`))
if !ok {
t.Fatal("expected top-level context length terminal error")
}
if got := statusCodeFromTestError(t, err); got != http.StatusBadRequest {
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, err)
}
assertCodexErrorCode(t, err.Error(), "invalid_request_error", "context_too_large")
if !strings.Contains(err.Error(), "Your input exceeds the context window") {
t.Fatalf("error message missing upstream context text: %v", err)
}
}
func TestCodexTerminalStreamContextLengthErrIgnoresOtherTerminalErrors(t *testing.T) {
_, ok := codexTerminalStreamContextLengthErr([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"Rate limit reached."}}`))
if ok {
t.Fatal("rate limit terminal error should not be handled by context length fix")
}
}
func statusCodeFromTestError(t *testing.T, err error) int {
t.Helper()
statusErr, ok := err.(interface{ StatusCode() int })
if !ok {
t.Fatalf("error %T does not expose StatusCode(): %v", err, err)
}
return statusErr.StatusCode()
}
func TestCodexExecutorExecuteStream_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
@@ -0,0 +1,678 @@
package executor
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
codexOpenAIImageSourceFormat = "openai-image"
codexImagesGenerationsPath = "/v1/images/generations"
codexImagesEditsPath = "/v1/images/edits"
codexOpenAIImagesMainModel = "gpt-5.4-mini"
)
type codexOpenAIImagePreparedRequest struct {
Body []byte
ResponseFormat string
StreamPrefix string
}
type codexImageCallResult struct {
Result string
RevisedPrompt string
OutputFormat string
Size string
Background string
Quality string
}
func isCodexOpenAIImageRequest(opts cliproxyexecutor.Options) bool {
if !strings.EqualFold(strings.TrimSpace(opts.SourceFormat.String()), codexOpenAIImageSourceFormat) {
return false
}
return codexIsImagesEndpointPath(helps.PayloadRequestPath(opts))
}
func codexIsImagesEndpointPath(path string) bool {
path = strings.TrimSpace(path)
if path == codexImagesGenerationsPath || path == codexImagesEditsPath {
return true
}
return strings.HasSuffix(path, codexImagesGenerationsPath) || strings.HasSuffix(path, codexImagesEditsPath)
}
func (e *CodexExecutor) executeOpenAIImage(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts)
if errPrepare != nil {
return resp, errPrepare
}
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth)
defer reporter.TrackFailure(ctx, &err)
body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts)
if errBuild != nil {
return resp, errBuild
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, errCache := e.cacheHelper(ctx, sdktranslator.FromString(codexOpenAIImageSourceFormat), url, req, body)
if errCache != nil {
return resp, errCache
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
recordCodexOpenAIImageRequest(ctx, e.cfg, e.Identifier(), auth, url, httpReq.Header.Clone(), body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return resp, err
}
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range bytes.Split(data, []byte("\n")) {
if !bytes.HasPrefix(line, dataTag) {
continue
}
eventData := bytes.TrimSpace(line[len(dataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
collectCodexOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
publishCodexImageToolUsage(ctx, reporter, body, eventData)
completedData := patchCodexCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
results, createdAt, usageRaw, firstMeta, errExtract := codexExtractImagesFromResponsesCompleted(completedData)
if errExtract != nil {
return resp, errExtract
}
if len(results) == 0 {
return resp, statusErr{code: http.StatusBadGateway, msg: "upstream did not return image output"}
}
out, errOutput := codexBuildImagesAPIResponse(results, createdAt, usageRaw, firstMeta, prepared.ResponseFormat)
if errOutput != nil {
return resp, errOutput
}
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
}
err = statusErr{code: http.StatusGatewayTimeout, msg: "stream error: stream disconnected before completion"}
return resp, err
}
func (e *CodexExecutor) executeOpenAIImageStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts)
if errPrepare != nil {
return nil, errPrepare
}
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth)
defer reporter.TrackFailure(ctx, &err)
body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts)
if errBuild != nil {
return nil, errBuild
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, errCache := e.cacheHelper(ctx, sdktranslator.FromString(codexOpenAIImageSourceFormat), url, req, body)
if errCache != nil {
return nil, errCache
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
recordCodexOpenAIImageRequest(ctx, e.cfg, e.Identifier(), auth, url, httpReq.Header.Clone(), body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
sendPayload := func(payload []byte) bool {
select {
case out <- cliproxyexecutor.StreamChunk{Payload: payload}:
return true
case <-ctx.Done():
return false
}
}
sendError := func(errSend error) bool {
select {
case out <- cliproxyexecutor.StreamChunk{Err: errSend}:
return true
case <-ctx.Done():
return false
}
}
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for scanner.Scan() {
line := scanner.Bytes()
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
if !bytes.HasPrefix(line, dataTag) {
continue
}
eventData := bytes.TrimSpace(line[len(dataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
collectCodexOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.image_generation_call.partial_image":
frame := codexBuildImagePartialFrame(eventData, prepared.ResponseFormat, prepared.StreamPrefix)
if len(frame) > 0 && !sendPayload(frame) {
return
}
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
publishCodexImageToolUsage(ctx, reporter, body, eventData)
completedData := patchCodexCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
results, _, usageRaw, _, errExtract := codexExtractImagesFromResponsesCompleted(completedData)
if errExtract != nil {
sendError(errExtract)
return
}
if len(results) == 0 {
sendError(statusErr{code: http.StatusBadGateway, msg: "upstream did not return image output"})
return
}
for _, img := range results {
frame := codexBuildImageCompletedFrame(img, usageRaw, prepared.ResponseFormat, prepared.StreamPrefix)
if len(frame) > 0 && !sendPayload(frame) {
return
}
}
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx, errScan)
sendError(errScan)
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) prepareCodexOpenAIImageBody(body []byte, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) {
out := body
var errThinking error
out, errThinking = thinking.ApplyThinking(out, codexOpenAIImagesMainModel, codexOpenAIImageSourceFormat, "codex", e.Identifier())
if errThinking != nil {
return nil, errThinking
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
out = helps.ApplyPayloadConfigWithRequest(e.cfg, codexOpenAIImagesMainModel, "codex", codexOpenAIImageSourceFormat, "", out, body, requestedModel, requestPath, opts.Headers)
out, _ = sjson.SetBytes(out, "model", codexOpenAIImagesMainModel)
out, _ = sjson.SetBytes(out, "stream", true)
out, _ = sjson.DeleteBytes(out, "previous_response_id")
out, _ = sjson.DeleteBytes(out, "prompt_cache_retention")
out, _ = sjson.DeleteBytes(out, "safety_identifier")
out, _ = sjson.DeleteBytes(out, "stream_options")
return normalizeCodexInstructions(out), nil
}
func recordCodexOpenAIImageRequest(ctx context.Context, cfg *config.Config, provider string, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: headers,
Body: body,
Provider: provider,
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
func codexPrepareOpenAIImageRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (codexOpenAIImagePreparedRequest, error) {
path := helps.PayloadRequestPath(opts)
if strings.HasSuffix(path, codexImagesGenerationsPath) {
return codexPrepareOpenAIImageGenerationJSON(req.Payload, req.Model)
}
if !strings.HasSuffix(path, codexImagesEditsPath) {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("unsupported OpenAI image endpoint path %q", path)
}
contentType := codexImageContentType(opts.Headers)
mediaType, _, _ := mime.ParseMediaType(contentType)
if strings.HasPrefix(strings.ToLower(mediaType), "multipart/") {
return codexPrepareOpenAIImageEditMultipart(req.Payload, req.Model, contentType)
}
return codexPrepareOpenAIImageEditJSON(req.Payload, req.Model)
}
func codexPrepareOpenAIImageGenerationJSON(rawJSON []byte, routeModel string) (codexOpenAIImagePreparedRequest, error) {
if !json.Valid(rawJSON) {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("invalid OpenAI image generation request JSON")
}
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
tool := codexBuildOpenAIImageTool(rawJSON, routeModel, "generate", []string{"size", "quality", "background", "output_format", "moderation"}, []string{"output_compression", "partial_images"})
body := codexBuildImagesResponsesRequest(prompt, nil, tool)
return codexOpenAIImagePreparedRequest{
Body: body,
ResponseFormat: codexOpenAIImageResponseFormatFromJSON(rawJSON),
StreamPrefix: "image_generation",
}, nil
}
func codexPrepareOpenAIImageEditJSON(rawJSON []byte, routeModel string) (codexOpenAIImagePreparedRequest, error) {
if !json.Valid(rawJSON) {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("invalid OpenAI image edit request JSON")
}
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
images := make([]string, 0)
if imagesResult := gjson.GetBytes(rawJSON, "images"); imagesResult.IsArray() {
for _, img := range imagesResult.Array() {
url := strings.TrimSpace(img.Get("image_url").String())
if url != "" {
images = append(images, url)
}
}
}
tool := codexBuildOpenAIImageTool(rawJSON, routeModel, "edit", []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"}, []string{"output_compression", "partial_images"})
if mask := strings.TrimSpace(gjson.GetBytes(rawJSON, "mask.image_url").String()); mask != "" {
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", mask)
}
body := codexBuildImagesResponsesRequest(prompt, images, tool)
return codexOpenAIImagePreparedRequest{
Body: body,
ResponseFormat: codexOpenAIImageResponseFormatFromJSON(rawJSON),
StreamPrefix: "image_edit",
}, nil
}
func codexPrepareOpenAIImageEditMultipart(rawBody []byte, routeModel string, contentType string) (codexOpenAIImagePreparedRequest, error) {
_, params, errMedia := mime.ParseMediaType(contentType)
if errMedia != nil {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("parse multipart content type failed: %w", errMedia)
}
boundary := strings.TrimSpace(params["boundary"])
if boundary == "" {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("multipart boundary is required")
}
reader := multipart.NewReader(bytes.NewReader(rawBody), boundary)
form, errForm := reader.ReadForm(32 << 20)
if errForm != nil {
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("parse multipart form failed: %w", errForm)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
log.Errorf("codex openai images: remove multipart temp files error: %v", errRemove)
}
}()
prompt := strings.TrimSpace(codexFormValue(form, "prompt"))
responseFormat := codexNormalizeImageResponseFormat(codexFormValue(form, "response_format"))
tool := []byte(`{"type":"image_generation","action":"edit"}`)
tool, _ = sjson.SetBytes(tool, "model", codexOpenAIImageToolModel(codexFormValue(form, "model"), routeModel))
for _, field := range []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"} {
if value := strings.TrimSpace(codexFormValue(form, field)); value != "" {
tool, _ = sjson.SetBytes(tool, field, value)
}
}
for _, field := range []string{"output_compression", "partial_images"} {
if value := strings.TrimSpace(codexFormValue(form, field)); value != "" {
if parsed, errParse := strconv.ParseInt(value, 10, 64); errParse == nil {
tool, _ = sjson.SetBytes(tool, field, parsed)
}
}
}
images := make([]string, 0)
for _, fh := range codexMultipartImageFiles(form) {
dataURL, errData := codexMultipartFileToDataURL(fh)
if errData != nil {
return codexOpenAIImagePreparedRequest{}, errData
}
images = append(images, dataURL)
}
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
dataURL, errData := codexMultipartFileToDataURL(maskFiles[0])
if errData != nil {
return codexOpenAIImagePreparedRequest{}, errData
}
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", dataURL)
}
body := codexBuildImagesResponsesRequest(prompt, images, tool)
return codexOpenAIImagePreparedRequest{
Body: body,
ResponseFormat: responseFormat,
StreamPrefix: "image_edit",
}, nil
}
func codexImageContentType(headers http.Header) string {
if headers == nil {
return ""
}
return strings.TrimSpace(headers.Get("Content-Type"))
}
func codexOpenAIImageResponseFormatFromJSON(rawJSON []byte) string {
return codexNormalizeImageResponseFormat(gjson.GetBytes(rawJSON, "response_format").String())
}
func codexNormalizeImageResponseFormat(responseFormat string) string {
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
return "url"
}
return "b64_json"
}
func codexOpenAIImageToolModel(requestModel string, routeModel string) string {
model := strings.TrimSpace(requestModel)
if model == "" {
model = strings.TrimSpace(routeModel)
}
if model == "" {
model = codexDefaultImageToolModel
}
return model
}
func codexBuildOpenAIImageTool(rawJSON []byte, routeModel string, action string, stringFields []string, numberFields []string) []byte {
tool := []byte(`{"type":"image_generation","action":""}`)
tool, _ = sjson.SetBytes(tool, "action", action)
tool, _ = sjson.SetBytes(tool, "model", codexOpenAIImageToolModel(gjson.GetBytes(rawJSON, "model").String(), routeModel))
for _, field := range stringFields {
if value := strings.TrimSpace(gjson.GetBytes(rawJSON, field).String()); value != "" {
tool, _ = sjson.SetBytes(tool, field, value)
}
}
for _, field := range numberFields {
if value := gjson.GetBytes(rawJSON, field); value.Exists() && value.Type == gjson.Number {
tool, _ = sjson.SetBytes(tool, field, value.Int())
}
}
return tool
}
func codexBuildImagesResponsesRequest(prompt string, images []string, toolJSON []byte) []byte {
req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`)
req, _ = sjson.SetBytes(req, "model", codexOpenAIImagesMainModel)
input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`)
input, _ = sjson.SetBytes(input, "0.content.0.text", prompt)
contentIndex := 1
for _, img := range images {
if strings.TrimSpace(img) == "" {
continue
}
part := []byte(`{"type":"input_image","image_url":""}`)
part, _ = sjson.SetBytes(part, "image_url", img)
input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", contentIndex), part)
contentIndex++
}
req, _ = sjson.SetRawBytes(req, "input", input)
req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`))
if len(toolJSON) > 0 && json.Valid(toolJSON) {
req, _ = sjson.SetRawBytes(req, "tools.-1", toolJSON)
}
return req
}
func codexFormValue(form *multipart.Form, key string) string {
if form == nil || len(form.Value[key]) == 0 {
return ""
}
return strings.TrimSpace(form.Value[key][0])
}
func codexMultipartImageFiles(form *multipart.Form) []*multipart.FileHeader {
if form == nil {
return nil
}
if files := form.File["image[]"]; len(files) > 0 {
return files
}
return form.File["image"]
}
func codexMultipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
if fileHeader == nil {
return "", fmt.Errorf("upload file is nil")
}
f, errOpen := fileHeader.Open()
if errOpen != nil {
return "", fmt.Errorf("open upload file failed: %w", errOpen)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("codex openai images: close upload file error: %v", errClose)
}
}()
data, errRead := io.ReadAll(f)
if errRead != nil {
return "", fmt.Errorf("read upload file failed: %w", errRead)
}
mediaType := strings.TrimSpace(fileHeader.Header.Get("Content-Type"))
if mediaType == "" {
mediaType = http.DetectContentType(data)
}
return "data:" + mediaType + ";base64," + base64.StdEncoding.EncodeToString(data), nil
}
func codexExtractImagesFromResponsesCompleted(payload []byte) (results []codexImageCallResult, createdAt int64, usageRaw []byte, firstMeta codexImageCallResult, err error) {
if gjson.GetBytes(payload, "type").String() != "response.completed" {
return nil, 0, nil, codexImageCallResult{}, fmt.Errorf("unexpected event type")
}
createdAt = gjson.GetBytes(payload, "response.created_at").Int()
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
output := gjson.GetBytes(payload, "response.output")
if output.IsArray() {
for _, item := range output.Array() {
if item.Get("type").String() != "image_generation_call" {
continue
}
res := strings.TrimSpace(item.Get("result").String())
if res == "" {
continue
}
entry := codexImageCallResult{
Result: res,
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
Size: strings.TrimSpace(item.Get("size").String()),
Background: strings.TrimSpace(item.Get("background").String()),
Quality: strings.TrimSpace(item.Get("quality").String()),
}
if len(results) == 0 {
firstMeta = entry
}
results = append(results, entry)
}
}
if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() {
usageRaw = []byte(usage.Raw)
}
return results, createdAt, usageRaw, firstMeta, nil
}
func codexBuildImagesAPIResponse(results []codexImageCallResult, createdAt int64, usageRaw []byte, firstMeta codexImageCallResult, responseFormat string) ([]byte, error) {
out := []byte(`{"created":0,"data":[]}`)
out, _ = sjson.SetBytes(out, "created", createdAt)
responseFormat = codexNormalizeImageResponseFormat(responseFormat)
for _, img := range results {
item := []byte(`{}`)
if responseFormat == "url" {
item, _ = sjson.SetBytes(item, "url", "data:"+codexMimeTypeFromOutputFormat(img.OutputFormat)+";base64,"+img.Result)
} else {
item, _ = sjson.SetBytes(item, "b64_json", img.Result)
}
if img.RevisedPrompt != "" {
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
}
out, _ = sjson.SetRawBytes(out, "data.-1", item)
}
if firstMeta.Background != "" {
out, _ = sjson.SetBytes(out, "background", firstMeta.Background)
}
if firstMeta.OutputFormat != "" {
out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat)
}
if firstMeta.Quality != "" {
out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality)
}
if firstMeta.Size != "" {
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
}
if len(usageRaw) > 0 && json.Valid(usageRaw) {
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
}
return out, nil
}
func codexBuildImagePartialFrame(payload []byte, responseFormat string, streamPrefix string) []byte {
b64 := strings.TrimSpace(gjson.GetBytes(payload, "partial_image_b64").String())
if b64 == "" {
return nil
}
outputFormat := strings.TrimSpace(gjson.GetBytes(payload, "output_format").String())
eventName := strings.TrimSpace(streamPrefix) + ".partial_image"
data := []byte(`{"type":"","partial_image_index":0}`)
data, _ = sjson.SetBytes(data, "type", eventName)
data, _ = sjson.SetBytes(data, "partial_image_index", gjson.GetBytes(payload, "partial_image_index").Int())
if codexNormalizeImageResponseFormat(responseFormat) == "url" {
data, _ = sjson.SetBytes(data, "url", "data:"+codexMimeTypeFromOutputFormat(outputFormat)+";base64,"+b64)
} else {
data, _ = sjson.SetBytes(data, "b64_json", b64)
}
return codexBuildSSEFrame(eventName, data)
}
func codexBuildImageCompletedFrame(img codexImageCallResult, usageRaw []byte, responseFormat string, streamPrefix string) []byte {
eventName := strings.TrimSpace(streamPrefix) + ".completed"
data := []byte(`{"type":""}`)
data, _ = sjson.SetBytes(data, "type", eventName)
if codexNormalizeImageResponseFormat(responseFormat) == "url" {
data, _ = sjson.SetBytes(data, "url", "data:"+codexMimeTypeFromOutputFormat(img.OutputFormat)+";base64,"+img.Result)
} else {
data, _ = sjson.SetBytes(data, "b64_json", img.Result)
}
if len(usageRaw) > 0 && json.Valid(usageRaw) {
data, _ = sjson.SetRawBytes(data, "usage", usageRaw)
}
return codexBuildSSEFrame(eventName, data)
}
func codexBuildSSEFrame(eventName string, data []byte) []byte {
var buf bytes.Buffer
if strings.TrimSpace(eventName) != "" {
buf.WriteString("event: ")
buf.WriteString(eventName)
buf.WriteString("\n")
}
buf.WriteString("data: ")
buf.Write(data)
buf.WriteString("\n\n")
return buf.Bytes()
}
func codexMimeTypeFromOutputFormat(outputFormat string) string {
switch strings.ToLower(strings.TrimSpace(outputFormat)) {
case "jpg", "jpeg":
return "image/jpeg"
case "webp":
return "image/webp"
default:
return "image/png"
}
}
@@ -13,6 +13,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
@@ -135,6 +136,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, to.String(), from.String(), "", body, originalTranslated, requestedModel, requestPath, opts.Headers)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = capGeminiMaxOutputTokens(body, baseModel)
action := "generateContent"
if req.Metadata != nil {
@@ -243,6 +245,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, to.String(), from.String(), "", body, originalTranslated, requestedModel, requestPath, opts.Headers)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = capGeminiMaxOutputTokens(body, baseModel)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "streamGenerateContent")
@@ -527,6 +530,26 @@ func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
util.ApplyCustomHeadersFromAttrs(req, attrs)
}
func capGeminiMaxOutputTokens(body []byte, modelName string) []byte {
maxOut := gjson.GetBytes(body, "generationConfig.maxOutputTokens")
if !maxOut.Exists() || maxOut.Type != gjson.Number {
return body
}
modelInfo := registry.LookupModelInfo(modelName, "gemini")
if modelInfo == nil {
return body
}
limit := modelInfo.OutputTokenLimit
if limit <= 0 {
limit = modelInfo.MaxCompletionTokens
}
if limit <= 0 || maxOut.Int() <= int64(limit) {
return body
}
body, _ = sjson.SetBytes(body, "generationConfig.maxOutputTokens", limit)
return body
}
func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
if modelName == "gemini-2.5-flash-image-preview" {
aspectRatioResult := gjson.GetBytes(rawJSON, "generationConfig.imageConfig.aspectRatio")
@@ -0,0 +1,90 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCapGeminiMaxOutputTokensUsesOutputTokenLimit(t *testing.T) {
body := []byte(`{"generationConfig":{"maxOutputTokens":500000,"temperature":0.2},"contents":[]}`)
out := capGeminiMaxOutputTokens(body, "gemini-3.1-pro-preview")
if got := gjson.GetBytes(out, "generationConfig.maxOutputTokens").Int(); got != 65536 {
t.Fatalf("maxOutputTokens = %d, want 65536", got)
}
if got := gjson.GetBytes(out, "generationConfig.temperature").Float(); got != 0.2 {
t.Fatalf("temperature = %v, want 0.2", got)
}
}
func TestCapGeminiMaxOutputTokensLeavesAllowedOrUnknown(t *testing.T) {
tests := []struct {
name string
model string
body []byte
want int64
}{
{
name: "allowed value",
model: "gemini-3.1-pro-preview",
body: []byte(`{"generationConfig":{"maxOutputTokens":64000}}`),
want: 64000,
},
{
name: "unknown model",
model: "custom-gemini-model",
body: []byte(`{"generationConfig":{"maxOutputTokens":500000}}`),
want: 500000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out := capGeminiMaxOutputTokens(tt.body, tt.model)
if got := gjson.GetBytes(out, "generationConfig.maxOutputTokens").Int(); got != tt.want {
t.Fatalf("maxOutputTokens = %d, want %d", got, tt.want)
}
})
}
}
func TestGeminiExecutorExecuteCapsMaxOutputTokensBeforeUpstream(t *testing.T) {
var upstreamMaxOutputTokens int64
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatalf("read request body: %v", err)
}
upstreamMaxOutputTokens = gjson.GetBytes(body, "generationConfig.maxOutputTokens").Int()
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`))
}))
defer server.Close()
exec := NewGeminiExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "test-key",
"base_url": server.URL,
}}
req := cliproxyexecutor.Request{
Model: "gemini-3.1-pro-preview",
Payload: []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"maxOutputTokens":500000}}`),
}
if _, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatGemini}); err != nil {
t.Fatalf("Execute() error = %v", err)
}
if upstreamMaxOutputTokens != 65536 {
t.Fatalf("upstream maxOutputTokens = %d, want 65536", upstreamMaxOutputTokens)
}
}
@@ -102,6 +102,7 @@ func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequ
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
logging.SetResponseHeaders(ctx, headers)
if cfg == nil || !cfg.RequestLog {
return
}
@@ -227,6 +228,7 @@ func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info Ups
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
logging.SetResponseHeaders(ctx, headers)
if cfg == nil || !cfg.RequestLog {
return
}
@@ -250,6 +252,7 @@ func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
logging.SetResponseHeaders(ctx, headers)
if cfg == nil || !cfg.RequestLog {
return
}
@@ -0,0 +1,24 @@
package helps
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
)
func TestRecordAPIResponseMetadataStoresHeadersWhenRequestLogDisabled(t *testing.T) {
ctx := logging.WithResponseHeadersHolder(context.Background())
headers := http.Header{}
headers.Add("X-Upstream-Request-Id", "upstream-req-1")
RecordAPIResponseMetadata(ctx, &config.Config{}, http.StatusOK, headers)
headers.Set("X-Upstream-Request-Id", "mutated")
got := logging.GetResponseHeaders(ctx)
if got.Get("X-Upstream-Request-Id") != "upstream-req-1" {
t.Fatalf("response header = %q, want %q", got.Get("X-Upstream-Request-Id"), "upstream-req-1")
}
}
@@ -50,7 +50,7 @@ func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
return httpClient
}
// If proxy setup failed, log and fall through to context RoundTripper
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL)
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyutil.Redact(proxyURL))
}
// Priority 3: Use RoundTripper from context (typically from RoundTripperFor)
@@ -8,4 +8,5 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/geminicli"
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/kimi"
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/openai"
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/xai"
)
@@ -10,6 +10,7 @@ import (
"time"
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
"github.com/tidwall/gjson"
@@ -25,6 +26,7 @@ type UsageReporter struct {
authType string
apiKey string
source string
reasoning string
requestedAt time.Time
once sync.Once
}
@@ -43,6 +45,7 @@ func NewUsageReporter(ctx context.Context, provider, model string, auth *cliprox
apiKey: apiKey,
source: resolveUsageSource(auth, apiKey),
authType: resolveUsageAuthType(auth),
reasoning: usage.ReasoningEffortFromContext(ctx),
}
if auth != nil {
reporter.authID = auth.ID
@@ -60,7 +63,7 @@ func (r *UsageReporter) PublishAdditionalModel(ctx context.Context, model string
if !ok {
return
}
usage.PublishRecord(ctx, record)
r.publishRecord(ctx, record)
}
func (r *UsageReporter) buildAdditionalModelRecord(model string, detail usage.Detail) (usage.Record, bool) {
@@ -97,7 +100,7 @@ func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
}
detail = normalizeUsageDetailTotal(detail)
r.once.Do(func() {
usage.PublishRecord(ctx, r.buildRecord(detail, failed, fail))
r.publishRecord(ctx, r.buildRecord(detail, failed, fail))
})
}
@@ -130,10 +133,15 @@ func (r *UsageReporter) EnsurePublished(ctx context.Context) {
return
}
r.once.Do(func() {
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false, usage.Failure{}))
r.publishRecord(ctx, r.buildRecord(usage.Detail{}, false, usage.Failure{}))
})
}
func (r *UsageReporter) publishRecord(ctx context.Context, record usage.Record) {
record.ResponseHeaders = internallogging.GetResponseHeaders(ctx)
usage.PublishRecord(ctx, record)
}
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool, failures ...usage.Failure) usage.Record {
var fail usage.Failure
if len(failures) > 0 {
@@ -150,19 +158,20 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
return usage.Record{Model: model, Detail: detail, Failed: failed, Fail: fail}
}
return usage.Record{
Provider: r.provider,
Model: model,
Alias: r.alias,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
AuthType: r.authType,
RequestedAt: r.requestedAt,
Latency: r.latency(),
Failed: failed,
Fail: fail,
Detail: detail,
Provider: r.provider,
Model: model,
Alias: r.alias,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
AuthType: r.authType,
ReasoningEffort: r.reasoning,
RequestedAt: r.requestedAt,
Latency: r.latency(),
Failed: failed,
Fail: fail,
Detail: detail,
}
}
@@ -159,6 +159,16 @@ func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
}
}
func TestUsageReporterBuildRecordIncludesReasoningEffort(t *testing.T) {
ctx := usage.WithReasoningEffort(context.Background(), "medium")
reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
if record.ReasoningEffort != "medium" {
t.Fatalf("reasoning effort = %q, want %q", record.ReasoningEffort, "medium")
}
}
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
reporter := &UsageReporter{
provider: "codex",
@@ -30,7 +30,7 @@ func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
if proxyURL != "" {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
if errBuild != nil {
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyutil.Redact(proxyURL), errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
@@ -4,9 +4,13 @@ import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"time"
@@ -21,6 +25,14 @@ import (
"github.com/tidwall/sjson"
)
const (
openAICompatImageHandlerType = "openai-image"
openAICompatImagesGenerationsPath = "/images/generations"
openAICompatImagesEditsPath = "/images/edits"
openAICompatDefaultImageEndpoint = openAICompatImagesGenerationsPath
openAICompatMultipartMemory int64 = 32 << 20
)
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
// It performs request/response translation and executes against the provider base URL
// using per-auth credentials (API key) and per-auth HTTP transport (proxy) from context.
@@ -71,6 +83,10 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
return e.executeImages(ctx, auth, req, opts, endpointPath)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -179,7 +195,98 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
return resp, nil
}
func (e *OpenAICompatExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return resp, err
}
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), false)
if errPrepare != nil {
err = errPrepare
return resp, err
}
if contentType == "" {
contentType = "application/json"
}
url := strings.TrimSuffix(baseURL, "/") + endpointPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", contentType)
if apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
}
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
body, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
err = statusErr{code: httpResp.StatusCode, msg: string(body)}
return resp, err
}
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
reporter.EnsurePublished(ctx)
resp = cliproxyexecutor.Response{Payload: body, Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
return e.executeImagesStream(ctx, auth, req, opts, endpointPath)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
@@ -342,6 +449,121 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) executeImagesStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return nil, err
}
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), true)
if errPrepare != nil {
err = errPrepare
return nil, err
}
if contentType == "" {
contentType = "application/json"
}
url := strings.TrimSuffix(baseURL, "/") + endpointPath
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", contentType)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
if apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
}
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
body, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
return nil, statusErr{code: httpResp.StatusCode, msg: string(body)}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
reporter.EnsurePublished(ctx)
}()
buffer := make([]byte, 32*1024)
for {
n, errRead := httpResp.Body.Read(buffer)
if n > 0 {
chunk := bytes.Clone(buffer[:n])
helps.AppendAPIResponseChunk(ctx, e.cfg, chunk)
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
case <-ctx.Done():
return
}
}
if errRead != nil {
if errRead != io.EOF {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx, errRead)
select {
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
case <-ctx.Done():
}
}
return
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
@@ -380,6 +602,124 @@ func (e *OpenAICompatExecutor) Refresh(ctx context.Context, auth *cliproxyauth.A
return auth, nil
}
func openAICompatImageEndpointPath(opts cliproxyexecutor.Options) string {
if opts.SourceFormat.String() != openAICompatImageHandlerType {
return ""
}
path := helps.PayloadRequestPath(opts)
if strings.HasSuffix(path, "/images/edits") {
return openAICompatImagesEditsPath
}
if strings.HasSuffix(path, "/images/generations") {
return openAICompatImagesGenerationsPath
}
return openAICompatDefaultImageEndpoint
}
func prepareOpenAICompatImagesPayload(payload []byte, model string, contentType string, stream bool) ([]byte, string, error) {
model = strings.TrimSpace(model)
contentType = strings.TrimSpace(contentType)
if json.Valid(payload) {
if model != "" {
payload, _ = sjson.SetBytes(payload, "model", model)
}
if stream {
payload, _ = sjson.SetBytes(payload, "stream", true)
} else {
payload, _ = sjson.DeleteBytes(payload, "stream")
}
return payload, "application/json", nil
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil || !strings.HasPrefix(strings.ToLower(strings.TrimSpace(mediaType)), "multipart/") {
return payload, contentType, nil
}
boundary := strings.TrimSpace(params["boundary"])
if boundary == "" {
return nil, "", fmt.Errorf("multipart boundary is missing")
}
return rewriteOpenAICompatImagesMultipartPayload(payload, model, boundary, stream)
}
func cloneOpenAICompatMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
dst := make(textproto.MIMEHeader, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func rewriteOpenAICompatImagesMultipartPayload(payload []byte, model string, boundary string, stream bool) ([]byte, string, error) {
reader := multipart.NewReader(bytes.NewReader(payload), boundary)
form, errRead := reader.ReadForm(openAICompatMultipartMemory)
if errRead != nil {
return nil, "", fmt.Errorf("read multipart form failed: %w", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
log.Errorf("openai compat executor: remove multipart form files error: %v", errRemove)
}
}()
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if model != "" {
if errWrite := writer.WriteField("model", model); errWrite != nil {
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
}
}
if stream {
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
}
}
for key, values := range form.Value {
if key == "model" || key == "stream" {
continue
}
for _, value := range values {
if errWrite := writer.WriteField(key, value); errWrite != nil {
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
}
}
}
for key, files := range form.File {
for _, fileHeader := range files {
if fileHeader == nil {
continue
}
header := cloneOpenAICompatMIMEHeader(fileHeader.Header)
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/octet-stream")
}
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
}
src, errOpen := fileHeader.Open()
if errOpen != nil {
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
}
_, errCopy := io.Copy(part, src)
if errClose := src.Close(); errClose != nil {
log.Errorf("openai compat executor: close upload file error: %v", errClose)
if errCopy == nil {
errCopy = errClose
}
}
if errCopy != nil {
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
}
}
}
if errClose := writer.Close(); errClose != nil {
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
}
return body.Bytes(), writer.FormDataContentType(), nil
}
func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (baseURL, apiKey string) {
if auth == nil {
return "", ""
@@ -1,10 +1,14 @@
package executor
import (
"bytes"
"context"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"strings"
"testing"
@@ -102,6 +106,265 @@ func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T)
}
}
func TestOpenAICompatExecutorImagesGenerationsPassthrough(t *testing.T) {
var gotPath string
var gotBody []byte
var gotContentType string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotContentType = r.Header.Get("Content-Type")
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}],"usage":{"total_tokens":1}}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "upstream-image",
Payload: []byte(`{"model":"compat-image","prompt":"draw"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-image"),
Stream: false,
Headers: http.Header{
"Content-Type": []string{"application/json"},
},
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
},
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/v1/images/generations" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/generations")
}
if gotContentType != "application/json" {
t.Fatalf("content type = %q, want application/json", gotContentType)
}
if got := gjson.GetBytes(gotBody, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(gotBody))
}
if got := gjson.GetBytes(resp.Payload, "data.0.b64_json").String(); got != "AA==" {
t.Fatalf("response payload = %s", string(resp.Payload))
}
}
func TestOpenAICompatExecutorImagesGenerationsStreamsUpstream(t *testing.T) {
var gotPath string
var gotBody []byte
var gotAccept string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAccept = r.Header.Get("Accept")
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("event: image_generation.partial\ndata: {\"type\":\"image_generation.partial\"}\n\n"))
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
_, _ = w.Write([]byte("data: [DONE]\n\n"))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
streamResult, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "upstream-image",
Payload: []byte(`{"model":"compat-image","prompt":"draw","stream":true}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-image"),
Stream: true,
Headers: http.Header{
"Content-Type": []string{"application/json"},
},
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
},
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var streamed bytes.Buffer
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("stream chunk error: %v", chunk.Err)
}
streamed.Write(chunk.Payload)
}
if gotPath != "/v1/images/generations" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/generations")
}
if gotAccept != "text/event-stream" {
t.Fatalf("accept = %q, want text/event-stream", gotAccept)
}
if got := gjson.GetBytes(gotBody, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(gotBody))
}
if !gjson.GetBytes(gotBody, "stream").Bool() {
t.Fatalf("stream flag missing from upstream body: %s", string(gotBody))
}
if !strings.Contains(streamed.String(), "event: image_generation.partial") || !strings.Contains(streamed.String(), "data: [DONE]") {
t.Fatalf("streamed body = %q", streamed.String())
}
}
func TestOpenAICompatExecutorImagesEditsMultipartRewritesModel(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
t.Fatalf("write prompt field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
header.Set("Content-Type", "image/png")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
contentType := writer.FormDataContentType()
var gotPath string
var gotModel string
var gotPrompt string
var gotFile string
var gotFileContentType string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if errParse := r.ParseMultipartForm(32 << 20); errParse != nil {
t.Fatalf("parse multipart form: %v", errParse)
}
gotModel = r.FormValue("model")
gotPrompt = r.FormValue("prompt")
file, fileHeader, errFile := r.FormFile("image")
if errFile != nil {
t.Fatalf("read image file: %v", errFile)
}
gotFileContentType = fileHeader.Header.Get("Content-Type")
data, errRead := io.ReadAll(file)
if errClose := file.Close(); errClose != nil {
t.Fatalf("close image file: %v", errClose)
}
if errRead != nil {
t.Fatalf("read image file: %v", errRead)
}
gotFile = string(data)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}]}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "upstream-image",
Payload: body.Bytes(),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-image"),
Stream: false,
Headers: http.Header{
"Content-Type": []string{contentType},
},
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/edits",
},
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/v1/images/edits" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/edits")
}
if gotModel != "upstream-image" {
t.Fatalf("model = %q, want upstream-image", gotModel)
}
if gotPrompt != "edit" {
t.Fatalf("prompt = %q, want edit", gotPrompt)
}
if gotFile != "png-data" {
t.Fatalf("file = %q, want png-data", gotFile)
}
if gotFileContentType != "image/png" {
t.Fatalf("file content type = %q, want image/png", gotFileContentType)
}
}
func TestRewriteOpenAICompatImagesMultipartPayloadPreservesStreamAndFileContentType(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
t.Fatalf("write stream field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.webp"))
header.Set("Content-Type", "image/webp")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("webp-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
out, contentType, err := prepareOpenAICompatImagesPayload(body.Bytes(), "upstream-image", writer.FormDataContentType(), true)
if err != nil {
t.Fatalf("prepareOpenAICompatImagesPayload error: %v", err)
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil {
t.Fatalf("parse content type: %v", errParse)
}
if mediaType != "multipart/form-data" {
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
}
reader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
form, errRead := reader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read rewritten form: %v", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
t.Fatalf("remove form files: %v", errRemove)
}
}()
if got := form.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
t.Fatalf("model values = %#v, want upstream-image", got)
}
if got := form.Value["stream"]; len(got) != 1 || got[0] != "true" {
t.Fatalf("stream values = %#v, want true", got)
}
if got := form.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/webp" {
t.Fatalf("image headers = %#v, want image/webp", got)
}
}
func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
+1 -1
View File
@@ -487,7 +487,7 @@ func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxye
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
var err error
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
body, err = thinking.ApplyThinking(body, req.Model, from.String(), e.Identifier(), e.Identifier())
if err != nil {
return nil, err
}
@@ -196,6 +196,48 @@ func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
}
}
func TestXAIExecutorAppliesThinkingSuffix(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var errRead error
gotBody, errRead = io.ReadAll(r.Body)
if errRead != nil {
t.Fatalf("read body: %v", errRead)
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"auth_kind": "oauth",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-4.3(low)",
Payload: []byte(`{"model":"grok-4.3","input":"hello"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Stream: false,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if got := gjson.GetBytes(gotBody, "model").String(); got != "grok-4.3" {
t.Fatalf("model = %q, want grok-4.3; body=%s", got, string(gotBody))
}
if got := gjson.GetBytes(gotBody, "reasoning.effort").String(); got != "low" {
t.Fatalf("reasoning.effort = %q, want low; body=%s", got, string(gotBody))
}
}
func TestXAIExecutorExecuteStreamFiltersToolSearchTool(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+53 -2
View File
@@ -18,6 +18,7 @@ var providerAppliers = map[string]ProviderApplier{
"codex": nil,
"antigravity": nil,
"kimi": nil,
"xai": nil,
}
// GetProviderApplier returns the ProviderApplier for the given provider name.
@@ -62,7 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
// - body: Original request body JSON
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
// - fromFormat: Source request format (e.g., openai, codex, gemini)
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, kimi)
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, kimi, xai)
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
//
// Returns:
@@ -324,7 +325,7 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
return extractGeminiConfig(body, provider)
case "openai":
return extractOpenAIConfig(body)
case "codex":
case "codex", "xai":
return extractCodexConfig(body)
case "kimi":
// Kimi uses OpenAI-compatible reasoning_effort format
@@ -338,6 +339,56 @@ func hasThinkingConfig(config ThinkingConfig) bool {
return config.Mode != ModeBudget || config.Budget != 0 || config.Level != ""
}
// ExtractReasoningEffort returns the request's thinking setting as a canonical
// reasoning_effort label for usage logging. Model suffixes have the same
// priority as ApplyThinking: a valid suffix overrides body fields.
func ExtractReasoningEffort(body []byte, provider, model string) string {
if effort := reasoningEffortFromSuffix(ParseSuffix(model)); effort != "" {
return effort
}
provider = strings.ToLower(strings.TrimSpace(provider))
config := extractThinkingConfig(body, provider)
if !hasThinkingConfig(config) {
switch provider {
case "openai-response":
config = extractCodexConfig(body)
case "openai":
config = extractCodexConfig(body)
}
}
return reasoningEffortFromConfig(config)
}
func reasoningEffortFromSuffix(suffix SuffixResult) string {
if !suffix.HasSuffix {
return ""
}
return reasoningEffortFromConfig(parseSuffixToConfig(suffix.RawSuffix, "", suffix.ModelName))
}
func reasoningEffortFromConfig(config ThinkingConfig) string {
if !hasThinkingConfig(config) {
return ""
}
switch config.Mode {
case ModeNone:
return string(LevelNone)
case ModeAuto:
return string(LevelAuto)
case ModeLevel:
return strings.ToLower(strings.TrimSpace(string(config.Level)))
case ModeBudget:
level, ok := ConvertBudgetToLevel(config.Budget)
if !ok {
return ""
}
return level
default:
return ""
}
}
// extractClaudeConfig extracts thinking configuration from Claude format request body.
//
// Claude API format:
+26
View File
@@ -0,0 +1,26 @@
// Package xai implements thinking configuration for xAI Grok Responses API models.
//
// xAI models use the OpenAI Responses API compatible reasoning.effort format
// with discrete levels.
package xai
import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/codex"
)
// Applier implements thinking.ProviderApplier for xAI models.
type Applier struct {
codex.Applier
}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new xAI thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("xai", NewApplier())
}
@@ -0,0 +1,51 @@
package xai
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/tidwall/gjson"
)
func TestApplySetsReasoningEffort(t *testing.T) {
applier := NewApplier()
modelInfo := &registry.ModelInfo{
ID: "grok-4.3",
Thinking: &registry.ThinkingSupport{
ZeroAllowed: true,
Levels: []string{"none", "low", "medium", "high"},
},
}
out, err := applier.Apply([]byte(`{"input":"hello"}`), thinking.ThinkingConfig{
Mode: thinking.ModeLevel,
Level: thinking.LevelHigh,
}, modelInfo)
if err != nil {
t.Fatalf("Apply() error = %v", err)
}
if got := gjson.GetBytes(out, "reasoning.effort").String(); got != "high" {
t.Fatalf("reasoning.effort = %q, want high; body=%s", got, string(out))
}
}
func TestApplyNoneFallsBackToLowestLevelWhenDisableUnsupported(t *testing.T) {
applier := NewApplier()
modelInfo := &registry.ModelInfo{
ID: "grok-3-mini",
Thinking: &registry.ThinkingSupport{
Levels: []string{"low", "medium", "high"},
},
}
out, err := applier.Apply([]byte(`{"input":"hello"}`), thinking.ThinkingConfig{
Mode: thinking.ModeNone,
}, modelInfo)
if err != nil {
t.Fatalf("Apply() error = %v", err)
}
if got := gjson.GetBytes(out, "reasoning.effort").String(); got != "low" {
t.Fatalf("reasoning.effort = %q, want low; body=%s", got, string(out))
}
}
@@ -0,0 +1,31 @@
package thinking
import "testing"
func TestExtractReasoningEffortUsesSuffixOverBody(t *testing.T) {
got := ExtractReasoningEffort([]byte(`{"reasoning_effort":"low"}`), "openai", "gpt-5.4(high)")
if got != "high" {
t.Fatalf("ExtractReasoningEffort() = %q, want %q", got, "high")
}
}
func TestExtractReasoningEffortConvertsBudgetToLevel(t *testing.T) {
got := ExtractReasoningEffort([]byte(`{"thinking":{"type":"enabled","budget_tokens":8192}}`), "claude", "claude-sonnet-4-5")
if got != "medium" {
t.Fatalf("ExtractReasoningEffort() = %q, want %q", got, "medium")
}
}
func TestExtractReasoningEffortSupportsOpenAIResponses(t *testing.T) {
got := ExtractReasoningEffort([]byte(`{"reasoning":{"effort":"medium"}}`), "openai-response", "gpt-5.4")
if got != "medium" {
t.Fatalf("ExtractReasoningEffort() = %q, want %q", got, "medium")
}
}
func TestExtractReasoningEffortMissingConfigIsEmpty(t *testing.T) {
got := ExtractReasoningEffort([]byte(`{"messages":[{"role":"user","content":"hi"}]}`), "openai", "gpt-5.4")
if got != "" {
t.Fatalf("ExtractReasoningEffort() = %q, want empty", got)
}
}
+1 -1
View File
@@ -42,7 +42,7 @@ func StripThinkingConfig(body []byte, provider string) []byte {
"reasoning_effort",
"thinking",
}
case "codex":
case "codex", "xai":
paths = []string{"reasoning.effort"}
default:
return body
+1 -1
View File
@@ -1,7 +1,7 @@
// Package thinking provides unified thinking configuration processing.
//
// This package offers a unified interface for parsing, validating, and applying
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, Codex, Antigravity, Kimi).
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, Codex, Antigravity, Kimi, xAI).
package thinking
import "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
+1 -1
View File
@@ -357,7 +357,7 @@ func isGeminiFamily(provider string) bool {
func isOpenAIFamily(provider string) bool {
switch provider {
case "openai", "openai-response", "codex":
case "openai", "openai-response", "codex", "xai":
return true
default:
return false
@@ -101,7 +101,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
if strings.HasPrefix(systemPrompt, "x-anthropic-billing-header:") {
if util.IsClaudeCodeAttributionSystemText(systemPrompt) {
continue
}
partJSON := []byte(`{}`)
@@ -112,7 +112,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
hasSystemInstruction = true
}
}
} else if systemResult.Type == gjson.String {
} else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) {
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
@@ -70,6 +70,28 @@ func uint64Ptr(v uint64) *uint64 {
return &v
}
func TestConvertClaudeRequestToAntigravity_StripsClaudeCodeAttribution(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
{"type": "text", "text": "Antigravity system prompt"}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.systemInstruction.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 system part after attribution strip, got %d: %s", len(parts), gjson.Get(outputStr, "request.systemInstruction.parts").Raw)
}
if got := parts[0].Get("text").String(); got != "Antigravity system prompt" {
t.Fatalf("Unexpected system part: %q", got)
}
}
func testNonAnthropicRawSignature(t *testing.T) string {
t.Helper()
@@ -99,35 +99,19 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
// Gemini-specific handling for non-Claude models:
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
if !strings.Contains(modelName, "claude") {
// - Replace client-provided thoughtSignature values with the skip sentinel.
// - Add the same sentinel to functionCall and thinking parts so upstream can bypass signature validation.
if !strings.Contains(strings.ToLower(modelName), "claude") {
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to mark with skip sentinel
var thinkingIndicesToSkipSignature []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Collect indices of thinking blocks to mark with skip sentinel
if part.Get("thought").Bool() {
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
}
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
if part.Get("functionCall").Exists() || part.Get("thought").Exists() || part.Get("thoughtSignature").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
return true
})
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
idx := thinkingIndicesToSkipSignature[i]
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
}
}
return true
})
@@ -7,8 +7,8 @@ import (
"github.com/tidwall/gjson"
)
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
// Valid signature on functionCall should be preserved
func TestConvertGeminiRequestToAntigravity_ReplacesClientSignatureOnFunctionCall(t *testing.T) {
// Client signatures on Gemini function calls are not portable to Antigravity.
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
@@ -25,15 +25,83 @@ func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that valid thoughtSignature is preserved
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part, got %d", len(parts))
}
sig := parts[0].Get("thoughtSignature").String()
if sig != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
expectedSig := "skip_thought_signature_validator"
if sig != expectedSig {
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, sig)
}
}
func TestConvertGeminiRequestToAntigravity_ReplacesClientSignatureOnTextPart(t *testing.T) {
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"text": "previous answer", "thoughtSignature": "%s"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
expectedSig := "skip_thought_signature_validator"
if sig != expectedSig {
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, sig)
}
}
func TestConvertGeminiRequestToAntigravity_AddsSkipSentinelToStringThoughtPart(t *testing.T) {
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"thought": "internal reasoning"}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
expectedSig := "skip_thought_signature_validator"
if sig != expectedSig {
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, sig)
}
}
func TestConvertGeminiRequestToAntigravity_SkipsUppercaseClaudeModel(t *testing.T) {
inputJSON := []byte(`{
"model": "Claude-Test",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("Claude-Test", inputJSON, false)
outputStr := string(output)
if sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature"); sig.Exists() {
t.Fatalf("Expected no thoughtSignature for Claude model, got %s", sig.Raw)
}
}
@@ -6,12 +6,15 @@
package claude
import (
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -50,7 +53,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
contentIndex := 0
appendSystemText := func(text string) {
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
if text == "" || util.IsClaudeCodeAttributionSystemText(text) {
return
}
@@ -84,6 +87,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
messageRole := messageResult.Get("role").String()
if messageRole == "system" {
messageRole = "developer"
}
newMessage := func() []byte {
msg := []byte(`{"type":"message","role":"","content":[]}`)
@@ -172,7 +178,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
case "tool_use":
flushMessage()
functionCallMessage := []byte(`{"type":"function_call"}`)
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", shortenCodexCallIDIfNeeded(messageContentResult.Get("id").String()))
{
name := messageContentResult.Get("name").String()
if short, ok := toolNameMap[name]; ok {
@@ -187,7 +193,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
case "tool_result":
flushMessage()
functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", shortenCodexCallIDIfNeeded(messageContentResult.Get("tool_use_id").String()))
contentResult := messageContentResult.Get("content")
if contentResult.IsArray() {
@@ -361,6 +367,23 @@ func isFernetLikeReasoningSignature(signature string) bool {
return ciphertextLen > 0 && ciphertextLen%aesBlockSize == 0
}
// shortenCodexCallIDIfNeeded keeps Claude tool IDs within the OpenAI Responses
// API call_id limit while preserving a stable, low-collision mapping.
func shortenCodexCallIDIfNeeded(id string) string {
const limit = 64
if len(id) <= limit {
return id
}
sum := sha256.Sum256([]byte(id))
suffix := "_" + hex.EncodeToString(sum[:8])
prefixLen := limit - len(suffix)
if prefixLen <= 0 {
return suffix[len(suffix)-limit:]
}
return id[:prefixLen] + suffix
}
func isClaudeWebSearchToolType(toolType string) bool {
return toolType == "web_search_20250305" || toolType == "web_search_20260209"
}
@@ -42,6 +42,18 @@ func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "System role in messages",
inputJSON: `{
"model": "claude-3-opus",
"messages": [
{"role": "system", "content": "Follow the project instructions"},
{"role": "user", "content": "hello"}
]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Follow the project instructions"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
@@ -136,6 +148,56 @@ func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
}
}
func TestConvertClaudeRequestToCodex_ShortenLongToolUseIDs(t *testing.T) {
longID := "toolu_" + strings.Repeat("a", 62)
if len(longID) <= 64 {
t.Fatalf("test setup error: longID length = %d, want > 64", len(longID))
}
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{"role": "user", "content": [{"type":"text","text":"run pwd"}]},
{"role": "assistant", "content": [
{"type":"tool_use","id":"` + longID + `","name":"Bash","input":{"cmd":"pwd"}}
]},
{"role": "user", "content": [
{"type":"tool_result","tool_use_id":"` + longID + `","content":"ok"}
]}
]
}`
result := ConvertClaudeRequestToCodex("test-model", []byte(inputJSON), false)
inputs := gjson.GetBytes(result, "input").Array()
var callID string
var outputCallID string
for _, item := range inputs {
switch item.Get("type").String() {
case "function_call":
callID = item.Get("call_id").String()
case "function_call_output":
outputCallID = item.Get("call_id").String()
}
}
if callID == "" {
t.Fatalf("missing function_call item. Output: %s", string(result))
}
if outputCallID == "" {
t.Fatalf("missing function_call_output item. Output: %s", string(result))
}
if callID != outputCallID {
t.Fatalf("call_id mismatch: function_call=%q function_call_output=%q. Output: %s", callID, outputCallID, string(result))
}
if len(callID) > 64 {
t.Fatalf("call_id length = %d, want <= 64: %q", len(callID), callID)
}
if callID == longID {
t.Fatalf("long call_id was not shortened: %q", callID)
}
}
func TestConvertClaudeRequestToCodex_ToolChoiceModeMapping(t *testing.T) {
tests := []struct {
name string
@@ -140,7 +140,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
params.HasReceivedArgumentsDelta = false
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
template, _ = sjson.SetBytes(template, "content_block.id", shortenCodexCallIDIfNeeded(util.SanitizeClaudeToolID(itemResult.Get("call_id").String())))
{
name := itemResult.Get("name").String()
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
@@ -350,7 +350,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
}
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.SetBytes(toolBlock, "id", shortenCodexCallIDIfNeeded(util.SanitizeClaudeToolID(item.Get("call_id").String())))
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
@@ -459,6 +459,70 @@ func TestConvertCodexResponseToClaude_StreamEmptyOutputUsesOutputItemDoneMessage
}
}
func TestConvertCodexResponseToClaude_ShortensLongToolUseIDs(t *testing.T) {
longCallID := "call_" + strings.Repeat("a", 62)
if len(longCallID) <= 64 {
t.Fatalf("test setup error: longCallID length = %d, want > 64", len(longCallID))
}
t.Run("stream", func(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"tools":[{"name":"lookup","input_schema":{"type":"object","properties":{}}}]}`)
var param any
outputs := ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"`+longCallID+`","name":"lookup"}}`), &param)
toolID := ""
for _, out := range outputs {
for _, line := range strings.Split(string(out), "\n") {
if !strings.HasPrefix(line, "data: ") {
continue
}
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "tool_use" {
toolID = data.Get("content_block.id").String()
}
}
}
if toolID == "" {
t.Fatalf("missing stream tool_use block. Outputs=%q", outputs)
}
if len(toolID) > 64 {
t.Fatalf("stream tool_use id length = %d, want <= 64: %q", len(toolID), toolID)
}
if toolID == longCallID {
t.Fatalf("stream tool_use id was not shortened: %q", toolID)
}
})
t.Run("nonstream", func(t *testing.T) {
ctx := context.Background()
originalRequest := []byte(`{"tools":[{"name":"lookup","input_schema":{"type":"object","properties":{}}}]}`)
response := []byte(`{
"type":"response.completed",
"response":{
"id":"resp_1",
"model":"gpt-5",
"usage":{"input_tokens":1,"output_tokens":1},
"output":[{"type":"function_call","call_id":"` + longCallID + `","name":"lookup","arguments":"{}"}]
}
}`)
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
toolID := gjson.GetBytes(out, "content.0.id").String()
if toolID == "" {
t.Fatalf("missing nonstream tool_use id. Output: %s", string(out))
}
if len(toolID) > 64 {
t.Fatalf("nonstream tool_use id length = %d, want <= 64: %q", len(toolID), toolID)
}
if toolID == longCallID {
t.Fatalf("nonstream tool_use id was not shortened: %q", toolID)
}
})
}
func TestConvertCodexResponseToClaude_StreamStopReasonMapping(t *testing.T) {
tests := []struct {
name string
@@ -49,6 +49,9 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
if util.IsClaudeCodeAttributionSystemText(textResult.String()) {
return true
}
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
@@ -60,7 +63,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
if hasSystemParts {
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
} else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) {
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.-1.text", systemResult.String())
}
@@ -40,3 +40,24 @@ func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) {
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
}
}
func TestConvertClaudeRequestToCLI_StripsClaudeCodeAttribution(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
{"type": "text", "text": "User system prompt"}
],
"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]
}`)
output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false)
parts := gjson.GetBytes(output, "request.systemInstruction.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 system part after attribution strip, got %d: %s", len(parts), gjson.GetBytes(output, "request.systemInstruction.parts").Raw)
}
if got := parts[0].Get("text").String(); got != "User system prompt" {
t.Fatalf("Unexpected system part: %q", got)
}
}
@@ -43,6 +43,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
if util.IsClaudeCodeAttributionSystemText(textResult.String()) {
return true
}
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
@@ -54,7 +57,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if hasSystemParts {
out, _ = sjson.SetRawBytes(out, "system_instruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
} else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) {
out, _ = sjson.SetBytes(out, "system_instruction.parts.-1.text", systemResult.String())
}
@@ -78,8 +81,12 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
text := contentResult.Get("text").String()
if text == "" {
return true
}
part := []byte(`{"text":""}`)
part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String())
part, _ = sjson.SetBytes(part, "text", text)
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
case "tool_use":
@@ -78,3 +78,57 @@ func TestConvertClaudeRequestToGemini_ImageContent(t *testing.T) {
t.Fatalf("Expected image data 'aGVsbG8=', got '%s'", got)
}
}
func TestConvertClaudeRequestToGemini_StripsClaudeCodeAttribution(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
{"type": "text", "text": "You are a Claude agent, built on Anthropic's Claude Agent SDK."},
{"type": "text", "text": "User system prompt"}
],
"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]
}`)
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
parts := gjson.GetBytes(output, "system_instruction.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 system parts after attribution strip, got %d: %s", len(parts), gjson.GetBytes(output, "system_instruction.parts").Raw)
}
if got := parts[0].Get("text").String(); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("Unexpected first system part: %q", got)
}
if got := parts[1].Get("text").String(); got != "User system prompt" {
t.Fatalf("Unexpected second system part: %q", got)
}
if gjson.GetBytes(output, `system_instruction.parts.#(text%"x-anthropic-billing-header:*")`).Exists() {
t.Fatalf("Claude Code attribution block was forwarded: %s", gjson.GetBytes(output, "system_instruction.parts").Raw)
}
}
func TestConvertClaudeRequestToGemini_SkipsEmptyTextParts(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": ""},
{"type": "text", "text": "hello"},
{"type": "text", "text": ""}
]
}
]
}`)
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
parts := gjson.GetBytes(output, "contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part after skipping empty text, got %d: %s", len(parts), output)
}
if got := parts[0].Get("text").String(); got != "hello" {
t.Fatalf("Expected part text 'hello', got '%s'", got)
}
}
@@ -9,6 +9,7 @@ import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -103,7 +104,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
hasSystemContent := false
if system := root.Get("system"); system.Exists() {
if system.Type == gjson.String {
if system.String() != "" {
if system.String() != "" && !util.IsClaudeCodeAttributionSystemText(system.String()) {
oldSystem := []byte(`{"type":"text","text":""}`)
oldSystem, _ = sjson.SetBytes(oldSystem, "text", system.String())
systemMsgJSON, _ = sjson.SetRawBytes(systemMsgJSON, "content.-1", oldSystem)
@@ -334,7 +335,7 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
switch partType {
case "text":
text := part.Get("text").String()
if strings.TrimSpace(text) == "" {
if strings.TrimSpace(text) == "" || util.IsClaudeCodeAttributionSystemText(text) {
return "", false
}
textContent := []byte(`{"type":"text","text":""}`)
@@ -696,3 +696,28 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
}
}
func TestConvertClaudeRequestToOpenAI_StripsClaudeCodeAttribution(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
{"type": "text", "text": "User system prompt"}
],
"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]
}`)
output := ConvertClaudeRequestToOpenAI("gpt-5", inputJSON, false)
messages := gjson.GetBytes(output, "messages").Array()
if len(messages) == 0 || messages[0].Get("role").String() != "system" {
t.Fatalf("Expected first message to be system, got: %s", gjson.GetBytes(output, "messages").Raw)
}
content := messages[0].Get("content").Array()
if len(content) != 1 {
t.Fatalf("Expected 1 system content item after attribution strip, got %d: %s", len(content), messages[0].Get("content").Raw)
}
if got := content[0].Get("text").String(); got != "User system prompt" {
t.Fatalf("Unexpected system content: %q", got)
}
}
@@ -8,6 +8,7 @@ package claude
import (
"bytes"
"context"
"sort"
"strings"
translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common"
@@ -26,6 +27,9 @@ type ConvertOpenAIResponseToAnthropicParams struct {
Model string
CreatedAt int64
ToolNameMap map[string]string
// SawToolCall is true once at least one tool_use content_block_start has
// been emitted on the wire. Using raw upstream tool_calls presence here
// can produce stop_reason=tool_use with zero announced tool blocks.
SawToolCall bool
// Content accumulator for streaming
ContentAccumulator strings.Builder
@@ -60,6 +64,9 @@ type ToolCallAccumulator struct {
ID string
Name string
Arguments strings.Builder
// StartEmitted tracks whether content_block_start has already been sent
// for this tool index.
StartEmitted bool
}
// ConvertOpenAIResponseToClaude converts OpenAI streaming response format to Anthropic API format.
@@ -218,9 +225,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
}
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
param.SawToolCall = true
index := int(toolCall.Get("index").Int())
blockIndex := param.toolContentBlockIndex(index)
// Initialize accumulator if needed
if _, exists := param.ToolCallsAccumulator[index]; !exists {
@@ -229,27 +234,25 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
accumulator := param.ToolCallsAccumulator[index]
// Handle tool call ID
if id := toolCall.Get("id"); id.Exists() {
accumulator.ID = id.String()
// Handle tool call ID. Only accept JSON-string, non-empty
// values so malformed upstream fields do not overwrite a
// valid ID or coerce into a content_block.id.
if id := toolCall.Get("id"); id.Exists() && id.Type == gjson.String {
if idStr := id.String(); idStr != "" {
accumulator.ID = idStr
}
}
// Handle function name
// Handle function name and arguments
if function := toolCall.Get("function"); function.Exists() {
if name := function.Get("name"); name.Exists() && name.String() != "" {
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
stopThinkingContentBlock(param, &results)
stopTextContentBlock(param, &results)
// Send content_block_start for tool_use
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSONBytes := []byte(contentBlockStartJSON)
contentBlockStartJSONBytes, _ = sjson.SetBytes(contentBlockStartJSONBytes, "index", blockIndex)
contentBlockStartJSONBytes, _ = sjson.SetBytes(contentBlockStartJSONBytes, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSONBytes, _ = sjson.SetBytes(contentBlockStartJSONBytes, "content_block.name", accumulator.Name)
results = append(results, translatorcommon.AppendSSEEventBytes(nil, "content_block_start", contentBlockStartJSONBytes, 2))
// Only record the name until content_block_start has been
// emitted. Some upstreams send "name": "" or repeat the
// field across chunks; reassigning after start could drift
// from what was already announced.
if !accumulator.StartEmitted {
if name := function.Get("name"); name.Exists() && name.Type == gjson.String && name.String() != "" {
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
}
}
// Handle function arguments
@@ -261,6 +264,13 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
}
}
// Re-check on every chunk, not only chunks with a function
// object. Some upstreams split function.name and id across
// separate deltas.
if !accumulator.StartEmitted && accumulator.Name != "" && accumulator.ID != "" && !param.ContentBlocksStopped {
emitToolUseStart(param, index, accumulator, &results)
}
return true
})
}
@@ -269,9 +279,12 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Handle finish_reason (but don't send message_delta/message_stop yet)
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
reason := finishReason.String()
if param.SawToolCall {
switch {
case param.SawToolCall:
param.FinishReason = "tool_calls"
} else {
case reason == "tool_calls":
param.FinishReason = "stop"
default:
param.FinishReason = reason
}
@@ -289,8 +302,17 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Send content_block_stop for any tool calls
if !param.ContentBlocksStopped {
for index := range param.ToolCallsAccumulator {
for _, index := range toolCallAccumulatorIndexes(param.ToolCallsAccumulator) {
accumulator := param.ToolCallsAccumulator[index]
if !accumulator.StartEmitted {
// Belated emit for streams that supplied a valid name but
// never sent an id. SanitizeClaudeToolID("") produces the
// expected stable synthetic toolu_<nanos>_<n> ID shape.
if accumulator.Name == "" {
continue
}
emitToolUseStart(param, index, accumulator, &results)
}
blockIndex := param.toolContentBlockIndex(index)
// Send complete input_json_delta with all accumulated arguments
@@ -353,8 +375,16 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
stopTextContentBlock(param, &results)
if !param.ContentBlocksStopped {
for index := range param.ToolCallsAccumulator {
for _, index := range toolCallAccumulatorIndexes(param.ToolCallsAccumulator) {
accumulator := param.ToolCallsAccumulator[index]
if !accumulator.StartEmitted {
// Belated emit at [DONE]; same behavior as the finish_reason
// path for name-but-no-id streams.
if accumulator.Name == "" {
continue
}
emitToolUseStart(param, index, accumulator, &results)
}
blockIndex := param.toolContentBlockIndex(index)
if accumulator.Arguments.Len() > 0 {
@@ -547,6 +577,29 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
param.TextContentBlockIndex = -1
}
func emitToolUseStart(param *ConvertOpenAIResponseToAnthropicParams, openAIToolIndex int, accumulator *ToolCallAccumulator, results *[][]byte) {
stopThinkingContentBlock(param, results)
stopTextContentBlock(param, results)
blockIndex := param.toolContentBlockIndex(openAIToolIndex)
contentBlockStartJSON := []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
contentBlockStartJSON, _ = sjson.SetBytes(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.SetBytes(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSON, _ = sjson.SetBytes(contentBlockStartJSON, "content_block.name", accumulator.Name)
*results = append(*results, translatorcommon.AppendSSEEventBytes(nil, "content_block_start", contentBlockStartJSON, 2))
accumulator.StartEmitted = true
param.SawToolCall = true
}
func toolCallAccumulatorIndexes(accumulators map[int]*ToolCallAccumulator) []int {
indexes := make([]int, 0, len(accumulators))
for index := range accumulators {
indexes = append(indexes, index)
}
sort.Ints(indexes)
return indexes
}
// ConvertOpenAIResponseToClaudeNonStream converts a non-streaming OpenAI response to a non-streaming Anthropic response.
//
// Parameters:
@@ -3,11 +3,108 @@ package claude
import (
"bytes"
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
type sseEvent struct {
Type string
Payload string
}
func runStream(t *testing.T, originalReq string, chunks ...string) []sseEvent {
t.Helper()
var paramAny any
var emitted [][]byte
for _, chunk := range chunks {
emitted = append(emitted, ConvertOpenAIResponseToClaude(
context.Background(),
"",
[]byte(originalReq),
nil,
[]byte("data: "+chunk),
&paramAny,
)...)
}
emitted = append(emitted, ConvertOpenAIResponseToClaude(
context.Background(),
"",
[]byte(originalReq),
nil,
[]byte("data: [DONE]"),
&paramAny,
)...)
var events []sseEvent
for _, raw := range emitted {
s := string(raw)
if !strings.HasPrefix(s, "event: ") {
continue
}
nl := strings.Index(s, "\n")
if nl < 0 {
continue
}
typ := strings.TrimPrefix(s[:nl], "event: ")
rest := s[nl+1:]
if !strings.HasPrefix(rest, "data: ") {
continue
}
payload := strings.TrimRight(strings.TrimPrefix(rest, "data: "), "\n")
events = append(events, sseEvent{Type: typ, Payload: payload})
}
return events
}
func countByType(events []sseEvent, typ string) int {
n := 0
for _, e := range events {
if e.Type == typ {
n++
}
}
return n
}
func toolUseStarts(events []sseEvent) []sseEvent {
var out []sseEvent
for _, e := range events {
if e.Type != "content_block_start" {
continue
}
if gjson.Get(e.Payload, "content_block.type").String() == "tool_use" {
out = append(out, e)
}
}
return out
}
func blockIndices(events []sseEvent) []int64 {
var idx []int64
for _, e := range events {
if e.Type == "content_block_start" {
idx = append(idx, gjson.Get(e.Payload, "index").Int())
}
}
return idx
}
func lastStopReason(events []sseEvent) string {
for i := len(events) - 1; i >= 0; i-- {
if events[i].Type == "message_delta" {
return gjson.Get(events[i].Payload, "delta.stop_reason").String()
}
}
return ""
}
const streamReq = `{"stream":true}`
func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing.T) {
originalRequest := []byte(`{"stream":true}`)
originalRequest := []byte(streamReq)
var param any
firstChunks := ConvertOpenAIResponseToClaude(
@@ -39,3 +136,231 @@ func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing
t.Fatalf("did not expect null tool name delta to emit an empty tool name, got %s", string(secondOutput))
}
}
func TestStreamingTool_EmptyNameThroughout(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":"","arguments":""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":"{\"x\":1}"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
if got := len(toolUseStarts(events)); got != 0 {
t.Fatalf("expected zero tool_use content_block_start, got %d (events=%+v)", got, events)
}
if got := countByType(events, "content_block_delta"); got != 0 {
t.Fatalf("expected zero content_block_delta when start was suppressed, got %d", got)
}
if got := countByType(events, "content_block_stop"); got != 0 {
t.Fatalf("expected zero content_block_stop when start was suppressed, got %d", got)
}
if got := lastStopReason(events); got == "tool_use" {
t.Fatalf("stop_reason must not be tool_use when zero tool_use blocks were emitted; got %q", got)
}
}
func TestStreamingTool_NullName(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":null,"arguments":""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
if got := len(toolUseStarts(events)); got != 0 {
t.Fatalf("null name must not produce a tool_use start; got %d", got)
}
if got := countByType(events, "content_block_stop"); got != 0 {
t.Fatalf("null name must not produce content_block_stop; got %d", got)
}
}
func TestStreamingTool_NonStringName(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":123,"arguments":""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
if got := len(toolUseStarts(events)); got != 0 {
t.Fatalf("non-string name must not produce a tool_use start; got %d", got)
}
}
func TestStreamingTool_RepeatedName(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":"do_it","arguments":""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"do_it","arguments":"{\"x\""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"do_it","arguments":":1}"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 1 {
t.Fatalf("expected exactly one tool_use start, got %d", len(starts))
}
if name := gjson.Get(starts[0].Payload, "content_block.name").String(); name != "do_it" {
t.Fatalf("announced tool name = %q, want %q", name, "do_it")
}
if got := countByType(events, "content_block_stop"); got != 1 {
t.Fatalf("expected exactly one content_block_stop, got %d", got)
}
}
func TestStreamingTool_MixedSuppressedAndValid(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[
{"index":0,"id":"call_skip","function":{"name":"","arguments":""}},
{"index":1,"id":"call_real","function":{"name":"do_it","arguments":""}}
]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[
{"index":1,"function":{"arguments":"{}"}}
]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 1 {
t.Fatalf("expected exactly one tool_use start, got %d", len(starts))
}
if got := countByType(events, "content_block_stop"); got != 1 {
t.Fatalf("expected exactly one content_block_stop, got %d", got)
}
indices := blockIndices(events)
if len(indices) == 0 || indices[0] != 0 {
t.Fatalf("first content_block_start index must be 0, got %v", indices)
}
}
func TestStreamingTool_EmptyIDDeferStart(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"","function":{"name":"do_it","arguments":""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_real","function":{"arguments":"{}"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 1 {
t.Fatalf("expected exactly one tool_use start once id arrived, got %d", len(starts))
}
if id := gjson.Get(starts[0].Payload, "content_block.id").String(); id != "call_real" {
t.Fatalf("announced tool id = %q, want %q", id, "call_real")
}
}
func TestStreamingTool_IDInDeltaWithoutFunction(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"function":{"name":"do_it"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_real"}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 1 {
t.Fatalf("expected exactly one tool_use start when id arrives in a function-less delta, got %d", len(starts))
}
if id := gjson.Get(starts[0].Payload, "content_block.id").String(); id != "call_real" {
t.Fatalf("announced tool id = %q, want %q", id, "call_real")
}
if name := gjson.Get(starts[0].Payload, "content_block.name").String(); name != "do_it" {
t.Fatalf("announced tool name = %q, want %q", name, "do_it")
}
if got := countByType(events, "content_block_stop"); got != 1 {
t.Fatalf("expected exactly one content_block_stop, got %d", got)
}
}
func TestStreamingTool_StopReasonWithEmittedTool(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":"do_it","arguments":"{}"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`,
)
if got := lastStopReason(events); got != "tool_use" {
t.Fatalf("stop_reason = %q, want %q", got, "tool_use")
}
}
func TestStreamingTool_StopReasonWhenIDNeverArrives(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"function":{"name":"do_it","arguments":""}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 1 {
t.Fatalf("expected one belated tool_use start with synthetic id, got %d", len(starts))
}
id := gjson.Get(starts[0].Payload, "content_block.id").String()
if !strings.HasPrefix(id, "toolu_") {
t.Fatalf("synthetic id should match toolu_<nanos>_<n>, got %q", id)
}
if name := gjson.Get(starts[0].Payload, "content_block.name").String(); name != "do_it" {
t.Fatalf("announced tool name = %q, want %q", name, "do_it")
}
if got := lastStopReason(events); got != "tool_use" {
t.Fatalf("stop_reason = %q, want %q", got, "tool_use")
}
}
func TestStreamingTool_BelatedStartsUseOpenAIToolIndexOrder(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[
{"index":2,"function":{"name":"third_tool","arguments":"{}"}},
{"index":0,"function":{"name":"first_tool","arguments":"{}"}},
{"index":1,"function":{"name":"second_tool","arguments":"{}"}}
]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 3 {
t.Fatalf("expected three belated tool_use starts, got %d", len(starts))
}
wantNames := []string{"first_tool", "second_tool", "third_tool"}
for i, wantName := range wantNames {
if name := gjson.Get(starts[i].Payload, "content_block.name").String(); name != wantName {
t.Fatalf("tool_use start %d name = %q, want %q (starts=%+v)", i, name, wantName, starts)
}
if blockIndex := gjson.Get(starts[i].Payload, "index").Int(); blockIndex != int64(i) {
t.Fatalf("tool_use start %d block index = %d, want %d", i, blockIndex, i)
}
}
}
func TestStreamingTool_LateIDAfterFinalization(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"function":{"name":"do_it"}}]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_late"}]}}]}`,
)
starts := toolUseStarts(events)
if len(starts) != 1 {
t.Fatalf("expected one belated tool_use start, got %d", len(starts))
}
var sawMessageStop bool
for _, e := range events {
if e.Type == "message_stop" {
sawMessageStop = true
continue
}
if sawMessageStop {
switch e.Type {
case "content_block_start", "content_block_delta", "content_block_stop":
t.Fatalf("event %q emitted after message_stop (events=%+v)", e.Type, events)
}
}
}
}
func TestStreamingTool_StopReasonMixedSuppressedAndValid(t *testing.T) {
events := runStream(t, streamReq,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[
{"index":0,"id":"call_skip","function":{"name":"","arguments":""}},
{"index":1,"id":"call_real","function":{"name":"do_it","arguments":"{}"}}
]}}]}`,
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
)
if got := lastStopReason(events); got != "tool_use" {
t.Fatalf("stop_reason = %q, want %q", got, "tool_use")
}
}
+15
View File
@@ -0,0 +1,15 @@
package util
import (
"strings"
"unicode"
)
const claudeCodeAttributionSystemPrefix = "x-anthropic-billing-header:"
// IsClaudeCodeAttributionSystemText reports whether text is the Claude Code
// attribution block that carries per-request billing and prompt fingerprint data.
func IsClaudeCodeAttributionSystemText(text string) bool {
text = strings.TrimLeftFunc(text, unicode.IsSpace)
return strings.HasPrefix(text, claudeCodeAttributionSystemPrefix)
}
+40
View File
@@ -0,0 +1,40 @@
package util
import "testing"
func TestIsClaudeCodeAttributionSystemText(t *testing.T) {
tests := []struct {
name string
text string
want bool
}{
{
name: "Claude Code attribution block",
text: "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;",
want: true,
},
{
name: "leading whitespace",
text: "\n\t x-anthropic-billing-header: cc_version=2.1.63.abc; cch=12345;",
want: true,
},
{
name: "regular system prompt",
text: "You are helpful.",
want: false,
},
{
name: "empty text",
text: "",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsClaudeCodeAttributionSystemText(tt.text); got != tt.want {
t.Fatalf("IsClaudeCodeAttributionSystemText(%q) = %v, want %v", tt.text, got, tt.want)
}
})
}
}
+2 -1
View File
@@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sort"
"strings"
@@ -20,7 +21,7 @@ func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
out(strings.ToLower(name) + "|" + strings.ToLower(alias) + "|" + fmt.Sprintf("image=%t", model.Image))
}
})
return hashJoined(keys)
+11
View File
@@ -25,6 +25,17 @@ func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
}
}
func TestComputeOpenAICompatModelsHash_IncludesImageFlag(t *testing.T) {
textModel := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-image", Alias: "image"}})
imageModel := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-image", Alias: "image", Image: true}})
if textModel == "" || imageModel == "" {
t.Fatal("hashes should not be empty")
}
if textModel == imageModel {
t.Fatal("hash should change when image flag changes")
}
}
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
a := []config.OpenAICompatibilityModel{
{Name: "gpt-4", Alias: "gpt4"},
+1 -1
View File
@@ -153,7 +153,7 @@ func openAICompatSignature(entry config.OpenAICompatibility) string {
if name == "" && alias == "" {
continue
}
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias)+"|"+fmt.Sprintf("image=%t", model.Image))
}
if len(models) > 0 {
sort.Strings(models)
+152 -2
View File
@@ -14,6 +14,8 @@ import (
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v7/internal/constant"
@@ -257,6 +259,15 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
return
case chunk, ok := <-dataChan:
if !ok {
if errMsg, okPendingErr := pendingClaudeStreamError(errChan); okPendingErr {
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
@@ -282,6 +293,21 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
}
}
func pendingClaudeStreamError(errs <-chan *interfaces.ErrorMessage) (*interfaces.ErrorMessage, bool) {
if errs == nil {
return nil, false
}
select {
case errMsg, ok := <-errs:
if !ok {
return nil, false
}
return errMsg, true
default:
return nil, false
}
}
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
@@ -317,11 +343,135 @@ type claudeErrorResponse struct {
}
func (h *ClaudeCodeAPIHandler) toClaudeError(msg *interfaces.ErrorMessage) claudeErrorResponse {
status := http.StatusInternalServerError
errText := http.StatusText(status)
if msg != nil {
if msg.StatusCode > 0 {
status = msg.StatusCode
errText = http.StatusText(status)
}
if msg.Error != nil {
if v := strings.TrimSpace(msg.Error.Error()); v != "" {
errText = v
}
}
}
errType, message := claudeErrorDetailFromText(status, errText)
return claudeErrorResponse{
Type: "error",
Error: claudeErrorDetail{
Type: "api_error",
Message: msg.Error.Error(),
Type: errType,
Message: message,
},
}
}
func (h *ClaudeCodeAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError
if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode
}
if msg != nil && msg.Addon != nil && handlers.PassthroughHeadersEnabled(h.Cfg) {
for key, values := range msg.Addon {
if len(values) == 0 {
continue
}
c.Writer.Header().Del(key)
for _, value := range values {
c.Writer.Header().Add(key, value)
}
}
}
body, err := json.Marshal(h.toClaudeError(msg))
if err != nil {
body = []byte(`{"type":"error","error":{"type":"api_error","message":"Internal Server Error"}}`)
}
appendClaudeAPIResponse(c, body)
if !c.Writer.Written() {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Status(status)
_, _ = c.Writer.Write(body)
}
func claudeErrorDetailFromText(status int, errText string) (string, string) {
message := strings.TrimSpace(errText)
if message == "" {
message = http.StatusText(status)
}
errType := claudeErrorTypeFromStatus(status)
var payload map[string]any
if json.Valid([]byte(message)) {
if err := json.Unmarshal([]byte(message), &payload); err == nil {
if e, ok := payload["error"].(map[string]any); ok {
if t, ok := e["type"].(string); ok && strings.TrimSpace(t) != "" {
errType = strings.TrimSpace(t)
}
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
} else if c, ok := e["code"].(string); ok && strings.TrimSpace(c) != "" {
message = strings.TrimSpace(c)
}
} else {
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) != "" && strings.TrimSpace(t) != "error" {
errType = strings.TrimSpace(t)
}
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
}
}
}
return errType, message
}
func claudeErrorTypeFromStatus(status int) string {
switch status {
case http.StatusUnauthorized:
return "authentication_error"
case http.StatusPaymentRequired:
return "billing_error"
case http.StatusForbidden:
return "permission_error"
case http.StatusNotFound:
return "not_found_error"
case http.StatusRequestEntityTooLarge:
return "request_too_large"
case http.StatusTooManyRequests:
return "rate_limit_error"
case http.StatusGatewayTimeout:
return "timeout_error"
case 529:
return "overloaded_error"
default:
if status >= http.StatusInternalServerError {
return "api_error"
}
return "invalid_request_error"
}
}
func appendClaudeAPIResponse(c *gin.Context, data []byte) {
if c == nil || len(data) == 0 {
return
}
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); !exists {
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
combined = append(combined, existingBytes...)
if existingBytes[len(existingBytes)-1] != '\n' {
combined = append(combined, '\n')
}
combined = append(combined, data...)
c.Set("API_RESPONSE", combined)
return
}
}
c.Set("API_RESPONSE", bytes.Clone(data))
}
@@ -0,0 +1,94 @@
package claude
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/tidwall/gjson"
)
func TestClaudeErrorExtractsOpenAIStyleUpstreamJSON(t *testing.T) {
handler := &ClaudeCodeAPIHandler{}
msg := &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: errors.New(`{"error":{"message":"Your input exceeds the context window of this model. Please adjust your input and try again.","type":"invalid_request_error","code":"context_too_large"}}`),
}
got := handler.toClaudeError(msg)
if got.Type != "error" {
t.Fatalf("type = %q, want error", got.Type)
}
if got.Error.Type != "invalid_request_error" {
t.Fatalf("error.type = %q, want invalid_request_error", got.Error.Type)
}
if got.Error.Message != "Your input exceeds the context window of this model. Please adjust your input and try again." {
t.Fatalf("error.message = %q", got.Error.Message)
}
}
func TestClaudeErrorExtractsClaudeStyleUpstreamJSON(t *testing.T) {
handler := &ClaudeCodeAPIHandler{}
msg := &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New(`{"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed your account's rate limit. Please try again later."},"request_id":"req_123"}`),
}
got := handler.toClaudeError(msg)
if got.Error.Type != "rate_limit_error" {
t.Fatalf("error.type = %q, want rate_limit_error", got.Error.Type)
}
if got.Error.Message != "This request would exceed your account's rate limit. Please try again later." {
t.Fatalf("error.message = %q", got.Error.Message)
}
}
func TestWriteClaudeErrorResponseUsesClaudeEnvelope(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
handler := &ClaudeCodeAPIHandler{}
msg := &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: errors.New(`{"error":{"message":"Your input exceeds the context window of this model. Please adjust your input and try again.","type":"invalid_request_error","code":"context_too_large"}}`),
}
handler.WriteErrorResponse(c, msg)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusBadRequest)
}
body := recorder.Body.Bytes()
if got := gjson.GetBytes(body, "type").String(); got != "error" {
t.Fatalf("type = %q, want error; body=%s", got, body)
}
if got := gjson.GetBytes(body, "error.type").String(); got != "invalid_request_error" {
t.Fatalf("error.type = %q, want invalid_request_error; body=%s", got, body)
}
if got := gjson.GetBytes(body, "error.message").String(); got != "Your input exceeds the context window of this model. Please adjust your input and try again." {
t.Fatalf("error.message = %q; body=%s", got, body)
}
}
func TestPendingClaudeStreamErrorUsesBufferedError(t *testing.T) {
wantErr := &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: errors.New(`{"error":{"message":"Your input exceeds the context window of this model. Please adjust your input and try again.","type":"invalid_request_error","code":"context_too_large"}}`),
}
errs := make(chan *interfaces.ErrorMessage, 1)
errs <- wantErr
close(errs)
gotErr, ok := pendingClaudeStreamError(errs)
if !ok {
t.Fatal("expected pending stream error")
}
if gotErr != wantErr {
t.Fatalf("pending error = %p, want %p", gotErr, wantErr)
}
}
+49 -4
View File
@@ -231,6 +231,17 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
return meta
}
func setReasoningEffortMetadata(meta map[string]any, handlerType, model string, rawJSON []byte) {
if meta == nil {
return
}
effort := thinking.ExtractReasoningEffort(rawJSON, handlerType, model)
if effort == "" {
return
}
meta[coreexecutor.ReasoningEffortMetadataKey] = effort
}
// headersFromContext extracts the original HTTP request headers from the gin context
// embedded in the provided context. This allows session affinity selectors to read
// client headers like X-Amp-Thread-Id.
@@ -400,6 +411,7 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
newCtx = logging.WithEndpoint(newCtx, endpoint)
}
newCtx = logging.WithResponseStatusHolder(newCtx)
newCtx = logging.WithResponseHeadersHolder(newCtx)
cancelCtx := newCtx
if requestCtx != nil && requestCtx != parentCtx {
@@ -534,12 +546,22 @@ func appendAPIResponse(c *gin.Context, data []byte) {
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
}
// ExecuteImageWithAuthManager executes an OpenAI-compatible image endpoint request.
func (h *BaseAPIHandler) ExecuteImageWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
}
func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
if errMsg != nil {
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON)
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -588,6 +610,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON)
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -631,7 +654,16 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
// This path is the only supported execution route.
// The returned http.Header carries upstream response headers captured before streaming begins.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
}
// ExecuteImageStreamWithAuthManager executes a streaming OpenAI-compatible image endpoint request.
func (h *BaseAPIHandler) ExecuteImageStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
}
func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- errMsg
@@ -640,6 +672,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON)
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -847,6 +880,10 @@ func statusFromError(err error) int {
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
return h.getRequestDetailsWithOptions(modelName, false)
}
func (h *BaseAPIHandler) getRequestDetailsWithOptions(modelName string, allowImageModel bool) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
@@ -871,10 +908,10 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
if strings.EqualFold(baseModel, "gpt-image-2") {
if strings.EqualFold(routeModelBaseName(baseModel), "gpt-image-2") && !allowImageModel {
return nil, "", &interfaces.ErrorMessage{
StatusCode: http.StatusServiceUnavailable,
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", baseModel),
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", routeModelBaseName(baseModel)),
}
}
@@ -901,6 +938,14 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
return providers, resolvedModelName, nil
}
func routeModelBaseName(model string) string {
model = strings.TrimSpace(model)
if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 {
return strings.TrimSpace(model[idx+1:])
}
return model
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil
@@ -18,3 +18,23 @@ func TestRequestExecutionMetadataIncludesExecutionSessionWithoutIdempotencyKey(t
t.Fatalf("unexpected idempotency key in metadata: %v", meta[idempotencyKeyMetadataKey])
}
}
func TestSetReasoningEffortMetadataUsesSuffixOverBody(t *testing.T) {
meta := make(map[string]any)
setReasoningEffortMetadata(meta, "openai", "gpt-5.4(high)", []byte(`{"reasoning_effort":"low"}`))
if got := meta[coreexecutor.ReasoningEffortMetadataKey]; got != "high" {
t.Fatalf("ReasoningEffortMetadataKey = %v, want %q", got, "high")
}
}
func TestSetReasoningEffortMetadataSupportsOpenAIResponses(t *testing.T) {
meta := make(map[string]any)
setReasoningEffortMetadata(meta, "openai-response", "gpt-5.4", []byte(`{"reasoning":{"effort":"medium"}}`))
if got := meta[coreexecutor.ReasoningEffortMetadataKey]; got != "medium" {
t.Fatalf("ReasoningEffortMetadataKey = %v, want %q", got, "medium")
}
}
+71 -5
View File
@@ -20,6 +20,14 @@ var (
codexClientModelTemplatesErr error
)
var codexClientAllowedReasoningLevels = map[string]struct{}{
"none": {},
"low": {},
"medium": {},
"high": {},
"xhigh": {},
}
func (h *OpenAIAPIHandler) codexClientModelsResponse() map[string]any {
return CodexClientModelsResponse(h.Models())
}
@@ -45,6 +53,7 @@ func buildCodexClientModels(models []map[string]any) []map[string]any {
if template, ok := templates[id]; ok {
entry := cloneCodexClientModelMap(template)
sanitizeCodexClientReasoningMetadata(entry)
applyCodexClientVisibilityOverride(entry, id)
result = append(result, entry)
continue
@@ -52,6 +61,7 @@ func buildCodexClientModels(models []map[string]any) []map[string]any {
entry := cloneCodexClientModelMap(defaultTemplate)
applyCodexClientModelMetadata(entry, id, model)
sanitizeCodexClientReasoningMetadata(entry)
applyCodexClientVisibilityOverride(entry, id)
result = append(result, entry)
}
@@ -104,6 +114,9 @@ func applyCodexClientModelMetadata(entry map[string]any, id string, model map[st
if info.ContextLength > 0 {
contextWindow = info.ContextLength
}
if info.Type == registry.OpenAIImageModelType {
entry["visibility"] = "hide"
}
applyCodexClientThinkingMetadata(entry, info.Thinking)
}
@@ -150,12 +163,16 @@ func applyCodexClientThinkingMetadata(entry map[string]any, thinking *registry.T
levels := make([]any, 0, len(thinking.Levels))
defaultLevel := ""
firstLevel := ""
for _, rawLevel := range thinking.Levels {
level := strings.ToLower(strings.TrimSpace(rawLevel))
if level == "" || level == "none" {
level := normalizeCodexClientReasoningLevel(rawLevel)
if level == "" {
continue
}
if defaultLevel == "" || level == "medium" {
if firstLevel == "" {
firstLevel = level
}
if (defaultLevel == "" && level != "none") || level == "medium" {
defaultLevel = level
}
levels = append(levels, map[string]any{
@@ -166,15 +183,64 @@ func applyCodexClientThinkingMetadata(entry map[string]any, thinking *registry.T
if len(levels) == 0 {
return
}
if defaultLevel == "" {
defaultLevel = firstLevel
}
entry["supported_reasoning_levels"] = levels
entry["default_reasoning_level"] = defaultLevel
}
func sanitizeCodexClientReasoningMetadata(entry map[string]any) {
rawLevels, ok := entry["supported_reasoning_levels"].([]any)
if !ok {
return
}
levels := make([]any, 0, len(rawLevels))
allowedDefaults := make(map[string]struct{}, len(rawLevels))
for _, rawLevelEntry := range rawLevels {
levelEntry, ok := rawLevelEntry.(map[string]any)
if !ok {
continue
}
level := normalizeCodexClientReasoningLevel(stringModelValue(levelEntry, "effort"))
if level == "" {
continue
}
clonedEntry := cloneCodexClientModelMap(levelEntry)
clonedEntry["effort"] = level
levels = append(levels, clonedEntry)
allowedDefaults[level] = struct{}{}
}
if len(levels) == 0 {
delete(entry, "supported_reasoning_levels")
delete(entry, "default_reasoning_level")
return
}
defaultLevel := normalizeCodexClientReasoningLevel(stringModelValue(entry, "default_reasoning_level"))
if _, ok := allowedDefaults[defaultLevel]; !ok {
defaultLevel = stringModelValue(levels[0].(map[string]any), "effort")
}
entry["supported_reasoning_levels"] = levels
entry["default_reasoning_level"] = defaultLevel
}
func normalizeCodexClientReasoningLevel(rawLevel string) string {
level := strings.ToLower(strings.TrimSpace(rawLevel))
if _, ok := codexClientAllowedReasoningLevels[level]; !ok {
return ""
}
return level
}
func codexClientReasoningDescription(level string) string {
switch level {
case "minimal":
return "Fastest responses with minimal reasoning"
case "none":
return "No reasoning"
case "low":
return "Fast responses with lighter reasoning"
case "medium":
@@ -9,6 +9,7 @@ import (
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -143,7 +145,20 @@ func isSupportedImagesModel(model string) bool {
if baseModel == defaultImagesToolModel {
return true
}
return isXAIImagesModel(model)
return isXAIImagesModel(model) || isOpenAICompatImagesModel(model)
}
func isDefaultImagesToolModel(model string) bool {
return imagesModelBase(model) == defaultImagesToolModel
}
func isOpenAICompatImagesModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
info := registry.LookupModelInfo(model)
return info != nil && info.Type == registry.OpenAIImageModelType
}
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
@@ -153,7 +168,7 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, %s, or a configured openai-compatibility image model.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Type: "invalid_request_error",
},
})
@@ -376,6 +391,90 @@ func multipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
return "data:" + mediaType + ";base64," + b64, nil
}
func buildOpenAICompatImagesJSONRequest(rawJSON []byte, imageModel string, stream bool) []byte {
payload := rawJSON
if model := strings.TrimSpace(imageModel); model != "" {
payload, _ = sjson.SetBytes(payload, "model", model)
}
if stream {
payload, _ = sjson.SetBytes(payload, "stream", true)
} else {
payload, _ = sjson.DeleteBytes(payload, "stream")
}
return payload
}
func cloneMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
dst := make(textproto.MIMEHeader, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func buildOpenAICompatImagesMultipartRequest(form *multipart.Form, imageModel string, stream bool) ([]byte, string, error) {
if form == nil {
return nil, "", fmt.Errorf("multipart form is nil")
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", imageModel); errWrite != nil {
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
}
if stream {
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
}
}
for key, values := range form.Value {
if key == "model" || key == "stream" {
continue
}
for _, value := range values {
if errWrite := writer.WriteField(key, value); errWrite != nil {
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
}
}
}
for key, files := range form.File {
for _, fileHeader := range files {
if fileHeader == nil {
continue
}
header := cloneMIMEHeader(fileHeader.Header)
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/octet-stream")
}
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
}
src, errOpen := fileHeader.Open()
if errOpen != nil {
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
}
_, errCopy := io.Copy(part, src)
if errClose := src.Close(); errClose != nil {
log.Errorf("openai images: close upload file error: %v", errClose)
if errCopy == nil {
errCopy = errClose
}
}
if errCopy != nil {
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
}
}
}
if errClose := writer.Close(); errClose != nil {
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
}
return body.Bytes(), writer.FormDataContentType(), nil
}
func parseIntField(raw string, fallback int64) int64 {
raw = strings.TrimSpace(raw)
if raw == "" {
@@ -454,11 +553,21 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isDefaultImagesToolModel(imageModel) {
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat)
h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_generation", stream)
return
}
tool := []byte(`{"type":"image_generation","action":"generate"}`)
tool, _ = sjson.SetBytes(tool, "model", imageModel)
@@ -589,6 +698,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
}
stream := parseBoolField(c.PostForm("stream"), false)
if isDefaultImagesToolModel(imageModel) {
imageReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
if errBuild != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", errBuild),
Type: "invalid_request_error",
},
})
return
}
c.Request.Header.Set("Content-Type", contentType)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "")
aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio)
@@ -598,6 +722,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
if errBuild != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", errBuild),
Type: "invalid_request_error",
},
})
return
}
c.Request.Header.Set("Content-Type", contentType)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
return
}
var maskDataURL *string
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
@@ -701,6 +840,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isDefaultImagesToolModel(imageModel) {
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
images := collectXAIImagesFromJSON(rawJSON)
if len(images) == 0 {
@@ -717,6 +861,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
return
}
var images []string
imagesResult := gjson.GetBytes(rawJSON, "images")
@@ -904,14 +1053,247 @@ func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, respon
h.collectXAIImages(c, xaiReq, responseFormat)
}
func (h *OpenAIAPIHandler) handleOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string, responseFormat string, streamPrefix string, stream bool) {
if stream {
h.streamOpenAICompatImages(c, compatReq, imageModel)
return
}
h.collectImagesWithModel(c, compatReq, imageModel, responseFormat)
}
func (h *OpenAIAPIHandler) handleRoutedImages(c *gin.Context, imageReq []byte, imageModel string, stream bool) {
if stream {
h.streamRoutedImages(c, imageReq, imageModel)
return
}
h.collectRoutedImages(c, imageReq, imageModel)
}
func (h *OpenAIAPIHandler) collectRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(imageModel)
resp, upstreamHeaders, errMsg := h.ExecuteImageWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel(nil)
}
func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
model := strings.TrimSpace(imageModel)
dataChan, upstreamHeaders, errChan := h.ExecuteImageStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
return
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(chunk)
flusher.Flush()
h.forwardRawImageStream(cliCtx, c, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Context, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
emitError := func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case <-ctx.Done():
cancel(ctx.Err())
return
case errMsg, ok := <-errs:
if ok && errMsg != nil {
emitError(errMsg)
cancel(errMsg.Error)
return
}
errs = nil
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
_, _ = c.Writer.Write(chunk)
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
}
}
func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(imageModel)
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, compatReq, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(chunk)
flusher.Flush()
h.ForwardStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, handlers.StreamForwardOptions{
WriteChunk: func(next []byte) {
_, _ = c.Writer.Write(next)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
},
})
return
}
}
}
func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) {
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
h.collectImagesWithModel(c, xaiReq, model, responseFormat)
}
func (h *OpenAIAPIHandler) collectImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
model = strings.TrimSpace(model)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
@@ -937,6 +1319,11 @@ func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, respo
}
func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) {
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
h.streamImagesWithModel(c, xaiReq, model, responseFormat, streamPrefix)
}
func (h *OpenAIAPIHandler) streamImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string, streamPrefix string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
@@ -949,8 +1336,8 @@ func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, respon
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
model = strings.TrimSpace(model)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
@@ -3,14 +3,17 @@ package openai
import (
"bytes"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"strings"
"testing"
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
"github.com/tidwall/gjson"
@@ -40,7 +43,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
}
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "."
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", " + xaiImagesQualityModel + ", or a configured openai-compatibility image model."
if message != expectedMessage {
t.Fatalf("error message = %q, want %q", message, expectedMessage)
}
@@ -63,6 +66,25 @@ func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) {
}
}
func TestImagesModelValidationAllowsOpenAICompatImageModels(t *testing.T) {
modelRegistry := registry.GetGlobalRegistry()
clientID := "test-openai-compat-image-model-validation"
modelRegistry.RegisterClient(clientID, "openai-compatibility", []*registry.ModelInfo{
{ID: "compat-image-model", Object: "model", OwnedBy: "compat", Type: registry.OpenAIImageModelType},
{ID: "compat-chat-model", Object: "model", OwnedBy: "compat", Type: "openai-compatibility"},
})
t.Cleanup(func() {
modelRegistry.UnregisterClient(clientID)
})
if !isSupportedImagesModel("compat-image-model") {
t.Fatal("expected configured openai-compatibility image model to be supported")
}
if isSupportedImagesModel("compat-chat-model") {
t.Fatal("expected non-image openai-compatibility model to be rejected")
}
}
func TestBuildXAIImagesGenerationsRequest(t *testing.T) {
rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`)
@@ -122,6 +144,100 @@ func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) {
}
}
func TestBuildOpenAICompatImagesJSONRequestPreservesStreamForStreaming(t *testing.T) {
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":false}`), "upstream-image", true)
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
}
if !gjson.GetBytes(req, "stream").Bool() {
t.Fatalf("stream flag missing: %s", string(req))
}
}
func TestBuildOpenAICompatImagesJSONRequestDropsStreamForNonStreaming(t *testing.T) {
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":true}`), "upstream-image", false)
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
}
if gjson.GetBytes(req, "stream").Exists() {
t.Fatalf("stream flag should be removed from non-streaming request: %s", string(req))
}
}
func TestBuildOpenAICompatImagesMultipartRequestPreservesStreamAndFileContentType(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
t.Fatalf("write stream field: %v", errWrite)
}
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
t.Fatalf("write prompt field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
header.Set("Content-Type", "image/png")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), writer.Boundary())
form, errRead := reader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read source form: %v", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
t.Fatalf("remove source form files: %v", errRemove)
}
}()
out, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, "upstream-image", true)
if errBuild != nil {
t.Fatalf("buildOpenAICompatImagesMultipartRequest error: %v", errBuild)
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil {
t.Fatalf("parse content type: %v", errParse)
}
if mediaType != "multipart/form-data" {
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
}
rewrittenReader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
rewrittenForm, errRead := rewrittenReader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read rewritten form: %v", errRead)
}
defer func() {
if errRemove := rewrittenForm.RemoveAll(); errRemove != nil {
t.Fatalf("remove rewritten form files: %v", errRemove)
}
}()
if got := rewrittenForm.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
t.Fatalf("model values = %#v, want upstream-image", got)
}
if got := rewrittenForm.Value["stream"]; len(got) != 1 || got[0] != "true" {
t.Fatalf("stream values = %#v, want true", got)
}
if got := rewrittenForm.Value["prompt"]; len(got) != 1 || got[0] != "edit" {
t.Fatalf("prompt values = %#v, want edit", got)
}
if got := rewrittenForm.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/png" {
t.Fatalf("image headers = %#v, want image/png", got)
}
}
func TestBuildImagesAPIResponseFromXAI(t *testing.T) {
payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`)
+6 -3
View File
@@ -177,12 +177,15 @@ waitForCallback:
if accessToken != "" {
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
return nil, fmt.Errorf("antigravity: failed to fetch project ID: %w", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
log.Infof("antigravity: obtained project ID %s", util.HideAPIKey(projectID))
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("antigravity: project ID discovery returned empty project")
}
now := time.Now()
metadata := map[string]any{
@@ -208,7 +211,7 @@ waitForCallback:
fmt.Println("Antigravity authentication successful")
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
fmt.Printf("Using GCP project: %s\n", util.HideAPIKey(projectID))
}
return &coreauth.Auth{
ID: fileName,
@@ -4,12 +4,14 @@ import (
"context"
"fmt"
"net/http"
"strings"
"testing"
"time"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
type antigravityCreditsFallbackExecutor struct {
@@ -48,6 +50,43 @@ func (e *antigravityCreditsFallbackExecutor) HttpRequest(context.Context, *Auth,
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
type codexOnlyFailureExecutor struct{}
func (codexOnlyFailureExecutor) Identifier() string { return "codex" }
func (codexOnlyFailureExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
}
func (codexOnlyFailureExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
}
func (codexOnlyFailureExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (codexOnlyFailureExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
}
func (codexOnlyFailureExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
}
type captureLogHook struct {
messages []string
}
func (h *captureLogHook) Levels() []log.Level {
return log.AllLevels
}
func (h *captureLogHook) Fire(entry *log.Entry) error {
h.messages = append(h.messages, entry.Message)
return nil
}
func TestManagerExecuteStream_AntigravityCreditsFallbackAfterBootstrap429(t *testing.T) {
const model = "claude-opus-4-6-thinking"
executor := &antigravityCreditsFallbackExecutor{}
@@ -88,6 +127,51 @@ func TestManagerExecuteStream_AntigravityCreditsFallbackAfterBootstrap429(t *tes
}
}
func TestManagerExecuteStream_CodexOnlyDoesNotEnterAntigravityCreditsFallback(t *testing.T) {
const model = "gpt-5.5"
logger := log.StandardLogger()
oldLevel := logger.GetLevel()
oldHooks := logger.ReplaceHooks(make(log.LevelHooks))
hook := &captureLogHook{}
logger.SetLevel(log.DebugLevel)
logger.AddHook(hook)
t.Cleanup(func() {
logger.SetLevel(oldLevel)
logger.ReplaceHooks(oldHooks)
})
manager := NewManager(nil, nil, nil)
manager.SetConfig(&internalconfig.Config{
QuotaExceeded: internalconfig.QuotaExceeded{AntigravityCredits: true},
})
manager.RegisterExecutor(codexOnlyFailureExecutor{})
manager.RegisterExecutor(&antigravityCreditsFallbackExecutor{})
reg := registry.GetGlobalRegistry()
reg.RegisterClient("codex-only", "codex", []*registry.ModelInfo{{ID: model}})
reg.RegisterClient("ag-unrelated", "antigravity", []*registry.ModelInfo{{ID: "gemini-3-flash"}})
t.Cleanup(func() {
reg.UnregisterClient("codex-only")
reg.UnregisterClient("ag-unrelated")
})
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "codex-only", Provider: "codex"}); errRegister != nil {
t.Fatalf("register codex auth: %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "ag-unrelated", Provider: "antigravity"}); errRegister != nil {
t.Fatalf("register antigravity auth: %v", errRegister)
}
_, errExecute := manager.ExecuteStream(context.Background(), []string{"codex"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
if errExecute == nil {
t.Fatal("expected codex execution failure")
}
for _, message := range hook.messages {
if strings.Contains(message, "shouldAttemptAntigravityCreditsFallback") {
t.Fatalf("codex-only request entered antigravity credits fallback gate; messages=%v", hook.messages)
}
}
}
func TestStatusCodeFromError_UnwrapsStreamBootstrap429(t *testing.T) {
bootstrapErr := newStreamBootstrapError(&Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota exhausted"}, nil)
wrappedErr := fmt.Errorf("conductor stream failed: %w", bootstrapErr)
+142 -15
View File
@@ -45,6 +45,13 @@ type ProviderExecutor interface {
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
}
// RequestAuthPreparer lets an executor update missing auth metadata immediately
// before a request. Manager serializes and persists returned updates.
type RequestAuthPreparer interface {
ShouldPrepareRequestAuth(auth *Auth) bool
PrepareRequestAuth(ctx context.Context, auth *Auth) (*Auth, error)
}
// ExecutionSessionCloser allows executors to release per-session runtime resources.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
@@ -182,6 +189,8 @@ type Manager struct {
// Auto refresh state
refreshCancel context.CancelFunc
refreshLoop *authAutoRefreshLoop
requestPrepareLocks sync.Map
}
// NewManager constructs a manager with optional custom selector and hook.
@@ -1238,7 +1247,7 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
}
}
if lastErr != nil {
if shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
if hasAntigravityProvider(normalized) && shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
if resp, ok := m.tryAntigravityCreditsExecute(ctx, req, opts); ok {
return resp, nil
}
@@ -1304,7 +1313,7 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
}
}
if lastErr != nil {
if shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
if hasAntigravityProvider(normalized) && shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
if result, ok := m.tryAntigravityCreditsExecuteStream(ctx, req, opts); ok {
return result, nil
}
@@ -1365,6 +1374,17 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
continue
}
attempted[auth.ID] = struct{}{}
var errPrepare error
auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth)
if errPrepare != nil {
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
m.MarkResult(execCtx, result)
lastErr = errPrepare
continue
}
var authErr error
for _, upstreamModel := range models {
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
@@ -1453,6 +1473,17 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
continue
}
attempted[auth.ID] = struct{}{}
var errPrepare error
auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth)
if errPrepare != nil {
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
m.MarkResult(execCtx, result)
lastErr = errPrepare
continue
}
var authErr error
for _, upstreamModel := range models {
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
@@ -1539,6 +1570,17 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
continue
}
attempted[auth.ID] = struct{}{}
var errPrepare error
auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth)
if errPrepare != nil {
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
m.MarkResult(execCtx, result)
lastErr = errPrepare
continue
}
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
@@ -1630,9 +1672,69 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
}
}
type requestAuthPrepareLock struct {
mu sync.Mutex
}
func (m *Manager) prepareRequestAuth(ctx context.Context, executor ProviderExecutor, auth *Auth) (*Auth, error) {
if m == nil || executor == nil || auth == nil {
return auth, nil
}
preparer, ok := executor.(RequestAuthPreparer)
if !ok || preparer == nil || !preparer.ShouldPrepareRequestAuth(auth) {
return auth, nil
}
id := strings.TrimSpace(auth.ID)
if id == "" {
return preparer.PrepareRequestAuth(ctx, auth.Clone())
}
lockValue, _ := m.requestPrepareLocks.LoadOrStore(id, &requestAuthPrepareLock{})
lock, ok := lockValue.(*requestAuthPrepareLock)
if !ok || lock == nil {
return preparer.PrepareRequestAuth(ctx, auth.Clone())
}
lock.mu.Lock()
defer lock.mu.Unlock()
target := auth.Clone()
m.mu.RLock()
if current := m.auths[id]; current != nil {
target = current.Clone()
}
m.mu.RUnlock()
if !preparer.ShouldPrepareRequestAuth(target) {
return target, nil
}
updated, errPrepare := preparer.PrepareRequestAuth(ctx, target)
if errPrepare != nil {
return auth, errPrepare
}
if updated == nil {
return target, nil
}
saved, errUpdate := m.Update(ctx, updated)
if errUpdate != nil {
return updated, errUpdate
}
if saved != nil {
return saved, nil
}
return updated, nil
}
func contextWithRequestedModelAlias(ctx context.Context, opts cliproxyexecutor.Options, fallback string) context.Context {
alias := requestedModelAliasFromOptions(opts, fallback)
return coreusage.WithRequestedModelAlias(ctx, alias)
ctx = coreusage.WithRequestedModelAlias(ctx, alias)
if effort := reasoningEffortFromOptions(opts); effort != "" {
ctx = coreusage.WithReasoningEffort(ctx, effort)
}
return ctx
}
func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback string) string {
@@ -1660,6 +1762,24 @@ func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback stri
}
}
func reasoningEffortFromOptions(opts cliproxyexecutor.Options) string {
if len(opts.Metadata) == 0 {
return ""
}
raw, ok := opts.Metadata[cliproxyexecutor.ReasoningEffortMetadataKey]
if !ok || raw == nil {
return ""
}
switch value := raw.(type) {
case string:
return strings.TrimSpace(value)
case []byte:
return strings.TrimSpace(string(value))
default:
return ""
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
@@ -3587,6 +3707,15 @@ type creditsCandidateEntry struct {
provider string
}
func hasAntigravityProvider(providers []string) bool {
for _, p := range providers {
if strings.EqualFold(strings.TrimSpace(p), "antigravity") {
return true
}
}
return false
}
func shouldAttemptAntigravityCreditsFallback(m *Manager, lastErr error, providers []string) bool {
status := statusCodeFromError(lastErr)
log.WithFields(log.Fields{
@@ -3597,18 +3726,6 @@ func shouldAttemptAntigravityCreditsFallback(m *Manager, lastErr error, provider
if m == nil || lastErr == nil {
return false
}
if len(providers) > 0 {
hasAntigravity := false
for _, p := range providers {
if strings.EqualFold(strings.TrimSpace(p), "antigravity") {
hasAntigravity = true
break
}
}
if !hasAntigravity {
return false
}
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil || !cfg.QuotaExceeded.AntigravityCredits {
return false
@@ -3645,6 +3762,11 @@ func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxy
}
creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
creditsCtx = contextWithRequestedModelAlias(creditsCtx, creditsOpts, routeModel)
preparedAuth, errPrepare := m.prepareRequestAuth(creditsCtx, c.executor, c.auth)
if errPrepare != nil {
continue
}
c.auth = preparedAuth
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
models := m.executionModelCandidates(c.auth, routeModel)
if len(models) == 0 {
@@ -3687,6 +3809,11 @@ func (m *Manager) tryAntigravityCreditsExecuteStream(ctx context.Context, req cl
creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt)
}
creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
preparedAuth, errPrepare := m.prepareRequestAuth(creditsCtx, c.executor, c.auth)
if errPrepare != nil {
continue
}
c.auth = preparedAuth
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
models := m.executionModelCandidates(c.auth, routeModel)
if len(models) == 0 {
+25
View File
@@ -0,0 +1,25 @@
package auth
import (
"context"
"testing"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
)
func TestContextWithRequestedModelAliasIncludesReasoningEffort(t *testing.T) {
ctx := contextWithRequestedModelAlias(context.Background(), cliproxyexecutor.Options{
Metadata: map[string]any{
cliproxyexecutor.RequestedModelMetadataKey: "client-model",
cliproxyexecutor.ReasoningEffortMetadataKey: "medium",
},
}, "fallback-model")
if got := coreusage.RequestedModelAliasFromContext(ctx); got != "client-model" {
t.Fatalf("requested model alias = %q, want %q", got, "client-model")
}
if got := coreusage.ReasoningEffortFromContext(ctx); got != "medium" {
t.Fatalf("reasoning effort = %q, want %q", got, "medium")
}
}
@@ -0,0 +1,146 @@
package auth
import (
"context"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
)
type requestPrepareStore struct {
saveCount atomic.Int32
mu sync.Mutex
last *Auth
}
func (s *requestPrepareStore) List(context.Context) ([]*Auth, error) { return nil, nil }
func (s *requestPrepareStore) Save(_ context.Context, auth *Auth) (string, error) {
s.saveCount.Add(1)
s.mu.Lock()
defer s.mu.Unlock()
s.last = auth.Clone()
return "", nil
}
func (s *requestPrepareStore) Delete(context.Context, string) error { return nil }
func (s *requestPrepareStore) lastAuth() *Auth {
s.mu.Lock()
defer s.mu.Unlock()
return s.last.Clone()
}
type requestPrepareExecutor struct {
prepareCalls atomic.Int32
executeCalls atomic.Int32
}
func (e *requestPrepareExecutor) Identifier() string { return "antigravity" }
func (e *requestPrepareExecutor) ShouldPrepareRequestAuth(auth *Auth) bool {
return auth == nil || auth.Metadata == nil || testStringValue(auth.Metadata["project_id"]) == ""
}
func (e *requestPrepareExecutor) PrepareRequestAuth(_ context.Context, auth *Auth) (*Auth, error) {
e.prepareCalls.Add(1)
updated := auth.Clone()
if updated.Metadata == nil {
updated.Metadata = make(map[string]any)
}
updated.Metadata["project_id"] = "prepared-project"
return updated, nil
}
func (e *requestPrepareExecutor) Execute(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.executeCalls.Add(1)
if got := testStringValue(auth.Metadata["project_id"]); got != "prepared-project" {
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusBadRequest, Message: "missing prepared project"}
}
return cliproxyexecutor.Response{Payload: []byte("ok")}, nil
}
func (e *requestPrepareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "stream not implemented"}
}
func (e *requestPrepareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *requestPrepareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "count not implemented"}
}
func (e *requestPrepareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "http not implemented"}
}
func TestManagerExecute_PreparesAndPersistsMissingRequestAuthMetadata(t *testing.T) {
const model = "gemini-3.1-pro"
store := &requestPrepareStore{}
executor := &requestPrepareExecutor{}
manager := NewManager(store, nil, nil)
manager.RegisterExecutor(executor)
auth := &Auth{
ID: "auth-request-prepare",
Provider: "antigravity",
Metadata: map[string]any{"access_token": "token"},
}
if _, errRegister := manager.Register(WithSkipPersist(context.Background()), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, "antigravity", []*registry.ModelInfo{{ID: model}})
t.Cleanup(func() { registry.GetGlobalRegistry().UnregisterClient(auth.ID) })
resp, errExecute := manager.Execute(context.Background(), []string{"antigravity"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
if errExecute != nil {
t.Fatalf("Execute error: %v", errExecute)
}
if string(resp.Payload) != "ok" {
t.Fatalf("payload = %q, want ok", string(resp.Payload))
}
if got := executor.prepareCalls.Load(); got != 1 {
t.Fatalf("prepare calls = %d, want 1", got)
}
if got := store.saveCount.Load(); got < 1 {
t.Fatalf("save count = %d, want at least 1", got)
}
if got := testStringValue(store.lastAuth().Metadata["project_id"]); got != "prepared-project" {
t.Fatalf("persisted project_id = %q, want prepared-project", got)
}
current, ok := manager.GetByID(auth.ID)
if !ok {
t.Fatal("expected auth in manager")
}
if got := testStringValue(current.Metadata["project_id"]); got != "prepared-project" {
t.Fatalf("manager project_id = %q, want prepared-project", got)
}
if _, errExecute = manager.Execute(context.Background(), []string{"antigravity"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{}); errExecute != nil {
t.Fatalf("second Execute error: %v", errExecute)
}
if got := executor.prepareCalls.Load(); got != 1 {
t.Fatalf("prepare calls after second execute = %d, want 1", got)
}
}
func testStringValue(value any) string {
if value == nil {
return ""
}
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
case []byte:
return strings.TrimSpace(string(typed))
default:
return ""
}
}
+3
View File
@@ -17,6 +17,9 @@ const RequestPathMetadataKey = "request_path"
// DisallowFreeAuthMetadataKey instructs auth selection to skip known free-tier credentials.
const DisallowFreeAuthMetadataKey = "disallow_free_auth"
// ReasoningEffortMetadataKey stores the client-requested reasoning effort for usage logs.
const ReasoningEffortMetadataKey = "reasoning_effort"
const (
// PinnedAuthMetadataKey locks execution to a specific auth ID.
PinnedAuthMetadataKey = "pinned_auth_id"
+38 -24
View File
@@ -1208,30 +1208,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
if strings.EqualFold(compat.Name, compatName) {
isCompatAuth = true
// Convert compatibility models to registry models
ms := make([]*ModelInfo, 0, len(compat.Models))
for j := range compat.Models {
m := compat.Models[j]
// Use alias as model ID, fallback to name if alias is empty
modelID := m.Alias
if modelID == "" {
modelID = m.Name
}
thinking := m.Thinking
if thinking == nil {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
ms = append(ms, &ModelInfo{
ID: modelID,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: compat.Name,
Type: "openai-compatibility",
DisplayName: modelID,
UserDefined: false,
Thinking: thinking,
})
}
ms := buildOpenAICompatibilityConfigModels(compat)
// Register and return
if len(ms) > 0 {
if providerKey == "" {
@@ -1578,6 +1555,43 @@ type modelEntry interface {
GetAlias() string
}
func buildOpenAICompatibilityConfigModels(compat *config.OpenAICompatibility) []*ModelInfo {
if compat == nil || len(compat.Models) == 0 {
return nil
}
now := time.Now().Unix()
models := make([]*ModelInfo, 0, len(compat.Models))
for i := range compat.Models {
model := compat.Models[i]
modelID := strings.TrimSpace(model.Alias)
if modelID == "" {
modelID = strings.TrimSpace(model.Name)
}
if modelID == "" {
continue
}
modelType := "openai-compatibility"
if model.Image {
modelType = registry.OpenAIImageModelType
}
thinking := model.Thinking
if thinking == nil && !model.Image {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
models = append(models, &ModelInfo{
ID: modelID,
Object: "model",
Created: now,
OwnedBy: compat.Name,
Type: modelType,
DisplayName: modelID,
UserDefined: false,
Thinking: thinking,
})
}
return models
}
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
if len(models) == 0 {
return nil
@@ -4,6 +4,7 @@ import (
"strings"
"testing"
internalregistry "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
@@ -63,3 +64,71 @@ func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T
t.Fatal("expected global excluded model to be present when attribute override is set")
}
}
func TestRegisterModelsForAuth_OpenAICompatibilityImageModelType(t *testing.T) {
service := &Service{
cfg: &config.Config{
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "images",
BaseURL: "https://example.com/v1",
Models: []config.OpenAICompatibilityModel{
{Name: "upstream-image", Alias: "compat-image", Image: true},
{Name: "upstream-chat", Alias: "compat-chat"},
},
},
},
},
}
auth := &coreauth.Auth{
ID: "auth-openai-compat-image",
Provider: "openai-compatibility",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "api_key",
"compat_name": "images",
"provider_key": "images",
},
}
modelRegistry := internalregistry.GetGlobalRegistry()
modelRegistry.UnregisterClient(auth.ID)
t.Cleanup(func() {
modelRegistry.UnregisterClient(auth.ID)
})
service.registerModelsForAuth(auth)
models := modelRegistry.GetModelsForClient(auth.ID)
var imageModel *internalregistry.ModelInfo
var chatModel *internalregistry.ModelInfo
for _, model := range models {
if model == nil {
continue
}
switch strings.TrimSpace(model.ID) {
case "compat-image":
imageModel = model
case "compat-chat":
chatModel = model
}
}
if imageModel == nil {
t.Fatal("expected compat-image to be registered")
}
if imageModel.Type != internalregistry.OpenAIImageModelType {
t.Fatalf("image model type = %q, want %q", imageModel.Type, internalregistry.OpenAIImageModelType)
}
if imageModel.Thinking != nil {
t.Fatalf("image model thinking = %+v, want nil", imageModel.Thinking)
}
if chatModel == nil {
t.Fatal("expected compat-chat to be registered")
}
if chatModel.Type != "openai-compatibility" {
t.Fatalf("chat model type = %q, want openai-compatibility", chatModel.Type)
}
if chatModel.Thinking == nil {
t.Fatal("expected chat model to keep default thinking support")
}
}
+47 -13
View File
@@ -2,6 +2,7 @@ package usage
import (
"context"
"net/http"
"strings"
"sync"
"time"
@@ -11,19 +12,23 @@ import (
// Record contains the usage statistics captured for a single provider request.
type Record struct {
Provider string
Model string
Alias string
APIKey string
AuthID string
AuthIndex string
AuthType string
Source string
RequestedAt time.Time
Latency time.Duration
Failed bool
Fail Failure
Detail Detail
Provider string
Model string
Alias string
APIKey string
AuthID string
AuthIndex string
AuthType string
Source string
// ReasoningEffort stores the client-requested thinking level for request event logs.
ReasoningEffort string
RequestedAt time.Time
Latency time.Duration
Failed bool
Fail Failure
Detail Detail
// ResponseHeaders stores a snapshot of upstream response headers for usage sinks.
ResponseHeaders http.Header
}
// Failure holds HTTP failure metadata for an upstream request attempt.
@@ -44,6 +49,7 @@ type Detail struct {
}
type requestedModelAliasContextKey struct{}
type reasoningEffortContextKey struct{}
// WithRequestedModelAlias stores the client-requested model name for usage sinks.
func WithRequestedModelAlias(ctx context.Context, alias string) context.Context {
@@ -73,6 +79,34 @@ func RequestedModelAliasFromContext(ctx context.Context) string {
}
}
// WithReasoningEffort stores the client-requested reasoning effort for usage sinks.
func WithReasoningEffort(ctx context.Context, effort string) context.Context {
if ctx == nil {
ctx = context.Background()
}
effort = strings.TrimSpace(effort)
if effort == "" {
return ctx
}
return context.WithValue(ctx, reasoningEffortContextKey{}, effort)
}
// ReasoningEffortFromContext returns the client-requested reasoning effort stored in ctx.
func ReasoningEffortFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(reasoningEffortContextKey{})
switch value := raw.(type) {
case string:
return strings.TrimSpace(value)
case []byte:
return strings.TrimSpace(string(value))
default:
return ""
}
}
// Plugin consumes usage records emitted by the proxy runtime.
type Plugin interface {
HandleUsage(ctx context.Context, record Record)
+122 -1
View File
@@ -1,7 +1,10 @@
package proxyutil
import (
"bufio"
"context"
"crypto/tls"
"encoding/base64"
"fmt"
"net"
"net/http"
@@ -50,7 +53,7 @@ func Parse(raw string) (Setting, error) {
parsedURL, errParse := url.Parse(trimmed)
if errParse != nil {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
return setting, fmt.Errorf("parse proxy URL failed")
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
setting.Mode = ModeInvalid
@@ -134,6 +137,9 @@ func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
case ModeDirect:
return proxy.Direct, setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "http" || setting.URL.Scheme == "https" {
return &httpConnectDialer{proxyURL: setting.URL, dialer: proxy.Direct}, setting.Mode, nil
}
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
if errDialer != nil {
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
@@ -143,3 +149,118 @@ func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
return nil, setting.Mode, nil
}
}
type httpConnectDialer struct {
proxyURL *url.URL
dialer proxy.Dialer
}
func (d *httpConnectDialer) Dial(network, addr string) (net.Conn, error) {
proxyConn, errDial := d.dialer.Dial(network, proxyDialAddr(d.proxyURL))
if errDial != nil {
return nil, fmt.Errorf("dial HTTP proxy failed: %w", errDial)
}
conn := proxyConn
if d.proxyURL.Scheme == "https" {
tlsConn := tls.Client(conn, &tls.Config{ServerName: d.proxyURL.Hostname()})
if errHandshake := tlsConn.Handshake(); errHandshake != nil {
if errClose := conn.Close(); errClose != nil {
return nil, fmt.Errorf("HTTPS proxy TLS handshake failed: %w; close failed: %v", errHandshake, errClose)
}
return nil, fmt.Errorf("HTTPS proxy TLS handshake failed: %w", errHandshake)
}
conn = tlsConn
}
req := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
Host: addr,
Header: make(http.Header),
}
if d.proxyURL.User != nil {
req.Header.Set("Proxy-Authorization", proxyAuthorization(d.proxyURL.User))
}
if errWrite := req.Write(conn); errWrite != nil {
if errClose := conn.Close(); errClose != nil {
return nil, fmt.Errorf("write CONNECT request failed: %w; close failed: %v", errWrite, errClose)
}
return nil, fmt.Errorf("write CONNECT request failed: %w", errWrite)
}
reader := bufio.NewReader(conn)
resp, errRead := http.ReadResponse(reader, req)
if errRead != nil {
if errClose := conn.Close(); errClose != nil {
return nil, fmt.Errorf("read CONNECT response failed: %w; close failed: %v", errRead, errClose)
}
return nil, fmt.Errorf("read CONNECT response failed: %w", errRead)
}
if resp.StatusCode != http.StatusOK {
if resp.Body != nil {
_ = resp.Body.Close()
}
if errClose := conn.Close(); errClose != nil {
return nil, fmt.Errorf("proxy CONNECT returned status %s; close failed: %v", resp.Status, errClose)
}
return nil, fmt.Errorf("proxy CONNECT returned status %s", resp.Status)
}
if reader.Buffered() > 0 {
return &bufferedConn{Conn: conn, reader: reader}, nil
}
return conn, nil
}
func proxyDialAddr(proxyURL *url.URL) string {
port := proxyURL.Port()
if port == "" {
port = "80"
if proxyURL.Scheme == "https" {
port = "443"
}
}
return net.JoinHostPort(proxyURL.Hostname(), port)
}
func proxyAuthorization(user *url.Userinfo) string {
username := user.Username()
password, _ := user.Password()
encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
return "Basic " + encoded
}
// Redact returns a log-safe proxy URL with credentials and path-like data removed.
func Redact(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
parsedURL, errParse := url.Parse(trimmed)
if errParse != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
return "<invalid proxy URL>"
}
redacted := &url.URL{
Scheme: parsedURL.Scheme,
Host: parsedURL.Host,
}
if parsedURL.User != nil {
redacted.User = url.User("redacted")
}
return redacted.String()
}
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) {
if c.reader.Buffered() > 0 {
return c.reader.Read(p)
}
return c.Conn.Read(p)
}
+161
View File
@@ -1,8 +1,15 @@
package proxyutil
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
func mustDefaultTransport(t *testing.T) *http.Transport {
@@ -159,3 +166,157 @@ func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) {
t.Fatal("expected SOCKS5H transport to have custom DialContext")
}
}
func TestBuildDialerHTTPProxyCONNECT(t *testing.T) {
t.Parallel()
listener, errListen := net.Listen("tcp", "127.0.0.1:0")
if errListen != nil {
t.Fatalf("net.Listen returned error: %v", errListen)
}
defer func() {
if errClose := listener.Close(); errClose != nil {
t.Errorf("listener.Close returned error: %v", errClose)
}
}()
done := make(chan error, 1)
go func() {
conn, errAccept := listener.Accept()
if errAccept != nil {
done <- errAccept
return
}
defer func() { _ = conn.Close() }()
if errDeadline := conn.SetDeadline(time.Now().Add(5 * time.Second)); errDeadline != nil {
done <- errDeadline
return
}
req, errRead := http.ReadRequest(bufio.NewReader(conn))
if errRead != nil {
done <- fmt.Errorf("read CONNECT request failed: %w", errRead)
return
}
if req.Method != http.MethodConnect {
done <- fmt.Errorf("method = %s, want CONNECT", req.Method)
return
}
if req.Host != "target.example.com:443" {
done <- fmt.Errorf("host = %s, want target.example.com:443", req.Host)
return
}
wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
if gotAuth := req.Header.Get("Proxy-Authorization"); gotAuth != wantAuth {
done <- fmt.Errorf("Proxy-Authorization = %q, want %q", gotAuth, wantAuth)
return
}
if _, errWrite := io.WriteString(conn, "HTTP/1.1 200 Connection Established\r\n\r\nok"); errWrite != nil {
done <- fmt.Errorf("write CONNECT response failed: %w", errWrite)
return
}
buf := make([]byte, 4)
n, errReadTunnel := io.ReadFull(conn, buf)
if errReadTunnel != nil {
done <- fmt.Errorf("read tunneled payload failed after %d bytes: %w", n, errReadTunnel)
return
}
if string(buf) != "ping" {
done <- fmt.Errorf("tunneled payload = %q, want ping", string(buf))
return
}
done <- nil
}()
dialer, mode, errBuild := BuildDialer("http://user:pass@" + listener.Addr().String())
if errBuild != nil {
t.Fatalf("BuildDialer returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if dialer == nil {
t.Fatal("expected dialer, got nil")
}
conn, errDial := dialer.Dial("tcp", "target.example.com:443")
if errDial != nil {
t.Fatalf("dialer.Dial returned error: %v", errDial)
}
defer func() {
if errClose := conn.Close(); errClose != nil {
t.Errorf("conn.Close returned error: %v", errClose)
}
}()
buf := make([]byte, 2)
n, errRead := io.ReadFull(conn, buf)
if errRead != nil {
t.Fatalf("conn.Read returned error after %d bytes: %v", n, errRead)
}
if string(buf) != "ok" {
t.Fatalf("buffered tunnel payload = %q, want ok", string(buf))
}
if _, errWrite := conn.Write([]byte("ping")); errWrite != nil {
t.Fatalf("conn.Write returned error: %v", errWrite)
}
if errServer := <-done; errServer != nil {
t.Fatalf("proxy server returned error: %v", errServer)
}
}
func TestRedactProxyURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want string
}{
{
name: "with credentials",
input: "http://user:pass@proxy.example.com:8080/path?token=secret",
want: "http://redacted@proxy.example.com:8080",
},
{
name: "without credentials",
input: "socks5://proxy.example.com:1080",
want: "socks5://proxy.example.com:1080",
},
{
name: "invalid",
input: "bad-value",
want: "<invalid proxy URL>",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := Redact(tt.input); got != tt.want {
t.Fatalf("Redact() = %q, want %q", got, tt.want)
}
})
}
}
func TestParseErrorDoesNotExposeProxyCredentials(t *testing.T) {
t.Parallel()
input := "http://user:secret%@proxy.example.com:8080"
_, errParse := Parse(input)
if errParse == nil {
t.Fatal("expected Parse to return an error")
}
if strings.Contains(errParse.Error(), input) ||
strings.Contains(errParse.Error(), "user") ||
strings.Contains(errParse.Error(), "secret") {
t.Fatalf("parse error exposes proxy credentials: %q", errParse.Error())
}
}