 |
Thanks to VisionCoder for supporting this project. VisionCoder Developer Platform is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity.
@@ -53,7 +45,7 @@ VisionCoder is also offering our users a limited-time 此链接注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。
---
- |
-感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用此链接注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。 |
-
-
 |
感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接注册的用户,可享受首充8折,企业客户最高可享 7.5 折! |
@@ -35,10 +31,6 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过BmoPlus AI成品号专卖/代充注册下单的用户,可享GPT 官网订阅一折 的震撼价格! |
- |
-感谢 Poixe AI 对本项目的赞助!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 CLIProxyAPI 专属链接注册,充值额外赠送 $5 美金 |
-
-
 |
感谢 VisionCoder 对本项目的支持。VisionCoder 开发平台 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。
@@ -53,7 +45,7 @@ VisionCoder 还为我们的用户提供 [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
@@ -200,6 +208,10 @@ Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
+### [Playful Proxy API Panel (PPAP)](https://github.com/daishuge/playful-proxy-api-panel)
+
+一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。
+
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
diff --git a/README_JA.md b/README_JA.md
index 6360320c..debe4ae5 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -10,23 +10,19 @@ OAuth経由でOpenAI Codex(GPTモデル)およびClaude Codeもサポート
## スポンサー
-[](https://z.ai/subscribe?ic=8JVLJQFSKB)
+[](https://www.packyapi.com/register?aff=cliproxyapi)
-本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています。
+PackyCodeのスポンサーシップに感謝します!
-GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
+PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。
-GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
+PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:こちらのリンクから登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。
---
- |
-PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:こちらのリンクから登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。 |
-
-
 |
AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:こちらのリンクから登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます! |
@@ -35,10 +31,6 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらのBmoPlus AIアカウント専門店/代行チャージ経由でご登録・ご注文いただいたユーザー様は、GPTを 公式サイト価格の約1割(90% OFF) という驚異的な価格でご利用いただけます! |
- |
-Poixe AIのスポンサーシップに感謝します!Poixe AIは信頼できるAIモデルAPIサービスを提供しており、プラットフォームが提供するLLM APIを使って簡単にAI製品を構築できます。また、サプライヤーとしてプラットフォームに大規模モデルのリソースを提供し、収益を得ることも可能です。CLIProxyAPIの専用リンクから登録すると、チャージ時に追加で$5が付与されます。 |
-
-
 |
VisionCoderのご支援に感謝します!VisionCoder 開発プラットフォーム は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに Token Plan の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。 |
@@ -51,7 +43,7 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
- OAuthログインによるClaude Codeサポート
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
-- ストリーミングおよび非ストリーミングレスポンス
+- ストリーミング、非ストリーミング、および対応環境でのWebSocketレスポンス
- 関数呼び出し/ツールのサポート
- マルチモーダル入力サポート(テキストと画像)
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude)
@@ -72,6 +64,18 @@ CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
+## 使用量統計
+
+v6.10.0以降、CLIProxyAPIおよび [CPAMC](https://github.com/router-for-me/Cli-Proxy-API-Management-Center) プロジェクトには使用量統計機能がプリセットされなくなりました。使用量統計が必要な場合は、次のプロジェクトをご利用ください:
+
+### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
+
+CLIProxyAPI向けの独立した使用量永続化・可視化サービス。CLIProxyAPIデータを定期同期してSQLiteに保存し、集計APIと、使用量や各種統計を確認できる組み込みダッシュボードを提供します。
+
+### [CLIProxyAPI Usage Dashboard](https://github.com/zhanglunet/cliproxyapi-usage-dashboard)
+
+CLIProxyAPI向けのローカル優先の使用量・クォータダッシュボード。Redis互換の使用量キューからリクエストごとのToken使用量を収集してSQLiteに保存し、アカウント別・モデル別の日次および直近時間枠の使用量を可視化し、Codex 5h/7dクォータ残量をローカルWeb UIで表示します。
+
## Amp CLIサポート
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
@@ -120,7 +124,7 @@ macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTの
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
-CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
+CLIProxyAPI経由で既存のLLMサブスクリプション(Gemini、ChatGPT、Claude, etc.)を使用してSRT字幕を翻訳および検証する、クロスプラットフォームのデスクトップおよびWebアプリ - APIキー不要。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
@@ -178,9 +182,13 @@ CLIProxyAPIをネイティブGUIでラップしたクロスプラットフォー
CLIProxyAPI向けのすぐに使えるクロスプラットフォームのクォータ確認ツール。アカウントごとの codex 5h/7d クォータ表示、プラン別ソート、ステータス色分け、複数アカウントの集計分析に対応。
-### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
+### [CodexCliPlus](https://github.com/C4AL/CodexCliPlus)
-CLIProxyAPI向けの独立した使用量永続化・可視化サービス。CPAデータを定期同期してSQLiteに保存し、集計APIと、使用量や各種統計を確認できる組み込みダッシュボードを提供します。
+CLIProxyAPIを基盤にしたWindows向けのローカル優先Codex CLIデスクトップ管理プラットフォーム。ローカル設定、アカウント、実行状態の管理を簡素化し、ローカルユーザーにより包括的なCodex CLI体験を提供します。
+
+### [CLIProxy Pool Watch](https://github.com/murasame612/CLIProxyPoolWidget)
+
+CLIProxyAPIプール内のChatGPT/Codexアカウントクォータを監視するmacOSネイティブSwiftUIアプリ。Management APIを通じて、アカウントの可用性、Plus基準の容量、5時間/週次クォータバー、プラン重み、復元予測を表示します。
> [!NOTE]
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
@@ -199,6 +207,10 @@ CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
+### [Playful Proxy API Panel (PPAP)](https://github.com/daishuge/playful-proxy-api-panel)
+
+上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。
+
> [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
diff --git a/assets/packycode-cn.png b/assets/packycode-cn.png
new file mode 100644
index 00000000..3e34d6ca
Binary files /dev/null and b/assets/packycode-cn.png differ
diff --git a/assets/packycode-en.png b/assets/packycode-en.png
new file mode 100644
index 00000000..90f716e2
Binary files /dev/null and b/assets/packycode-en.png differ
diff --git a/cmd/server/main.go b/cmd/server/main.go
index b8707f0a..b10bc9c8 100644
--- a/cmd/server/main.go
+++ b/cmd/server/main.go
@@ -24,11 +24,11 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
- "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -417,7 +417,8 @@ func main() {
configFileExists = true
}
}
- usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
+ redisqueue.SetUsageStatisticsEnabled(cfg.UsageStatisticsEnabled)
+ redisqueue.SetRetentionSeconds(cfg.RedisUsageQueueRetentionSeconds)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg); err != nil {
diff --git a/config.example.yaml b/config.example.yaml
index 172e961f..d7d5a9f5 100644
--- a/config.example.yaml
+++ b/config.example.yaml
@@ -66,6 +66,10 @@ error-logs-max-files: 10
# When false, disable in-memory usage statistics aggregation
usage-statistics-enabled: false
+# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP).
+# Default: 60. Max: 3600.
+redis-usage-queue-retention-seconds: 60
+
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
diff --git a/docker-build.sh b/docker-build.sh
index 4538b807..ebe7d923 100644
--- a/docker-build.sh
+++ b/docker-build.sh
@@ -5,123 +5,13 @@
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
-# Hidden feature: Preserve usage statistics across rebuilds
-# Usage: ./docker-build.sh --with-usage
-# First run prompts for management API key, saved to temp/stats/.api_secret
-
set -euo pipefail
-STATS_DIR="temp/stats"
-STATS_FILE="${STATS_DIR}/.usage_backup.json"
-SECRET_FILE="${STATS_DIR}/.api_secret"
-WITH_USAGE=false
-
-get_port() {
- if [[ -f "config.yaml" ]]; then
- grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
- else
- echo "8317"
- fi
-}
-
-export_stats_api_secret() {
- if [[ -f "${SECRET_FILE}" ]]; then
- API_SECRET=$(cat "${SECRET_FILE}")
- else
- if [[ ! -d "${STATS_DIR}" ]]; then
- mkdir -p "${STATS_DIR}"
- fi
- echo "First time using --with-usage. Management API key required."
- read -r -p "Enter management key: " -s API_SECRET
- echo
- echo "${API_SECRET}" > "${SECRET_FILE}"
- chmod 600 "${SECRET_FILE}"
- fi
-}
-
-check_container_running() {
- local port
- port=$(get_port)
-
- if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
- echo "Error: cli-proxy-api service is not responding at localhost:${port}"
- echo "Please start the container first or use without --with-usage flag."
- exit 1
- fi
-}
-
-export_stats() {
- local port
- port=$(get_port)
-
- if [[ ! -d "${STATS_DIR}" ]]; then
- mkdir -p "${STATS_DIR}"
- fi
- check_container_running
- echo "Exporting usage statistics..."
- EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
- "http://localhost:${port}/v0/management/usage/export")
- HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
- RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
-
- if [[ "${HTTP_CODE}" != "200" ]]; then
- echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
- exit 1
- fi
-
- echo "${RESPONSE_BODY}" > "${STATS_FILE}"
- echo "Statistics exported to ${STATS_FILE}"
-}
-
-import_stats() {
- local port
- port=$(get_port)
-
- echo "Importing usage statistics..."
- IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
- -H "X-Management-Key: ${API_SECRET}" \
- -H "Content-Type: application/json" \
- -d @"${STATS_FILE}" \
- "http://localhost:${port}/v0/management/usage/import")
- IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
- IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
-
- if [[ "${IMPORT_CODE}" == "200" ]]; then
- echo "Statistics imported successfully"
- else
- echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
- fi
-
- rm -f "${STATS_FILE}"
-}
-
-wait_for_service() {
- local port
- port=$(get_port)
-
- echo "Waiting for service to be ready..."
- for i in {1..30}; do
- if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
- break
- fi
- sleep 1
- done
- sleep 2
-}
-
-case "${1:-}" in
- "")
- ;;
- "--with-usage")
- WITH_USAGE=true
- export_stats_api_secret
- ;;
- *)
- echo "Error: unknown option '${1}'. Did you mean '--with-usage'?"
- echo "Usage: ./docker-build.sh [--with-usage]"
- exit 1
- ;;
-esac
+if [[ "${1:-}" != "" ]]; then
+ echo "Error: unknown option '${1}'."
+ echo "Usage: ./docker-build.sh"
+ exit 1
+fi
# --- Step 1: Choose Environment ---
echo "Please select an option:"
@@ -133,14 +23,7 @@ read -r -p "Enter choice [1-2]: " choice
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
- if [[ "${WITH_USAGE}" == "true" ]]; then
- export_stats
- fi
docker compose up -d --remove-orphans --no-build
- if [[ "${WITH_USAGE}" == "true" ]]; then
- wait_for_service
- import_stats
- fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -167,18 +50,9 @@ case "$choice" in
--build-arg COMMIT="${COMMIT}" \
--build-arg BUILD_DATE="${BUILD_DATE}"
- if [[ "${WITH_USAGE}" == "true" ]]; then
- export_stats
- fi
-
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
- if [[ "${WITH_USAGE}" == "true" ]]; then
- wait_for_service
- import_stats
- fi
-
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
diff --git a/internal/api/handlers/management/api_key_usage.go b/internal/api/handlers/management/api_key_usage.go
new file mode 100644
index 00000000..3361da5d
--- /dev/null
+++ b/internal/api/handlers/management/api_key_usage.go
@@ -0,0 +1,107 @@
+package management
+
+import (
+ "net/http"
+ "strings"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+)
+
+type apiKeyUsageEntry struct {
+ Success int64 `json:"success"`
+ Failed int64 `json:"failed"`
+ RecentRequests []coreauth.RecentRequestBucket `json:"recent_requests"`
+}
+
+func mergeRecentRequestBuckets(dst, src []coreauth.RecentRequestBucket) []coreauth.RecentRequestBucket {
+ if len(dst) == 0 {
+ return src
+ }
+ if len(src) == 0 {
+ return dst
+ }
+ if len(dst) != len(src) {
+ n := len(dst)
+ if len(src) < n {
+ n = len(src)
+ }
+ for i := 0; i < n; i++ {
+ dst[i].Success += src[i].Success
+ dst[i].Failed += src[i].Failed
+ }
+ return dst
+ }
+ for i := range dst {
+ dst[i].Success += src[i].Success
+ dst[i].Failed += src[i].Failed
+ }
+ return dst
+}
+
+// GetAPIKeyUsage returns recent request buckets for all in-memory api_key auths,
+// grouped by provider and keyed by "base_url|api_key".
+func (h *Handler) GetAPIKeyUsage(c *gin.Context) {
+ if h == nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler not initialized"})
+ return
+ }
+
+ h.mu.Lock()
+ manager := h.authManager
+ h.mu.Unlock()
+ if manager == nil {
+ c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
+ return
+ }
+
+ now := time.Now()
+ out := make(map[string]map[string]apiKeyUsageEntry)
+ for _, auth := range manager.List() {
+ if auth == nil {
+ continue
+ }
+ kind, apiKey := auth.AccountInfo()
+ if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
+ continue
+ }
+ apiKey = strings.TrimSpace(apiKey)
+ if apiKey == "" {
+ continue
+ }
+ baseURL := ""
+ if auth.Attributes != nil {
+ baseURL = strings.TrimSpace(auth.Attributes["base_url"])
+ if baseURL == "" {
+ baseURL = strings.TrimSpace(auth.Attributes["base-url"])
+ }
+ }
+ compositeKey := baseURL + "|" + apiKey
+ provider := strings.ToLower(strings.TrimSpace(auth.Provider))
+ if provider == "" {
+ provider = "unknown"
+ }
+
+ recent := auth.RecentRequestsSnapshot(now)
+ providerBucket, ok := out[provider]
+ if !ok {
+ providerBucket = make(map[string]apiKeyUsageEntry)
+ out[provider] = providerBucket
+ }
+ if existing, exists := providerBucket[compositeKey]; exists {
+ existing.Success += auth.Success
+ existing.Failed += auth.Failed
+ existing.RecentRequests = mergeRecentRequestBuckets(existing.RecentRequests, recent)
+ providerBucket[compositeKey] = existing
+ continue
+ }
+ providerBucket[compositeKey] = apiKeyUsageEntry{
+ Success: auth.Success,
+ Failed: auth.Failed,
+ RecentRequests: recent,
+ }
+ }
+
+ c.JSON(http.StatusOK, out)
+}
diff --git a/internal/api/handlers/management/api_key_usage_test.go b/internal/api/handlers/management/api_key_usage_test.go
new file mode 100644
index 00000000..2880567f
--- /dev/null
+++ b/internal/api/handlers/management/api_key_usage_test.go
@@ -0,0 +1,95 @@
+package management
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+)
+
+func sumRecentRequestBuckets(buckets []coreauth.RecentRequestBucket) (int64, int64) {
+ var success int64
+ var failed int64
+ for _, bucket := range buckets {
+ success += bucket.Success
+ failed += bucket.Failed
+ }
+ return success, failed
+}
+
+func TestGetAPIKeyUsage_GroupsByProviderAndAPIKey(t *testing.T) {
+ t.Setenv("MANAGEMENT_PASSWORD", "")
+ gin.SetMode(gin.TestMode)
+
+ manager := coreauth.NewManager(nil, nil, nil)
+ if _, err := manager.Register(context.Background(), &coreauth.Auth{
+ ID: "codex-auth",
+ Provider: "codex",
+ Attributes: map[string]string{
+ "api_key": "codex-key",
+ "base_url": "https://codex.example.com",
+ },
+ }); err != nil {
+ t.Fatalf("register codex auth: %v", err)
+ }
+ if _, err := manager.Register(context.Background(), &coreauth.Auth{
+ ID: "claude-auth",
+ Provider: "claude",
+ Attributes: map[string]string{
+ "api_key": "claude-key",
+ "base_url": "https://claude.example.com",
+ },
+ }); err != nil {
+ t.Fatalf("register claude auth: %v", err)
+ }
+
+ manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: true})
+ manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: false})
+ manager.MarkResult(context.Background(), coreauth.Result{AuthID: "claude-auth", Provider: "claude", Model: "claude-4", Success: true})
+
+ h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := httptest.NewRequest(http.MethodGet, "/v0/management/api-key-usage", nil)
+ ginCtx.Request = req
+ h.GetAPIKeyUsage(ginCtx)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var payload map[string]map[string]apiKeyUsageEntry
+ if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
+ t.Fatalf("decode payload: %v", err)
+ }
+
+ codexEntry := payload["codex"]["https://codex.example.com|codex-key"]
+ if codexEntry.Success != 1 || codexEntry.Failed != 1 {
+ t.Fatalf("codex totals = %d/%d, want 1/1", codexEntry.Success, codexEntry.Failed)
+ }
+ if len(codexEntry.RecentRequests) != 20 {
+ t.Fatalf("codex buckets len = %d, want 20", len(codexEntry.RecentRequests))
+ }
+ codexSuccess, codexFailed := sumRecentRequestBuckets(codexEntry.RecentRequests)
+ if codexSuccess != 1 || codexFailed != 1 {
+ t.Fatalf("codex totals = %d/%d, want 1/1", codexSuccess, codexFailed)
+ }
+
+ claudeEntry := payload["claude"]["https://claude.example.com|claude-key"]
+ if claudeEntry.Success != 1 || claudeEntry.Failed != 0 {
+ t.Fatalf("claude totals = %d/%d, want 1/0", claudeEntry.Success, claudeEntry.Failed)
+ }
+ if len(claudeEntry.RecentRequests) != 20 {
+ t.Fatalf("claude buckets len = %d, want 20", len(claudeEntry.RecentRequests))
+ }
+ claudeSuccess, claudeFailed := sumRecentRequestBuckets(claudeEntry.RecentRequests)
+ if claudeSuccess != 1 || claudeFailed != 0 {
+ t.Fatalf("claude totals = %d/%d, want 1/0", claudeSuccess, claudeFailed)
+ }
+}
diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go
index 8f7b8c5e..285b3ae2 100644
--- a/internal/api/handlers/management/auth_files.go
+++ b/internal/api/handlers/management/auth_files.go
@@ -388,6 +388,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
"source": "memory",
"size": int64(0),
}
+ entry["success"] = auth.Success
+ entry["failed"] = auth.Failed
+ entry["recent_requests"] = auth.RecentRequestsSnapshot(time.Now())
if email := authEmail(auth); email != "" {
entry["email"] = email
}
@@ -2395,23 +2398,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
- // Check if this is a free user (gen-lang-client projects or free/legacy tier)
- isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
- strings.EqualFold(tierID, "FREE") ||
- strings.EqualFold(tierID, "LEGACY")
-
- if isFreeUser {
- // For free users, use backend project ID for preview model access
- log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
- log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
- finalProjectID = responseProjectID
- } else {
- // Pro users: keep requested project ID (original behavior)
- log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
- }
- } else {
- finalProjectID = responseProjectID
+ log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
+ log.Infof("Using backend project ID: %s", responseProjectID)
}
+ finalProjectID = responseProjectID
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
diff --git a/internal/api/handlers/management/auth_files_recent_requests_test.go b/internal/api/handlers/management/auth_files_recent_requests_test.go
new file mode 100644
index 00000000..979040f5
--- /dev/null
+++ b/internal/api/handlers/management/auth_files_recent_requests_test.go
@@ -0,0 +1,94 @@
+package management
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+)
+
+func TestListAuthFiles_IncludesRecentRequestsBuckets(t *testing.T) {
+ t.Setenv("MANAGEMENT_PASSWORD", "")
+ gin.SetMode(gin.TestMode)
+
+ manager := coreauth.NewManager(nil, nil, nil)
+ record := &coreauth.Auth{
+ ID: "runtime-only-auth-1",
+ Provider: "codex",
+ Attributes: map[string]string{
+ "runtime_only": "true",
+ },
+ Metadata: map[string]any{
+ "type": "codex",
+ },
+ }
+ if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
+ t.Fatalf("failed to register auth record: %v", errRegister)
+ }
+
+ h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
+ h.tokenStore = &memoryAuthStore{}
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
+ ginCtx.Request = req
+
+ h.ListAuthFiles(ginCtx)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
+ }
+
+ var payload map[string]any
+ if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
+ t.Fatalf("failed to decode list payload: %v", errUnmarshal)
+ }
+ filesRaw, ok := payload["files"].([]any)
+ if !ok {
+ t.Fatalf("expected files array, payload: %#v", payload)
+ }
+ if len(filesRaw) != 1 {
+ t.Fatalf("expected 1 auth entry, got %d", len(filesRaw))
+ }
+
+ fileEntry, ok := filesRaw[0].(map[string]any)
+ if !ok {
+ t.Fatalf("expected file entry object, got %#v", filesRaw[0])
+ }
+
+ if _, ok := fileEntry["success"].(float64); !ok {
+ t.Fatalf("expected success number, got %#v", fileEntry["success"])
+ }
+ if _, ok := fileEntry["failed"].(float64); !ok {
+ t.Fatalf("expected failed number, got %#v", fileEntry["failed"])
+ }
+
+ recentRaw, ok := fileEntry["recent_requests"].([]any)
+ if !ok {
+ t.Fatalf("expected recent_requests array, got %#v", fileEntry["recent_requests"])
+ }
+ if len(recentRaw) != 20 {
+ t.Fatalf("expected 20 recent_requests buckets, got %d", len(recentRaw))
+ }
+ for idx, item := range recentRaw {
+ bucket, ok := item.(map[string]any)
+ if !ok {
+ t.Fatalf("expected bucket object at %d, got %#v", idx, item)
+ }
+ if _, ok := bucket["time"].(string); !ok {
+ t.Fatalf("expected bucket time string at %d, got %#v", idx, bucket["time"])
+ }
+ if _, ok := bucket["success"].(float64); !ok {
+ t.Fatalf("expected bucket success number at %d, got %#v", idx, bucket["success"])
+ }
+ if _, ok := bucket["failed"].(float64); !ok {
+ t.Fatalf("expected bucket failed number at %d, got %#v", idx, bucket["failed"])
+ }
+ }
+}
diff --git a/internal/api/handlers/management/handler.go b/internal/api/handlers/management/handler.go
index af11366c..9abc8a5c 100644
--- a/internal/api/handlers/management/handler.go
+++ b/internal/api/handlers/management/handler.go
@@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
- "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"golang.org/x/crypto/bcrypt"
@@ -41,7 +40,6 @@ type Handler struct {
attemptsMu sync.Mutex
failedAttempts map[string]*attemptInfo // keyed by client IP
authManager *coreauth.Manager
- usageStats *usage.RequestStatistics
tokenStore coreauth.Store
localPassword string
allowRemoteOverride bool
@@ -60,7 +58,6 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
authManager: manager,
- usageStats: usage.GetRequestStatistics(),
tokenStore: sdkAuth.GetTokenStore(),
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
@@ -124,9 +121,6 @@ func (h *Handler) SetAuthManager(manager *coreauth.Manager) {
h.mu.Unlock()
}
-// SetUsageStatistics allows replacing the usage statistics reference.
-func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
-
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
diff --git a/internal/api/handlers/management/usage.go b/internal/api/handlers/management/usage.go
index 5f794089..dfddf503 100644
--- a/internal/api/handlers/management/usage.go
+++ b/internal/api/handlers/management/usage.go
@@ -2,78 +2,54 @@ package management
import (
"encoding/json"
+ "errors"
"net/http"
- "time"
+ "strconv"
+ "strings"
"github.com/gin-gonic/gin"
- "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
)
-type usageExportPayload struct {
- Version int `json:"version"`
- ExportedAt time.Time `json:"exported_at"`
- Usage usage.StatisticsSnapshot `json:"usage"`
-}
+type usageQueueRecord []byte
-type usageImportPayload struct {
- Version int `json:"version"`
- Usage usage.StatisticsSnapshot `json:"usage"`
-}
-
-// GetUsageStatistics returns the in-memory request statistics snapshot.
-func (h *Handler) GetUsageStatistics(c *gin.Context) {
- var snapshot usage.StatisticsSnapshot
- if h != nil && h.usageStats != nil {
- snapshot = h.usageStats.Snapshot()
+func (r usageQueueRecord) MarshalJSON() ([]byte, error) {
+ if json.Valid(r) {
+ return append([]byte(nil), r...), nil
}
- c.JSON(http.StatusOK, gin.H{
- "usage": snapshot,
- "failed_requests": snapshot.FailureCount,
- })
+ return json.Marshal(string(r))
}
-// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
-func (h *Handler) ExportUsageStatistics(c *gin.Context) {
- var snapshot usage.StatisticsSnapshot
- if h != nil && h.usageStats != nil {
- snapshot = h.usageStats.Snapshot()
- }
- c.JSON(http.StatusOK, usageExportPayload{
- Version: 1,
- ExportedAt: time.Now().UTC(),
- Usage: snapshot,
- })
-}
-
-// ImportUsageStatistics merges a previously exported usage snapshot into memory.
-func (h *Handler) ImportUsageStatistics(c *gin.Context) {
- if h == nil || h.usageStats == nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
+// GetUsageQueue pops queued usage records from the usage queue.
+func (h *Handler) GetUsageQueue(c *gin.Context) {
+ if h == nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
- data, err := c.GetRawData()
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
+ count, errCount := parseUsageQueueCount(c.Query("count"))
+ if errCount != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": errCount.Error()})
return
}
- var payload usageImportPayload
- if err := json.Unmarshal(data, &payload); err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
- return
- }
- if payload.Version != 0 && payload.Version != 1 {
- c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
- return
+ items := redisqueue.PopOldest(count)
+ records := make([]usageQueueRecord, 0, len(items))
+ for _, item := range items {
+ records = append(records, usageQueueRecord(append([]byte(nil), item...)))
}
- result := h.usageStats.MergeSnapshot(payload.Usage)
- snapshot := h.usageStats.Snapshot()
- c.JSON(http.StatusOK, gin.H{
- "added": result.Added,
- "skipped": result.Skipped,
- "total_requests": snapshot.TotalRequests,
- "failed_requests": snapshot.FailureCount,
- })
+ c.JSON(http.StatusOK, records)
+}
+
+func parseUsageQueueCount(value string) (int, error) {
+ value = strings.TrimSpace(value)
+ if value == "" {
+ return 1, nil
+ }
+ count, errCount := strconv.Atoi(value)
+ if errCount != nil || count <= 0 {
+ return 0, errors.New("count must be a positive integer")
+ }
+ return count, nil
}
diff --git a/internal/api/handlers/management/usage_test.go b/internal/api/handlers/management/usage_test.go
new file mode 100644
index 00000000..ca46d976
--- /dev/null
+++ b/internal/api/handlers/management/usage_test.go
@@ -0,0 +1,98 @@
+package management
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
+)
+
+func TestGetUsageQueuePopsRequestedRecords(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ withManagementUsageQueue(t, func() {
+ redisqueue.Enqueue([]byte(`{"id":1}`))
+ redisqueue.Enqueue([]byte(`{"id":2}`))
+ redisqueue.Enqueue([]byte(`{"id":3}`))
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
+
+ h := &Handler{}
+ h.GetUsageQueue(ginCtx)
+
+ if rec.Code != http.StatusOK {
+ t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
+ }
+
+ var payload []json.RawMessage
+ if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
+ t.Fatalf("unmarshal response: %v", errUnmarshal)
+ }
+ if len(payload) != 2 {
+ t.Fatalf("response records = %d, want 2", len(payload))
+ }
+ requireRecordID(t, payload[0], 1)
+ requireRecordID(t, payload[1], 2)
+
+ remaining := redisqueue.PopOldest(10)
+ if len(remaining) != 1 || string(remaining[0]) != `{"id":3}` {
+ t.Fatalf("remaining queue = %q, want third item only", remaining)
+ }
+ })
+}
+
+func TestGetUsageQueueInvalidCountDoesNotPop(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ withManagementUsageQueue(t, func() {
+ redisqueue.Enqueue([]byte(`{"id":1}`))
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=0", nil)
+
+ h := &Handler{}
+ h.GetUsageQueue(ginCtx)
+
+ if rec.Code != http.StatusBadRequest {
+ t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
+ }
+
+ remaining := redisqueue.PopOldest(10)
+ if len(remaining) != 1 || string(remaining[0]) != `{"id":1}` {
+ t.Fatalf("remaining queue = %q, want original item", remaining)
+ }
+ })
+}
+
+func withManagementUsageQueue(t *testing.T, fn func()) {
+ t.Helper()
+
+ prevQueueEnabled := redisqueue.Enabled()
+ redisqueue.SetEnabled(false)
+ redisqueue.SetEnabled(true)
+
+ defer func() {
+ redisqueue.SetEnabled(false)
+ redisqueue.SetEnabled(prevQueueEnabled)
+ }()
+
+ fn()
+}
+
+func requireRecordID(t *testing.T, raw json.RawMessage, want int) {
+ t.Helper()
+
+ var payload struct {
+ ID int `json:"id"`
+ }
+ if errUnmarshal := json.Unmarshal(raw, &payload); errUnmarshal != nil {
+ t.Fatalf("unmarshal record: %v", errUnmarshal)
+ }
+ if payload.ID != want {
+ t.Fatalf("record id = %d, want %d", payload.ID, want)
+ }
+}
diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go
index 707fe576..895c494e 100644
--- a/internal/api/modules/amp/response_rewriter.go
+++ b/internal/api/modules/amp/response_rewriter.go
@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
+// ampCanonicalToolNames maps tool names to the exact casing expected by the
+// Amp mode tool whitelist (case-sensitive match).
+var ampCanonicalToolNames = map[string]string{
+ "bash": "Bash",
+ "read": "Read",
+ "grep": "Grep",
+ "glob": "glob",
+ "task": "Task",
+ "check": "Check",
+}
+
+// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
+// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
+// which causes Amp's case-sensitive mode whitelist to reject them.
+func normalizeAmpToolNames(data []byte) []byte {
+ // Non-streaming: content[].name in tool_use blocks
+ for index, block := range gjson.GetBytes(data, "content").Array() {
+ if block.Get("type").String() != "tool_use" {
+ continue
+ }
+ name := block.Get("name").String()
+ if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
+ path := fmt.Sprintf("content.%d.name", index)
+ var err error
+ data, err = sjson.SetBytes(data, path, canonical)
+ if err != nil {
+ log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
+ }
+ }
+ }
+
+ // Streaming: content_block.name in content_block_start events
+ if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
+ name := gjson.GetBytes(data, "content_block.name").String()
+ if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
+ var err error
+ data, err = sjson.SetBytes(data, "content_block.name", canonical)
+ if err != nil {
+ log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
+ }
+ }
+ }
+
+ return data
+}
+
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
+ data = normalizeAmpToolNames(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
+ // Normalize tool names to canonical casing
+ data = normalizeAmpToolNames(data)
+
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go
index ac95dfc6..a3a350cb 100644
--- a/internal/api/modules/amp/response_rewriter_test.go
+++ b/internal/api/modules/amp/response_rewriter_test.go
@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
}
}
+func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
+ input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
+ result := normalizeAmpToolNames(input)
+
+ if !contains(result, []byte(`"name":"Bash"`)) {
+ t.Errorf("expected bash->Bash, got %s", string(result))
+ }
+ if !contains(result, []byte(`"name":"Read"`)) {
+ t.Errorf("expected read->Read, got %s", string(result))
+ }
+ if contains(result, []byte(`"name":"bash"`)) {
+ t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
+ }
+}
+
+func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
+ input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
+ result := normalizeAmpToolNames(input)
+
+ if !contains(result, []byte(`"name":"Grep"`)) {
+ t.Errorf("expected grep->Grep in streaming, got %s", string(result))
+ }
+}
+
+func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
+ input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
+ result := normalizeAmpToolNames(input)
+
+ if string(result) != string(input) {
+ t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
+ }
+}
+
+func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
+ input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
+ result := normalizeAmpToolNames(input)
+
+ if string(result) != string(input) {
+ t.Errorf("expected glob to remain lowercase, got %s", string(result))
+ }
+}
+
+func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
+ input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
+ result := normalizeAmpToolNames(input)
+
+ if string(result) != string(input) {
+ t.Errorf("expected no modification for unknown tool, got %s", string(result))
+ }
+}
+
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {
diff --git a/internal/api/server.go b/internal/api/server.go
index 8421357b..487ea571 100644
--- a/internal/api/server.go
+++ b/internal/api/server.go
@@ -31,7 +31,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
- "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
@@ -507,9 +506,6 @@ func (s *Server) registerManagementRoutes() {
mgmt := s.engine.Group("/v0/management")
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
- mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
- mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
- mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -554,6 +550,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
+ mgmt.GET("/api-key-usage", s.mgmt.GetAPIKeyUsage)
+ mgmt.GET("/usage-queue", s.mgmt.GetUsageQueue)
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
@@ -1000,7 +998,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
- usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
+ redisqueue.SetUsageStatisticsEnabled(cfg.UsageStatisticsEnabled)
+ }
+
+ if oldCfg == nil || oldCfg.RedisUsageQueueRetentionSeconds != cfg.RedisUsageQueueRetentionSeconds {
+ redisqueue.SetRetentionSeconds(cfg.RedisUsageQueueRetentionSeconds)
}
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
diff --git a/internal/api/server_test.go b/internal/api/server_test.go
index db1ef27d..fe37cb72 100644
--- a/internal/api/server_test.go
+++ b/internal/api/server_test.go
@@ -13,6 +13,7 @@ import (
gin "github.com/gin-gonic/gin"
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
@@ -84,6 +85,68 @@ func TestHealthz(t *testing.T) {
})
}
+func TestManagementUsageRequiresManagementAuthAndPopsArray(t *testing.T) {
+ t.Setenv("MANAGEMENT_PASSWORD", "test-management-key")
+
+ prevQueueEnabled := redisqueue.Enabled()
+ redisqueue.SetEnabled(false)
+ t.Cleanup(func() {
+ redisqueue.SetEnabled(false)
+ redisqueue.SetEnabled(prevQueueEnabled)
+ })
+
+ server := newTestServer(t)
+
+ redisqueue.Enqueue([]byte(`{"id":1}`))
+ redisqueue.Enqueue([]byte(`{"id":2}`))
+
+ missingKeyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
+ missingKeyRR := httptest.NewRecorder()
+ server.engine.ServeHTTP(missingKeyRR, missingKeyReq)
+ if missingKeyRR.Code != http.StatusUnauthorized {
+ t.Fatalf("missing key status = %d, want %d body=%s", missingKeyRR.Code, http.StatusUnauthorized, missingKeyRR.Body.String())
+ }
+
+ legacyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage?count=2", nil)
+ legacyReq.Header.Set("Authorization", "Bearer test-management-key")
+ legacyRR := httptest.NewRecorder()
+ server.engine.ServeHTTP(legacyRR, legacyReq)
+ if legacyRR.Code != http.StatusNotFound {
+ t.Fatalf("legacy usage status = %d, want %d body=%s", legacyRR.Code, http.StatusNotFound, legacyRR.Body.String())
+ }
+
+ authReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
+ authReq.Header.Set("Authorization", "Bearer test-management-key")
+ authRR := httptest.NewRecorder()
+ server.engine.ServeHTTP(authRR, authReq)
+ if authRR.Code != http.StatusOK {
+ t.Fatalf("authenticated status = %d, want %d body=%s", authRR.Code, http.StatusOK, authRR.Body.String())
+ }
+
+ var payload []json.RawMessage
+ if errUnmarshal := json.Unmarshal(authRR.Body.Bytes(), &payload); errUnmarshal != nil {
+ t.Fatalf("unmarshal response: %v body=%s", errUnmarshal, authRR.Body.String())
+ }
+ if len(payload) != 2 {
+ t.Fatalf("response records = %d, want 2", len(payload))
+ }
+ for i, raw := range payload {
+ var record struct {
+ ID int `json:"id"`
+ }
+ if errUnmarshal := json.Unmarshal(raw, &record); errUnmarshal != nil {
+ t.Fatalf("unmarshal record %d: %v", i, errUnmarshal)
+ }
+ if record.ID != i+1 {
+ t.Fatalf("record %d id = %d, want %d", i, record.ID, i+1)
+ }
+ }
+
+ if remaining := redisqueue.PopOldest(1); len(remaining) != 0 {
+ t.Fatalf("remaining queue = %q, want empty", remaining)
+ }
+}
+
func TestAmpProviderModelRoutes(t *testing.T) {
testCases := []struct {
name string
diff --git a/internal/auth/claude/anthropic_auth.go b/internal/auth/claude/anthropic_auth.go
index 6c770abf..60c71b35 100644
--- a/internal/auth/claude/anthropic_auth.go
+++ b/internal/auth/claude/anthropic_auth.go
@@ -6,15 +6,18 @@ package claude
import (
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
+ "sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
+ "golang.org/x/sync/singleflight"
)
// OAuth configuration constants for Claude/Anthropic
@@ -23,8 +26,94 @@ const (
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
+
+ claudeRefreshMinBackoff = 5 * time.Second
+ claudeRefreshMaxBackoff = 5 * time.Minute
)
+var (
+ claudeRefreshGroup singleflight.Group
+ claudeRefreshMu sync.Mutex
+ claudeRefreshBlock = make(map[string]time.Time)
+)
+
+type refreshHTTPError struct {
+ status int
+ message string
+ retryable bool
+}
+
+func (e *refreshHTTPError) Error() string {
+ return fmt.Sprintf("token refresh failed with status %d: %s", e.status, e.message)
+}
+
+func (e *refreshHTTPError) Retryable() bool {
+ return e != nil && e.retryable
+}
+
+func resetClaudeRefreshState() {
+ claudeRefreshMu.Lock()
+ defer claudeRefreshMu.Unlock()
+ claudeRefreshBlock = make(map[string]time.Time)
+ claudeRefreshGroup = singleflight.Group{}
+}
+
+func claudeRefreshBlockedUntil(refreshToken string) time.Time {
+ claudeRefreshMu.Lock()
+ defer claudeRefreshMu.Unlock()
+ return claudeRefreshBlock[refreshToken]
+}
+
+func setClaudeRefreshBlockedUntil(refreshToken string, until time.Time) {
+ claudeRefreshMu.Lock()
+ defer claudeRefreshMu.Unlock()
+ claudeRefreshBlock[refreshToken] = until
+}
+
+func clearClaudeRefreshBlockedUntil(refreshToken string) {
+ claudeRefreshMu.Lock()
+ defer claudeRefreshMu.Unlock()
+ delete(claudeRefreshBlock, refreshToken)
+}
+
+func clampClaudeRefreshBackoff(d time.Duration) time.Duration {
+ if d < claudeRefreshMinBackoff {
+ return claudeRefreshMinBackoff
+ }
+ if d > claudeRefreshMaxBackoff {
+ return claudeRefreshMaxBackoff
+ }
+ return d
+}
+
+func parseClaudeRetryAfter(resp *http.Response) time.Duration {
+ if resp == nil {
+ return claudeRefreshMinBackoff
+ }
+ if raw := strings.TrimSpace(resp.Header.Get("Retry-After")); raw != "" {
+ if seconds, err := time.ParseDuration(raw + "s"); err == nil {
+ return clampClaudeRefreshBackoff(seconds)
+ }
+ if when, err := http.ParseTime(raw); err == nil {
+ return clampClaudeRefreshBackoff(time.Until(when))
+ }
+ }
+ if raw := strings.TrimSpace(resp.Header.Get("Retry-After-Ms")); raw != "" {
+ if ms, err := time.ParseDuration(raw + "ms"); err == nil {
+ return clampClaudeRefreshBackoff(ms)
+ }
+ }
+ return claudeRefreshMinBackoff
+}
+
+func isClaudeRefreshRetryable(err error) bool {
+ var httpErr *refreshHTTPError
+ if errors.As(err, &httpErr) {
+ return httpErr.Retryable()
+ }
+ return true
+}
+
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
// It contains access token, refresh token, and associated user/organization information.
type tokenResponse struct {
@@ -242,6 +331,35 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
if refreshToken == "" {
return nil, fmt.Errorf("refresh token is required")
}
+ if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
+ return nil, &refreshHTTPError{
+ status: http.StatusTooManyRequests,
+ message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
+ retryable: false,
+ }
+ }
+
+ result, err, _ := claudeRefreshGroup.Do(refreshToken, func() (interface{}, error) {
+ return o.refreshTokensSingleFlight(context.WithoutCancel(ctx), refreshToken)
+ })
+ if err != nil {
+ return nil, err
+ }
+ tokenData, ok := result.(*ClaudeTokenData)
+ if !ok || tokenData == nil {
+ return nil, fmt.Errorf("token refresh failed: invalid single-flight result")
+ }
+ return tokenData, nil
+}
+
+func (o *ClaudeAuth) refreshTokensSingleFlight(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
+ if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
+ return nil, &refreshHTTPError{
+ status: http.StatusTooManyRequests,
+ message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
+ retryable: false,
+ }
+ }
reqBody := map[string]interface{}{
"client_id": ClientID,
@@ -276,7 +394,17 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
+ message := string(body)
+ if resp.StatusCode == http.StatusTooManyRequests {
+ retryAfter := parseClaudeRetryAfter(resp)
+ setClaudeRefreshBlockedUntil(refreshToken, time.Now().Add(retryAfter))
+ return nil, &refreshHTTPError{status: resp.StatusCode, message: message, retryable: false}
+ }
+ return nil, &refreshHTTPError{
+ status: resp.StatusCode,
+ message: message,
+ retryable: resp.StatusCode >= http.StatusInternalServerError,
+ }
}
// log.Debugf("Token response: %s", string(body))
@@ -287,6 +415,8 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
// Create token data
+ clearClaudeRefreshBlockedUntil(refreshToken)
+
return &ClaudeTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -348,6 +478,9 @@ func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken st
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
+ if !isClaudeRefreshRetryable(err) {
+ break
+ }
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
diff --git a/internal/auth/claude/anthropic_auth_test.go b/internal/auth/claude/anthropic_auth_test.go
new file mode 100644
index 00000000..0b14d083
--- /dev/null
+++ b/internal/auth/claude/anthropic_auth_test.go
@@ -0,0 +1,123 @@
+package claude
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+type roundTripFunc func(*http.Request) (*http.Response, error)
+
+func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
+ return f(req)
+}
+
+func TestRefreshTokensWithRetry_429BlocksImmediateReplay(t *testing.T) {
+ resetClaudeRefreshState()
+ defer resetClaudeRefreshState()
+
+ var calls int32
+ auth := &ClaudeAuth{
+ httpClient: &http.Client{
+ Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ atomic.AddInt32(&calls, 1)
+ return &http.Response{
+ StatusCode: http.StatusTooManyRequests,
+ Body: io.NopCloser(strings.NewReader(`{"error":"rate_limited"}`)),
+ Header: http.Header{"Retry-After": []string{"60"}},
+ Request: req,
+ }, nil
+ }),
+ },
+ }
+
+ _, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
+ if err == nil {
+ t.Fatalf("expected 429 refresh error")
+ }
+ if !strings.Contains(err.Error(), "status 429") {
+ t.Fatalf("expected status 429 in error, got %v", err)
+ }
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Fatalf("expected 1 refresh attempt after 429, got %d", got)
+ }
+
+ _, err = auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
+ if err == nil {
+ t.Fatalf("expected immediate blocked refresh error")
+ }
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Fatalf("expected blocked retry to avoid a second refresh call, got %d attempts", got)
+ }
+ if blockedUntil := claudeRefreshBlockedUntil("dummy_refresh_token"); !blockedUntil.After(time.Now()) {
+ t.Fatalf("expected blocked-until timestamp to be set, got %v", blockedUntil)
+ }
+}
+
+func TestRefreshTokens_DeduplicatesConcurrentRefresh(t *testing.T) {
+ resetClaudeRefreshState()
+ defer resetClaudeRefreshState()
+
+ var calls int32
+ started := make(chan struct{})
+ release := make(chan struct{})
+ var once sync.Once
+
+ auth := &ClaudeAuth{
+ httpClient: &http.Client{
+ Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
+ atomic.AddInt32(&calls, 1)
+ once.Do(func() { close(started) })
+ <-release
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader(`{
+ "access_token":"new-access",
+ "refresh_token":"new-refresh",
+ "token_type":"Bearer",
+ "expires_in":3600,
+ "account":{"email_address":"shared@example.com"}
+ }`)),
+ Header: make(http.Header),
+ Request: req,
+ }, nil
+ }),
+ },
+ }
+
+ results := make(chan *ClaudeTokenData, 2)
+ errs := make(chan error, 2)
+ runRefresh := func() {
+ td, err := auth.RefreshTokens(context.Background(), "shared-refresh-token")
+ results <- td
+ errs <- err
+ }
+
+ go runRefresh()
+ go runRefresh()
+
+ <-started
+ time.Sleep(20 * time.Millisecond)
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
+ }
+ close(release)
+
+ for i := 0; i < 2; i++ {
+ if err := <-errs; err != nil {
+ t.Fatalf("expected refresh to succeed, got %v", err)
+ }
+ td := <-results
+ if td == nil || td.AccessToken != "new-access" {
+ t.Fatalf("expected refreshed access token, got %#v", td)
+ }
+ }
+ if got := atomic.LoadInt32(&calls); got != 1 {
+ t.Fatalf("expected exactly 1 upstream refresh call, got %d", got)
+ }
+}
diff --git a/internal/cmd/login.go b/internal/cmd/login.go
index 16af718e..22404dac 100644
--- a/internal/cmd/login.go
+++ b/internal/cmd/login.go
@@ -333,42 +333,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
- // Check if this is a free user (gen-lang-client projects or free/legacy tier)
- isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
- strings.EqualFold(tierID, "FREE") ||
- strings.EqualFold(tierID, "LEGACY")
-
- if isFreeUser {
- // Interactive prompt for free users
- fmt.Printf("\nGoogle returned a different project ID:\n")
- fmt.Printf(" Requested (frontend): %s\n", projectID)
- fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
- fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
- fmt.Printf(" This is normal for free tier users.\n\n")
- fmt.Printf("Which project ID would you like to use?\n")
- fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
- fmt.Printf(" [2] Frontend: %s\n\n", projectID)
- fmt.Printf("Enter choice [1]: ")
-
- reader := bufio.NewReader(os.Stdin)
- choice, _ := reader.ReadString('\n')
- choice = strings.TrimSpace(choice)
-
- if choice == "2" {
- log.Infof("Using frontend project ID: %s", projectID)
- fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
- finalProjectID = projectID
- } else {
- log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
- finalProjectID = responseProjectID
- }
- } else {
- // Pro users: keep requested project ID (original behavior)
- log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
- }
- } else {
- finalProjectID = responseProjectID
+ log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
+ log.Infof("Using backend project ID: %s", responseProjectID)
}
+ finalProjectID = responseProjectID
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
diff --git a/internal/config/config.go b/internal/config/config.go
index 39c91127..46ce4f50 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -65,6 +65,11 @@ type Config struct {
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
+ // RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
+ // are retained in memory for the Redis RESP interface (LPOP/RPOP).
+ // Default: 60. Max: 3600.
+ RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
+
// DisableCooling disables quota cooldown scheduling when true.
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
@@ -609,6 +614,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.LogsMaxTotalSizeMB = 0
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
+ cfg.RedisUsageQueueRetentionSeconds = 60
cfg.DisableCooling = false
cfg.DisableImageGeneration = DisableImageGenerationOff
cfg.Pprof.Enable = false
@@ -671,6 +677,13 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
}
+ if cfg.RedisUsageQueueRetentionSeconds <= 0 {
+ cfg.RedisUsageQueueRetentionSeconds = 60
+ } else if cfg.RedisUsageQueueRetentionSeconds > 3600 {
+ log.WithField("value", cfg.RedisUsageQueueRetentionSeconds).Warn("redis-usage-queue-retention-seconds too large; clamping to 3600")
+ cfg.RedisUsageQueueRetentionSeconds = 3600
+ }
+
if cfg.MaxRetryCredentials < 0 {
cfg.MaxRetryCredentials = 0
}
diff --git a/internal/logging/requestmeta.go b/internal/logging/requestmeta.go
new file mode 100644
index 00000000..a28d7c62
--- /dev/null
+++ b/internal/logging/requestmeta.go
@@ -0,0 +1,62 @@
+package logging
+
+import (
+ "context"
+ "sync/atomic"
+)
+
+type endpointKey struct{}
+type responseStatusKey struct{}
+
+type responseStatusHolder struct {
+ status atomic.Int32
+}
+
+func WithEndpoint(ctx context.Context, endpoint string) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return context.WithValue(ctx, endpointKey{}, endpoint)
+}
+
+func GetEndpoint(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ if endpoint, ok := ctx.Value(endpointKey{}).(string); ok {
+ return endpoint
+ }
+ return ""
+}
+
+func WithResponseStatusHolder(ctx context.Context) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder); ok && holder != nil {
+ return ctx
+ }
+ return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{})
+}
+
+func SetResponseStatus(ctx context.Context, status int) {
+ if ctx == nil || status <= 0 {
+ return
+ }
+ holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
+ if !ok || holder == nil {
+ return
+ }
+ holder.status.Store(int32(status))
+}
+
+func GetResponseStatus(ctx context.Context) int {
+ if ctx == nil {
+ return 0
+ }
+ holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
+ if !ok || holder == nil {
+ return 0
+ }
+ return int(holder.status.Load())
+}
diff --git a/internal/misc/header_utils.go b/internal/misc/header_utils.go
index 5752a269..ac022a96 100644
--- a/internal/misc/header_utils.go
+++ b/internal/misc/header_utils.go
@@ -12,7 +12,7 @@ import (
const (
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
- GeminiCLIVersion = "0.31.0"
+ GeminiCLIVersion = "0.34.0"
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
@@ -46,7 +46,7 @@ func GeminiCLIUserAgent(model string) string {
if model == "" {
model = "unknown"
}
- return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
+ return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
}
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
diff --git a/internal/redisqueue/plugin.go b/internal/redisqueue/plugin.go
index a805e5da..b33bc8fd 100644
--- a/internal/redisqueue/plugin.go
+++ b/internal/redisqueue/plugin.go
@@ -3,13 +3,10 @@ package redisqueue
import (
"context"
"encoding/json"
- "net/http"
"strings"
"time"
- "github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
- internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
@@ -23,7 +20,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if p == nil {
return
}
- if !Enabled() || !internalusage.StatisticsEnabled() {
+ if !Enabled() || !UsageStatisticsEnabled() {
return
}
@@ -36,6 +33,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if modelName == "" {
modelName = "unknown"
}
+ aliasName := strings.TrimSpace(record.Alias)
+ if aliasName == "" {
+ aliasName = modelName
+ }
provider := strings.TrimSpace(record.Provider)
if provider == "" {
provider = "unknown"
@@ -46,13 +47,8 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
apiKey := strings.TrimSpace(record.APIKey)
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
- if requestID == "" {
- if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
- requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx))
- }
- }
- tokens := internalusage.TokenStats{
+ tokens := tokenStats{
InputTokens: record.Detail.InputTokens,
OutputTokens: record.Detail.OutputTokens,
ReasoningTokens: record.Detail.ReasoningTokens,
@@ -71,7 +67,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
failed = !resolveSuccess(ctx)
}
- detail := internalusage.RequestDetail{
+ detail := requestDetail{
Timestamp: timestamp,
LatencyMs: record.Latency.Milliseconds(),
Source: record.Source,
@@ -81,9 +77,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
payload, err := json.Marshal(queuedUsageDetail{
- RequestDetail: detail,
+ requestDetail: detail,
Provider: provider,
Model: modelName,
+ Alias: aliasName,
Endpoint: resolveEndpoint(ctx),
AuthType: authType,
APIKey: apiKey,
@@ -96,50 +93,43 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
type queuedUsageDetail struct {
- internalusage.RequestDetail
+ requestDetail
Provider string `json:"provider"`
Model string `json:"model"`
+ Alias string `json:"alias"`
Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"`
APIKey string `json:"api_key"`
RequestID string `json:"request_id"`
}
+type requestDetail struct {
+ Timestamp time.Time `json:"timestamp"`
+ LatencyMs int64 `json:"latency_ms"`
+ Source string `json:"source"`
+ AuthIndex string `json:"auth_index"`
+ Tokens tokenStats `json:"tokens"`
+ Failed bool `json:"failed"`
+}
+
+type tokenStats struct {
+ InputTokens int64 `json:"input_tokens"`
+ OutputTokens int64 `json:"output_tokens"`
+ ReasoningTokens int64 `json:"reasoning_tokens"`
+ CachedTokens int64 `json:"cached_tokens"`
+ TotalTokens int64 `json:"total_tokens"`
+}
+
func resolveSuccess(ctx context.Context) bool {
- if ctx == nil {
- return true
- }
- ginCtx, ok := ctx.Value("gin").(*gin.Context)
- if !ok || ginCtx == nil {
- return true
- }
- status := ginCtx.Writer.Status()
+ status := internallogging.GetResponseStatus(ctx)
if status == 0 {
return true
}
- return status < http.StatusBadRequest
+ return status < httpStatusBadRequest
}
func resolveEndpoint(ctx context.Context) string {
- if ctx == nil {
- return ""
- }
- ginCtx, ok := ctx.Value("gin").(*gin.Context)
- if !ok || ginCtx == nil || ginCtx.Request == nil {
- return ""
- }
-
- path := strings.TrimSpace(ginCtx.FullPath())
- if path == "" && ginCtx.Request.URL != nil {
- path = strings.TrimSpace(ginCtx.Request.URL.Path)
- }
- if path == "" {
- return ""
- }
-
- method := strings.TrimSpace(ginCtx.Request.Method)
- if method == "" {
- return path
- }
- return method + " " + path
+ return strings.TrimSpace(internallogging.GetEndpoint(ctx))
}
+
+const httpStatusBadRequest = 400
diff --git a/internal/redisqueue/plugin_test.go b/internal/redisqueue/plugin_test.go
index 907b8aee..8dcade90 100644
--- a/internal/redisqueue/plugin_test.go
+++ b/internal/redisqueue/plugin_test.go
@@ -10,20 +10,21 @@ import (
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
- internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
withEnabledQueue(t, func() {
- ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
- internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored")
- ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx)
+ ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id")
+ ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
+ ctx = internallogging.WithResponseStatusHolder(ctx)
+ internallogging.SetResponseStatus(ctx, http.StatusOK)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
+ Alias: "client-gpt",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
@@ -40,6 +41,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4")
+ requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "ctx-request-id")
@@ -49,14 +51,16 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) {
withEnabledQueue(t, func() {
- ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError)
- internallogging.SetGinRequestID(ginCtx, "gin-request-id")
- ctx := context.WithValue(context.Background(), "gin", ginCtx)
+ ctx := internallogging.WithRequestID(context.Background(), "gin-request-id")
+ ctx = internallogging.WithEndpoint(ctx, "GET /v1/responses")
+ ctx = internallogging.WithResponseStatusHolder(ctx)
+ internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4-mini",
+ Alias: "client-mini",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
@@ -73,6 +77,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4-mini")
+ requireStringField(t, payload, "alias", "client-mini")
requireStringField(t, payload, "endpoint", "GET /v1/responses")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "gin-request-id")
@@ -80,20 +85,63 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
})
}
+func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) {
+ withEnabledQueue(t, func() {
+ ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
+ ctx := context.WithValue(context.Background(), "gin", ginCtx)
+ ctx = internallogging.WithRequestID(ctx, "ctx-request-id")
+ ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
+ ctx = internallogging.WithResponseStatusHolder(ctx)
+ internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
+
+ mgr := coreusage.NewManager(16)
+ defer mgr.Stop()
+
+ mgr.Register(pluginFunc(func(_ context.Context, _ coreusage.Record) {
+ ginCtx.Request = httptest.NewRequest(http.MethodGet, "http://example.com/v1/responses", nil)
+ ginCtx.Status(http.StatusOK)
+ }))
+ mgr.Register(&usageQueuePlugin{})
+
+ mgr.Publish(ctx, coreusage.Record{
+ Provider: "openai",
+ Model: "gpt-5.4",
+ Alias: "client-gpt",
+ APIKey: "test-key",
+ AuthIndex: "0",
+ AuthType: "apikey",
+ Source: "user@example.com",
+ RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
+ Latency: 1500 * time.Millisecond,
+ Detail: coreusage.Detail{
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalTokens: 30,
+ },
+ })
+
+ payload := waitForSinglePayload(t, 2*time.Second)
+ requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
+ requireStringField(t, payload, "alias", "client-gpt")
+ requireStringField(t, payload, "request_id", "ctx-request-id")
+ requireBoolField(t, payload, "failed", true)
+ })
+}
+
func withEnabledQueue(t *testing.T, fn func()) {
t.Helper()
prevQueueEnabled := Enabled()
- prevStatsEnabled := internalusage.StatisticsEnabled()
+ prevUsageEnabled := UsageStatisticsEnabled()
SetEnabled(false)
SetEnabled(true)
- internalusage.SetStatisticsEnabled(true)
+ SetUsageStatisticsEnabled(true)
defer func() {
SetEnabled(false)
SetEnabled(prevQueueEnabled)
- internalusage.SetStatisticsEnabled(prevStatsEnabled)
+ SetUsageStatisticsEnabled(prevUsageEnabled)
}()
fn()
@@ -127,6 +175,29 @@ func popSinglePayload(t *testing.T) map[string]json.RawMessage {
return payload
}
+func waitForSinglePayload(t *testing.T, timeout time.Duration) map[string]json.RawMessage {
+ t.Helper()
+
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ items := PopOldest(10)
+ if len(items) == 0 {
+ time.Sleep(10 * time.Millisecond)
+ continue
+ }
+ if len(items) != 1 {
+ t.Fatalf("PopOldest() items = %d, want 1", len(items))
+ }
+ var payload map[string]json.RawMessage
+ if err := json.Unmarshal(items[0], &payload); err != nil {
+ t.Fatalf("unmarshal payload: %v", err)
+ }
+ return payload
+ }
+ t.Fatalf("timeout waiting for queued payload")
+ return nil
+}
+
func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) {
t.Helper()
@@ -143,6 +214,12 @@ func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, w
}
}
+type pluginFunc func(context.Context, coreusage.Record)
+
+func (fn pluginFunc) HandleUsage(ctx context.Context, record coreusage.Record) {
+ fn(ctx, record)
+}
+
func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) {
t.Helper()
diff --git a/internal/redisqueue/queue.go b/internal/redisqueue/queue.go
index 8a4b6742..2fea5839 100644
--- a/internal/redisqueue/queue.go
+++ b/internal/redisqueue/queue.go
@@ -6,7 +6,10 @@ import (
"time"
)
-const retentionWindow = time.Minute
+const (
+ defaultRetentionSeconds int64 = 60
+ maxRetentionSeconds int64 = 3600
+)
type queueItem struct {
enqueuedAt time.Time
@@ -20,10 +23,15 @@ type queue struct {
}
var (
- enabled atomic.Bool
- global queue
+ enabled atomic.Bool
+ retentionSeconds atomic.Int64
+ global queue
)
+func init() {
+ retentionSeconds.Store(defaultRetentionSeconds)
+}
+
func SetEnabled(value bool) {
enabled.Store(value)
if !value {
@@ -35,6 +43,16 @@ func Enabled() bool {
return enabled.Load()
}
+func SetRetentionSeconds(value int) {
+ normalized := int64(value)
+ if normalized <= 0 {
+ normalized = defaultRetentionSeconds
+ } else if normalized > maxRetentionSeconds {
+ normalized = maxRetentionSeconds
+ }
+ retentionSeconds.Store(normalized)
+}
+
func Enqueue(payload []byte) {
if !Enabled() {
return
@@ -110,7 +128,11 @@ func (q *queue) pruneLocked(now time.Time) {
return
}
- cutoff := now.Add(-retentionWindow)
+ windowSeconds := retentionSeconds.Load()
+ if windowSeconds <= 0 {
+ windowSeconds = defaultRetentionSeconds
+ }
+ cutoff := now.Add(-time.Duration(windowSeconds) * time.Second)
for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) {
q.head++
}
diff --git a/internal/redisqueue/usage_toggle.go b/internal/redisqueue/usage_toggle.go
new file mode 100644
index 00000000..dddbeca6
--- /dev/null
+++ b/internal/redisqueue/usage_toggle.go
@@ -0,0 +1,16 @@
+package redisqueue
+
+import "sync/atomic"
+
+var usageStatisticsEnabled atomic.Bool
+
+func init() {
+ usageStatisticsEnabled.Store(true)
+}
+
+// SetUsageStatisticsEnabled toggles whether usage records are enqueued into the redisqueue payload buffer.
+// This is controlled by the config field `usage-statistics-enabled` and the corresponding management API.
+func SetUsageStatisticsEnabled(enabled bool) { usageStatisticsEnabled.Store(enabled) }
+
+// UsageStatisticsEnabled reports whether the usage queue plugin should publish records.
+func UsageStatisticsEnabled() bool { return usageStatisticsEnabled.Load() }
diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go
index 73491d82..37e85377 100644
--- a/internal/runtime/executor/aistudio_executor.go
+++ b/internal/runtime/executor/aistudio_executor.go
@@ -285,7 +285,10 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if event.Err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
+ case <-ctx.Done():
+ }
return false
}
switch event.Type {
@@ -303,7 +306,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
+ case <-ctx.Done():
+ return false
+ }
}
break
}
@@ -319,14 +326,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
+ case <-ctx.Done():
+ return false
+ }
}
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
return false
case wsrelay.MessageTypeError:
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
+ case <-ctx.Done():
+ }
return false
}
return true
diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go
index 280c799a..418ed7b1 100644
--- a/internal/runtime/executor/antigravity_executor.go
+++ b/internal/runtime/executor/antigravity_executor.go
@@ -1357,17 +1357,28 @@ attemptLoop:
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
for i := range chunks {
- out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
for i := range tail {
- out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
} else {
reporter.EnsurePublished(ctx)
}
diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go
index 66432ac4..b22f4e44 100644
--- a/internal/runtime/executor/claude_executor.go
+++ b/internal/runtime/executor/claude_executor.go
@@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{
"notebookedit": "NotebookEdit",
}
-// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding.
-var oauthToolRenameReverseMap = func() map[string]string {
- m := make(map[string]string, len(oauthToolRenameMap))
- for k, v := range oauthToolRenameMap {
- m[v] = k
- }
- return m
-}()
+// The reverse map is now computed per-request in remapOAuthToolNames so that
+// only names the client actually caused us to rewrite are restored on the
+// response. A global reverse map — as used previously — corrupted responses
+// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase
+// alongside `glob` lowercase; the request flagged renames via `glob→Glob`,
+// then the global reverse map incorrectly rewrote every `Bash` in the
+// response to `bash`, causing Amp to reject the tool_use as unknown).
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
// even after remapping. Currently empty — all tools are mapped instead of removed.
@@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
- oauthToolNamesRemapped := false
- if oauthToken && !auth.ToolPrefixDisabled() {
- bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
- }
- // Remap third-party tool names to Claude Code equivalents and remove
- // tools without official counterparts. This prevents Anthropic from
- // fingerprinting the request as third-party via tool naming patterns.
+ var oauthToolNamesReverseMap map[string]string
if oauthToken {
- bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
+ bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
@@ -285,6 +278,10 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if stream {
+ if errValidate := validateClaudeStreamingResponse(data); errValidate != nil {
+ helps.RecordAPIResponseError(ctx, e.cfg, errValidate)
+ return resp, errValidate
+ }
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
@@ -294,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
}
- if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
- data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
- }
- // Reverse the OAuth tool name remap so the downstream client sees original names.
- if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
- data = reverseRemapOAuthToolNames(data)
- }
+ data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
var param any
out := sdktranslator.TranslateNonStream(
ctx,
@@ -375,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
- oauthToolNamesRemapped := false
- if oauthToken && !auth.ToolPrefixDisabled() {
- bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
- }
- // Remap third-party tool names to Claude Code equivalents and remove
- // tools without official counterparts. This prevents Anthropic from
- // fingerprinting the request as third-party via tool naming patterns.
+ var oauthToolNamesReverseMap map[string]string
if oauthToken {
- bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
+ bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
@@ -474,22 +459,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
- if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
- line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
- }
- if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
- line = reverseRemapOAuthToolNamesFromStreamLine(line)
- }
+ line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
cloned[len(line)] = '\n'
- out <- cliproxyexecutor.StreamChunk{Payload: cloned}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: cloned}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
return
}
@@ -504,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
- if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
- line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
- }
- if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
- line = reverseRemapOAuthToolNamesFromStreamLine(line)
- }
+ line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
chunks := sdktranslator.TranslateStream(
ctx,
to,
@@ -521,18 +503,83 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
¶m,
)
for i := range chunks {
- out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
+func validateClaudeStreamingResponse(data []byte) error {
+ scanner := bufio.NewScanner(bytes.NewReader(data))
+ scanner.Buffer(nil, 52_428_800)
+
+ hasData := false
+ hasMessageStart := false
+ hasMessageDelta := false
+
+ for scanner.Scan() {
+ line := bytes.TrimSpace(scanner.Bytes())
+ if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
+ continue
+ }
+ payload := bytes.TrimSpace(line[len("data:"):])
+ if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
+ continue
+ }
+ hasData = true
+ if !gjson.ValidBytes(payload) {
+ return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned malformed stream data"}
+ }
+
+ root := gjson.ParseBytes(payload)
+ switch root.Get("type").String() {
+ case "error":
+ message := strings.TrimSpace(root.Get("error.message").String())
+ if message == "" {
+ message = strings.TrimSpace(root.Get("error.type").String())
+ }
+ if message == "" {
+ message = "unknown upstream error"
+ }
+ return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned error event: " + message}
+ case "message_start":
+ message := root.Get("message")
+ if strings.TrimSpace(message.Get("id").String()) == "" || strings.TrimSpace(message.Get("model").String()) == "" {
+ return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream message_start is missing id or model"}
+ }
+ hasMessageStart = true
+ case "message_delta":
+ hasMessageDelta = true
+ }
+ }
+ if errScan := scanner.Err(); errScan != nil {
+ return errScan
+ }
+ if !hasData {
+ return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned empty stream response"}
+ }
+ if !hasMessageStart {
+ return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response is missing message_start"}
+ }
+ if !hasMessageDelta {
+ return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response ended before message completion"}
+ }
+ return nil
+}
+
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
@@ -559,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
- if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
- body = applyClaudeToolPrefix(body, claudeToolPrefix)
- }
- // Remap tool names for OAuth token requests to avoid third-party fingerprinting.
if isClaudeOAuthToken(apiKey) {
- body, _ = remapOAuthToolNames(body)
+ body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled())
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
@@ -661,7 +704,7 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
return auth, nil
}
svc := claudeauth.NewClaudeAuthWithProxyURL(e.cfg, auth.ProxyURL)
- td, err := svc.RefreshTokens(ctx, refreshToken)
+ td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3)
if err != nil {
return nil, err
}
@@ -1004,6 +1047,36 @@ func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
+// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name
+// transforms in the same order across request paths. Remap runs before prefixing
+// so any future non-empty prefix still composes correctly with the per-request
+// reverse map.
+func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) {
+ body, reverseMap := remapOAuthToolNames(body)
+ if !prefixDisabled {
+ body = applyClaudeToolPrefix(body, prefix)
+ }
+ return body, reverseMap
+}
+
+// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name
+// transforms for non-stream responses in reverse order.
+func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
+ if !prefixDisabled {
+ body = stripClaudeToolPrefixFromResponse(body, prefix)
+ }
+ return reverseRemapOAuthToolNames(body, reverseMap)
+}
+
+// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name
+// transforms for SSE lines in reverse order.
+func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
+ if !prefixDisabled {
+ line = stripClaudeToolPrefixFromStreamLine(line, prefix)
+ }
+ return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap)
+}
+
// remapOAuthToolNames renames third-party tool names to Claude Code equivalents
// and removes tools without an official counterpart. This prevents Anthropic from
// fingerprinting the request as a third-party client via tool naming patterns.
@@ -1011,8 +1084,25 @@ func isClaudeOAuthToken(apiKey string) bool {
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
// references in messages. Removed tools' corresponding tool_result blocks are preserved
// (they just become orphaned, which is safe for Claude).
-func remapOAuthToolNames(body []byte) ([]byte, bool) {
- renamed := false
+//
+// The returned map is keyed on the upstream (TitleCase) name and maps to the
+// client-supplied original name. Callers MUST pass this map to the reverse
+// functions so only names the client actually caused us to rewrite are restored
+// on the response. A global reverse map (the previous implementation) incorrectly
+// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`)
+// when any OTHER tool in the same request triggered a forward rename (e.g.
+// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash`
+// regardless of what the client originally sent.
+func remapOAuthToolNames(body []byte) ([]byte, map[string]string) {
+ reverseMap := make(map[string]string, len(oauthToolRenameMap))
+ recordRename := func(original, renamed string) {
+ // Preserve the first-seen original name if the same upstream name is
+ // produced from multiple call sites; they all map back identically.
+ if _, exists := reverseMap[renamed]; !exists {
+ reverseMap[renamed] = original
+ }
+ }
+
// 1. Rewrite tools array in a single pass (if present).
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
@@ -1045,7 +1135,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
updatedTool, err := sjson.Set(toolJSON, "name", newName)
if err == nil {
toolJSON = updatedTool
- renamed = true
+ recordRename(name, newName)
}
}
@@ -1070,7 +1160,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
body, _ = sjson.DeleteBytes(body, "tool_choice")
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
- renamed = true
+ recordRename(tcName, newName)
}
}
@@ -1090,14 +1180,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
- renamed = true
+ recordRename(name, newName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
- renamed = true
+ recordRename(toolName, newName)
}
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
@@ -1111,7 +1201,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, newName)
- renamed = true
+ recordRename(nestedToolName, newName)
}
}
return true
@@ -1124,13 +1214,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
})
}
- return body, renamed
+ return body, reverseMap
}
-// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
-// It maps Claude Code TitleCase names back to the original lowercase names so the
-// downstream client receives tool names it recognizes.
-func reverseRemapOAuthToolNames(body []byte) []byte {
+// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses
+// using the per-request map produced by remapOAuthToolNames. Names the client sent
+// that were NOT forward-renamed are passed through unchanged.
+func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte {
+ if len(reverseMap) == 0 {
+ return body
+ }
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
@@ -1140,13 +1233,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
switch partType {
case "tool_use":
name := part.Get("name").String()
- if origName, ok := oauthToolRenameReverseMap[name]; ok {
+ if origName, ok := reverseMap[name]; ok {
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
- if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
+ if origName, ok := reverseMap[toolName]; ok {
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
@@ -1156,8 +1249,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
return body
}
-// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines.
-func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
+// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE
+// stream lines, using the per-request reverseMap produced by remapOAuthToolNames.
+func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte {
+ if len(reverseMap) == 0 {
+ return line
+ }
payload := helps.JSONPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
@@ -1175,7 +1272,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
- if origName, ok := oauthToolRenameReverseMap[name]; ok {
+ if origName, ok := reverseMap[name]; ok {
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
if err != nil {
return line
@@ -1185,7 +1282,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
- if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
+ if origName, ok := reverseMap[toolName]; ok {
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
if err != nil {
return line
diff --git a/internal/runtime/executor/claude_executor_test.go b/internal/runtime/executor/claude_executor_test.go
index c1ce8fc0..2e914044 100644
--- a/internal/runtime/executor/claude_executor_test.go
+++ b/internal/runtime/executor/claude_executor_test.go
@@ -936,6 +936,113 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
}
}
+func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsEmptyClaudeStream(t *testing.T) {
+ _, err := executeOpenAIChatCompletionThroughClaude(t, "")
+ if err == nil {
+ t.Fatal("Execute error = nil, want empty stream error")
+ }
+ assertStatusErr(t, err, http.StatusBadGateway)
+ if !strings.Contains(err.Error(), "empty stream response") {
+ t.Fatalf("Execute error = %q, want empty stream response", err.Error())
+ }
+}
+
+func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsClaudeErrorEvent(t *testing.T) {
+ body := `data: {"type":"error","error":{"type":"overloaded_error","message":"upstream overloaded"}}` + "\n"
+ _, err := executeOpenAIChatCompletionThroughClaude(t, body)
+ if err == nil {
+ t.Fatal("Execute error = nil, want upstream error event")
+ }
+ assertStatusErr(t, err, http.StatusBadGateway)
+ if !strings.Contains(err.Error(), "upstream overloaded") {
+ t.Fatalf("Execute error = %q, want upstream overloaded", err.Error())
+ }
+}
+
+func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsIncompleteClaudeStream(t *testing.T) {
+ body := strings.Join([]string{
+ `data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
+ `data: {"type":"message_stop"}`,
+ ``,
+ }, "\n")
+
+ _, err := executeOpenAIChatCompletionThroughClaude(t, body)
+ if err == nil {
+ t.Fatal("Execute error = nil, want incomplete stream error")
+ }
+ assertStatusErr(t, err, http.StatusBadGateway)
+ if !strings.Contains(err.Error(), "ended before message completion") {
+ t.Fatalf("Execute error = %q, want incomplete stream error", err.Error())
+ }
+}
+
+func TestClaudeExecutor_ExecuteOpenAINonStreamConvertsValidClaudeStream(t *testing.T) {
+ body := strings.Join([]string{
+ `event: message_start`,
+ `data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
+ `event: content_block_delta`,
+ `data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ok"}}`,
+ `event: message_delta`,
+ `data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":2,"output_tokens":1}}`,
+ `event: message_stop`,
+ `data: {"type":"message_stop"}`,
+ ``,
+ }, "\n")
+
+ resp, err := executeOpenAIChatCompletionThroughClaude(t, body)
+ if err != nil {
+ t.Fatalf("Execute error: %v", err)
+ }
+ if got := gjson.GetBytes(resp.Payload, "id").String(); got != "msg_123" {
+ t.Fatalf("response id = %q, want msg_123; payload=%s", got, string(resp.Payload))
+ }
+ if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-3-5-sonnet-20241022" {
+ t.Fatalf("response model = %q, want claude-3-5-sonnet-20241022", got)
+ }
+ if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "ok" {
+ t.Fatalf("response content = %q, want ok", got)
+ }
+ if got := gjson.GetBytes(resp.Payload, "usage.total_tokens").Int(); got != 3 {
+ t.Fatalf("usage.total_tokens = %d, want 3", got)
+ }
+}
+
+func executeOpenAIChatCompletionThroughClaude(t *testing.T, upstreamBody string) (cliproxyexecutor.Response, error) {
+ t.Helper()
+
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = w.Write([]byte(upstreamBody))
+ }))
+ defer server.Close()
+
+ executor := NewClaudeExecutor(&config.Config{})
+ auth := &cliproxyauth.Auth{Attributes: map[string]string{
+ "api_key": "key-123",
+ "base_url": server.URL,
+ }}
+ payload := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hi"}]}`)
+
+ return executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
+ Model: "claude-3-5-sonnet-20241022",
+ Payload: payload,
+ }, cliproxyexecutor.Options{
+ SourceFormat: sdktranslator.FromString("openai"),
+ })
+}
+
+func assertStatusErr(t *testing.T, err error, want int) {
+ t.Helper()
+
+ status, ok := err.(interface{ StatusCode() int })
+ if !ok {
+ t.Fatalf("error %T does not expose StatusCode", err)
+ }
+ if got := status.StatusCode(); got != want {
+ t.Fatalf("StatusCode() = %d, want %d", got, want)
+ }
+}
+
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -1989,19 +2096,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
- out, renamed := remapOAuthToolNames(body)
- if renamed {
- t.Fatalf("renamed = true, want false")
+ out, reverseMap := remapOAuthToolNames(body)
+ if len(reverseMap) != 0 {
+ t.Fatalf("reverseMap = %v, want empty", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
- reversed := resp
- if renamed {
- reversed = reverseRemapOAuthToolNames(resp)
- }
+ reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
@@ -2010,20 +2114,150 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
- out, renamed := remapOAuthToolNames(body)
- if !renamed {
- t.Fatalf("renamed = false, want true")
+ out, reverseMap := remapOAuthToolNames(body)
+ if reverseMap["Bash"] != "bash" {
+ t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
- reversed := resp
- if renamed {
- reversed = reverseRemapOAuthToolNames(resp)
- }
+ reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
t.Fatalf("content.0.name = %q, want %q", got, "bash")
}
}
+
+// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression
+// test for a case where a single request contains both a TitleCase tool (which
+// must pass through unchanged) and a lowercase tool that we forward-rename.
+// Before the fix, triggering ANY forward rename caused the reverse pass to
+// lowercase every TitleCase tool in the response using a global reverse map,
+// corrupting tool names the client originally sent in TitleCase (notably Amp
+// CLI's `Bash`, which its registry lookup cannot find as `bash`).
+func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) {
+ body := []byte(`{"tools":[` +
+ `{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
+ `{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
+ `]}`)
+
+ out, reverseMap := remapOAuthToolNames(body)
+
+ // Forward: TitleCase `Bash` is not a forward-map key, must pass through.
+ if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
+ t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash")
+ }
+ // Forward: `glob` is a forward-map key, upstream sees `Glob`.
+ if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" {
+ t.Fatalf("tools.1.name = %q, want %q", got, "Glob")
+ }
+
+ // Reverse map records ONLY the rename that happened.
+ if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
+ t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
+ }
+
+ // Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`,
+ // reverseRemap MUST leave it alone.
+ bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
+ reversed := reverseRemapOAuthToolNames(bashResp, reverseMap)
+ if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
+ t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash")
+ }
+
+ // Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`,
+ // reverseRemap MUST restore the original `glob`.
+ globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`)
+ reversed = reverseRemapOAuthToolNames(globResp, reverseMap)
+ if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" {
+ t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob")
+ }
+}
+
+// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the
+// SSE streaming code path against the same mixed-case bug.
+func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) {
+ reverseMap := map[string]string{"Glob": "glob"}
+
+ // Bash block was never renamed, must pass through as-is.
+ bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`)
+ out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap)
+ if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
+ t.Fatalf("Bash should be preserved, got: %s", string(out))
+ }
+ if bytes.Contains(out, []byte(`"name":"bash"`)) {
+ t.Fatalf("Bash must not be lowercased, got: %s", string(out))
+ }
+
+ // Glob block IS in the reverseMap, must be restored to `glob`.
+ globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`)
+ out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap)
+ if !bytes.Contains(out, []byte(`"name":"glob"`)) {
+ t.Fatalf("Glob should be restored to glob, got: %s", string(out))
+ }
+}
+
+func TestPrepareClaudeOAuthToolNamesForUpstream_MixedCaseWithPrefix(t *testing.T) {
+ body := []byte(`{"tools":[` +
+ `{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
+ `{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
+ `],"messages":[{"role":"assistant","content":[` +
+ `{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}},` +
+ `{"type":"tool_use","id":"toolu_02","name":"glob","input":{}}` +
+ `]}]}`)
+
+ out, reverseMap := prepareClaudeOAuthToolNamesForUpstream(body, "proxy_", false)
+
+ if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Bash" {
+ t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Bash")
+ }
+ if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Glob" {
+ t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Glob")
+ }
+ if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Bash" {
+ t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Bash")
+ }
+ if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Glob" {
+ t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Glob")
+ }
+ if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
+ t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
+ }
+}
+
+func TestRestoreClaudeOAuthToolNamesFromResponse_MixedCaseWithPrefix(t *testing.T) {
+ reverseMap := map[string]string{"Glob": "glob"}
+ resp := []byte(`{"content":[` +
+ `{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}},` +
+ `{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}` +
+ `]}`)
+
+ out := restoreClaudeOAuthToolNamesFromResponse(resp, "proxy_", false, reverseMap)
+
+ if got := gjson.GetBytes(out, "content.0.name").String(); got != "Bash" {
+ t.Fatalf("content.0.name = %q, want %q", got, "Bash")
+ }
+ if got := gjson.GetBytes(out, "content.1.name").String(); got != "glob" {
+ t.Fatalf("content.1.name = %q, want %q", got, "glob")
+ }
+}
+
+func TestRestoreClaudeOAuthToolNamesFromStreamLine_MixedCaseWithPrefix(t *testing.T) {
+ reverseMap := map[string]string{"Glob": "glob"}
+
+ bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}}}`)
+ out := restoreClaudeOAuthToolNamesFromStreamLine(bashLine, "proxy_", false, reverseMap)
+ if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
+ t.Fatalf("Bash should be preserved, got: %s", string(out))
+ }
+ if bytes.Contains(out, []byte(`"name":"bash"`)) {
+ t.Fatalf("Bash must not be lowercased, got: %s", string(out))
+ }
+
+ globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}}`)
+ out = restoreClaudeOAuthToolNamesFromStreamLine(globLine, "proxy_", false, reverseMap)
+ if !bytes.Contains(out, []byte(`"name":"glob"`)) {
+ t.Fatalf("Glob should be restored to glob, got: %s", string(out))
+ }
+}
diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go
index aa8223f4..19cc8e75 100644
--- a/internal/runtime/executor/codex_executor.go
+++ b/internal/runtime/executor/codex_executor.go
@@ -30,8 +30,8 @@ import (
)
const (
- codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
- codexOriginator = "codex-tui"
+ codexUserAgent = "codex_cli_rs/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9"
+ codexOriginator = "codex_cli_rs"
codexDefaultImageToolModel = "gpt-image-2"
)
@@ -515,13 +515,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, translatedLine, ¶m)
for i := range chunks {
- out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go
index 40ba7e92..d6f1de86 100644
--- a/internal/runtime/executor/codex_websockets_executor.go
+++ b/internal/runtime/executor/codex_websockets_executor.go
@@ -188,7 +188,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
- body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body = normalizeCodexInstructions(body)
@@ -776,6 +775,11 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
+ default:
+ return "", fmt.Errorf("codex websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme)
+ }
+ if strings.TrimSpace(parsed.Host) == "" {
+ return "", fmt.Errorf("codex websockets executor: responses websocket URL host is empty")
}
return parsed.String(), nil
}
@@ -809,6 +813,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
+ setHeaderCasePreserved(headers, "session_id", cache.ID)
headers.Set("Conversation_id", cache.ID)
}
@@ -828,13 +833,19 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header.Clone()
}
- _, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
+ isAPIKey := codexAuthUsesAPIKey(auth)
+ cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
+ if isAPIKey {
+ ensureHeaderWithPriority(headers, ginHeaders, "User-Agent", "", "")
+ } else {
+ ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
+ }
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
@@ -845,16 +856,9 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
headers.Set("OpenAI-Beta", betaHeader)
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
- misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
- }
- headers.Del("User-Agent")
-
- isAPIKey := false
- if auth != nil && auth.Attributes != nil {
- if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
- isAPIKey = true
- }
+ ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", uuid.NewString())
}
+ ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", "")
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
@@ -864,7 +868,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
- headers.Set("Chatgpt-Account-Id", trimmed)
+ setHeaderCasePreserved(headers, "ChatGPT-Account-ID", trimmed)
}
}
}
@@ -879,6 +883,77 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers
}
+func codexAuthUsesAPIKey(auth *cliproxyauth.Auth) bool {
+ if auth == nil || auth.Attributes == nil {
+ return false
+ }
+ return strings.TrimSpace(auth.Attributes["api_key"]) != ""
+}
+
+func ensureHeaderCasePreserved(target http.Header, source http.Header, key, configValue, fallbackValue string) {
+ if target == nil {
+ return
+ }
+ if strings.TrimSpace(headerValueCaseInsensitive(target, key)) != "" {
+ return
+ }
+ if source != nil {
+ if val := strings.TrimSpace(headerValueCaseInsensitive(source, key)); val != "" {
+ setHeaderCasePreserved(target, key, val)
+ return
+ }
+ }
+ if val := strings.TrimSpace(configValue); val != "" {
+ setHeaderCasePreserved(target, key, val)
+ return
+ }
+ if val := strings.TrimSpace(fallbackValue); val != "" {
+ setHeaderCasePreserved(target, key, val)
+ }
+}
+
+func setHeaderCasePreserved(headers http.Header, key string, value string) {
+ if headers == nil {
+ return
+ }
+ key = strings.TrimSpace(key)
+ value = strings.TrimSpace(value)
+ if key == "" || value == "" {
+ return
+ }
+ deleteHeaderCaseInsensitive(headers, key)
+ headers[key] = []string{value}
+}
+
+func headerValueCaseInsensitive(headers http.Header, key string) string {
+ key = strings.TrimSpace(key)
+ if headers == nil || key == "" {
+ return ""
+ }
+ if val := strings.TrimSpace(headers.Get(key)); val != "" {
+ return val
+ }
+ for existingKey, values := range headers {
+ if !strings.EqualFold(existingKey, key) {
+ continue
+ }
+ for _, value := range values {
+ if trimmed := strings.TrimSpace(value); trimmed != "" {
+ return trimmed
+ }
+ }
+ }
+ return ""
+}
+
+func deleteHeaderCaseInsensitive(headers http.Header, key string) {
+ for existingKey := range headers {
+ if strings.EqualFold(existingKey, key) {
+ delete(headers, existingKey)
+ }
+ }
+}
+
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
@@ -962,25 +1037,55 @@ func parseCodexWebsocketError(payload []byte) (error, bool) {
return nil, false
}
- out := []byte(`{}`)
- if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
- raw := errNode.Raw
- if errNode.Type == gjson.String {
- raw = errNode.Raw
- }
- out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
- } else {
- out, _ = sjson.SetBytes(out, "error.type", "server_error")
- out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
- }
-
+ out := buildCodexWebsocketErrorPayload(payload, status)
headers := parseCodexWebsocketErrorHeaders(payload)
+ statusError := statusErr{code: status, msg: string(out)}
+ if retryAfter := parseCodexRetryAfter(status, out, time.Now()); retryAfter != nil {
+ statusError.retryAfter = retryAfter
+ } else if isCodexWebsocketConnectionLimitError(payload) {
+ retryAfter := time.Duration(0)
+ statusError.retryAfter = &retryAfter
+ }
return statusErrWithHeaders{
- statusErr: statusErr{code: status, msg: string(out)},
+ statusErr: statusError,
headers: headers,
}, true
}
+func buildCodexWebsocketErrorPayload(payload []byte, status int) []byte {
+ out := []byte(`{}`)
+ out, _ = sjson.SetBytes(out, "status", status)
+
+ if bodyNode := gjson.GetBytes(payload, "body"); bodyNode.Exists() {
+ out, _ = sjson.SetRawBytes(out, "body", []byte(bodyNode.Raw))
+ if bodyErrorNode := bodyNode.Get("error"); bodyErrorNode.Exists() {
+ out, _ = sjson.SetRawBytes(out, "error", []byte(bodyErrorNode.Raw))
+ return out
+ }
+ }
+
+ if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
+ out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw))
+ return out
+ }
+
+ out, _ = sjson.SetBytes(out, "error.type", "server_error")
+ out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
+ return out
+}
+
+func isCodexWebsocketConnectionLimitError(payload []byte) bool {
+ if len(payload) == 0 {
+ return false
+ }
+ for _, path := range []string{"error.code", "error.type", "body.error.code", "body.error.type", "code", "error"} {
+ if strings.TrimSpace(gjson.GetBytes(payload, path).String()) == "websocket_connection_limit_reached" {
+ return true
+ }
+ }
+ return false
+}
+
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
headersNode := gjson.GetBytes(payload, "headers")
if !headersNode.Exists() || !headersNode.IsObject() {
diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go
index dec356de..9c7bb591 100644
--- a/internal/runtime/executor/codex_websockets_executor_test.go
+++ b/internal/runtime/executor/codex_websockets_executor_test.go
@@ -1,15 +1,21 @@
package executor
import (
+ "bytes"
"context"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
+ "time"
"github.com/gin-gonic/gin"
+ "github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
+ cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
+ sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -32,14 +38,80 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
}
}
+func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T) {
+ upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
+ capturedPayload := make(chan []byte, 1)
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/responses" {
+ t.Fatalf("request path = %s, want /responses", r.URL.Path)
+ }
+ conn, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ t.Fatalf("upgrade websocket: %v", err)
+ }
+ defer func() { _ = conn.Close() }()
+
+ msgType, payload, err := conn.ReadMessage()
+ if err != nil {
+ t.Fatalf("read upstream websocket message: %v", err)
+ }
+ if msgType != websocket.TextMessage {
+ t.Fatalf("message type = %d, want text", msgType)
+ }
+ capturedPayload <- bytes.Clone(payload)
+
+ completed := []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
+ if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
+ t.Fatalf("write completed websocket message: %v", errWrite)
+ }
+ }))
+ defer server.Close()
+
+ exec := NewCodexWebsocketsExecutor(&config.Config{SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll}})
+ auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "sk-test", "base_url": server.URL}}
+ req := cliproxyexecutor.Request{
+ Model: "gpt-5-codex",
+ Payload: []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`),
+ }
+ opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("codex")}
+
+ if _, err := exec.Execute(context.Background(), auth, req, opts); err != nil {
+ t.Fatalf("Execute() error = %v", err)
+ }
+
+ select {
+ case payload := <-capturedPayload:
+ if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
+ t.Fatalf("upstream type = %s, want response.create; payload=%s", got, payload)
+ }
+ if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-1" {
+ t.Fatalf("upstream previous_response_id = %s, want resp-1; payload=%s", got, payload)
+ }
+ case <-time.After(5 * time.Second):
+ t.Fatal("timed out waiting for upstream websocket payload")
+ }
+}
+
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
- if got := headers.Get("User-Agent"); got != "" {
- t.Fatalf("User-Agent = %s, want empty", got)
+ if got := headers.Get("User-Agent"); got != codexUserAgent {
+ t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
+ }
+ if !strings.HasPrefix(codexUserAgent, codexOriginator+"/") {
+ t.Fatalf("default Codex User-Agent = %s, want prefix %s/", codexUserAgent, codexOriginator)
+ }
+ if strings.HasPrefix(codexUserAgent, "codex-tui/") {
+ t.Fatalf("default Codex User-Agent = %s, must not use stale codex-tui prefix", codexUserAgent)
+ }
+ if strings.Contains(codexUserAgent, "(codex-tui;") {
+ t.Fatalf("default Codex User-Agent = %s, must not include stale codex-tui suffix", codexUserAgent)
+ }
+ if got := headers.Get("Originator"); got != codexOriginator {
+ t.Fatalf("Originator = %s, want %s", got, codexOriginator)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
@@ -62,9 +134,11 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
}
ctx := contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
+ "User-Agent": "codex_cli_rs/0.1.0",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
+ "session_id": "sess-client",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
@@ -72,6 +146,9 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
if got := headers.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
+ if got := headers.Get("User-Agent"); got != "codex_cli_rs/0.1.0" {
+ t.Fatalf("User-Agent = %s, want %s", got, "codex_cli_rs/0.1.0")
+ }
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
@@ -81,6 +158,12 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
+ if got := headerValueCaseInsensitive(headers, "session_id"); got != "sess-client" {
+ t.Fatalf("session_id = %s, want sess-client", got)
+ }
+ if _, ok := headers["session_id"]; !ok {
+ t.Fatalf("expected lowercase session_id header key, got %#v", headers)
+ }
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
@@ -97,8 +180,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
- if got := headers.Get("User-Agent"); got != "" {
- t.Fatalf("User-Agent = %s, want empty", got)
+ if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
+ t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
@@ -129,8 +212,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
- if gotVal := got.Get("User-Agent"); gotVal != "" {
- t.Fatalf("User-Agent = %s, want empty", gotVal)
+ if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
+ t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
@@ -155,8 +238,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
- if got := headers.Get("User-Agent"); got != "" {
- t.Fatalf("User-Agent = %s, want empty", got)
+ if got := headers.Get("User-Agent"); got != "config-ua" {
+ t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
@@ -183,6 +266,131 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
+ if got := headers.Get("Originator"); got != "" {
+ t.Fatalf("Originator = %s, want empty", got)
+ }
+}
+
+func TestApplyCodexWebsocketHeadersPreservesExplicitAPIKeyUserAgent(t *testing.T) {
+ auth := &cliproxyauth.Auth{Provider: "codex", Attributes: map[string]string{"api_key": "sk-test"}}
+ ctx := contextWithGinHeaders(map[string]string{"User-Agent": "api-key-client/1.0", "Originator": "explicit-origin"})
+
+ headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "sk-test", nil)
+
+ if got := headers.Get("User-Agent"); got != "api-key-client/1.0" {
+ t.Fatalf("User-Agent = %s, want api-key-client/1.0", got)
+ }
+ if got := headers.Get("Originator"); got != "explicit-origin" {
+ t.Fatalf("Originator = %s, want explicit-origin", got)
+ }
+}
+
+func TestApplyCodexPromptCacheHeadersSetsLowercaseSessionAndLegacyConversation(t *testing.T) {
+ req := cliproxyexecutor.Request{Model: "gpt-5-codex", Payload: []byte(`{"prompt_cache_key":"cache-1"}`)}
+
+ _, headers := applyCodexPromptCacheHeaders("openai-response", req, []byte(`{"model":"gpt-5-codex"}`))
+
+ if got := headerValueCaseInsensitive(headers, "session_id"); got != "cache-1" {
+ t.Fatalf("session_id = %s, want cache-1", got)
+ }
+ if _, ok := headers["session_id"]; !ok {
+ t.Fatalf("expected lowercase session_id key, got %#v", headers)
+ }
+ if got := headers.Get("Conversation_id"); got != "cache-1" {
+ t.Fatalf("Conversation_id = %s, want cache-1", got)
+ }
+}
+
+func TestApplyCodexWebsocketHeadersUsesCanonicalAccountHeader(t *testing.T) {
+ auth := &cliproxyauth.Auth{Provider: "codex", Metadata: map[string]any{"account_id": "acct-1"}}
+
+ headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", nil)
+
+ if got := headerValueCaseInsensitive(headers, "ChatGPT-Account-ID"); got != "acct-1" {
+ t.Fatalf("ChatGPT-Account-ID = %s, want acct-1", got)
+ }
+ values, ok := headers["ChatGPT-Account-ID"]
+ if !ok {
+ t.Fatalf("expected exact ChatGPT-Account-ID key, got %#v", headers)
+ }
+ if len(values) != 1 || values[0] != "acct-1" {
+ t.Fatalf("ChatGPT-Account-ID values = %#v, want [acct-1]", values)
+ }
+}
+
+func TestBuildCodexResponsesWebsocketURLRequiresHTTPURL(t *testing.T) {
+ if got, err := buildCodexResponsesWebsocketURL("https://example.com/backend/responses"); err != nil || got != "wss://example.com/backend/responses" {
+ t.Fatalf("https URL = %q, %v; want wss URL", got, err)
+ }
+ if _, err := buildCodexResponsesWebsocketURL("ftp://example.com/responses"); err == nil {
+ t.Fatalf("expected unsupported scheme error")
+ }
+ if _, err := buildCodexResponsesWebsocketURL("https:///responses"); err == nil {
+ t.Fatalf("expected empty host error")
+ }
+}
+
+func TestParseCodexWebsocketErrorMarksConnectionLimitRetryable(t *testing.T) {
+ err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"code":"websocket_connection_limit_reached","message":"too many websockets"},"headers":{"retry-after":"1"}}`))
+ if !ok {
+ t.Fatalf("expected websocket error")
+ }
+ status, ok := err.(interface{ StatusCode() int })
+ if !ok || status.StatusCode() != http.StatusTooManyRequests {
+ t.Fatalf("status = %#v, want 429", err)
+ }
+ retryable, ok := err.(interface{ RetryAfter() *time.Duration })
+ if !ok || retryable.RetryAfter() == nil {
+ t.Fatalf("expected retryable websocket connection limit error")
+ }
+ if got := *retryable.RetryAfter(); got != 0 {
+ t.Fatalf("retryAfter = %v, want connection-limit fallback 0", got)
+ }
+ withHeaders, ok := err.(interface{ Headers() http.Header })
+ if !ok || withHeaders.Headers().Get("retry-after") != "1" {
+ t.Fatalf("headers = %#v, want retry-after", err)
+ }
+}
+
+func TestParseCodexWebsocketErrorUsesUsageLimitRetryMetadata(t *testing.T) {
+ err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"type":"usage_limit_reached","message":"usage limit reached","resets_in_seconds":7}}}`))
+ if !ok {
+ t.Fatalf("expected websocket error")
+ }
+
+ retryable, ok := err.(interface{ RetryAfter() *time.Duration })
+ if !ok || retryable.RetryAfter() == nil {
+ t.Fatalf("expected retryable usage limit websocket error")
+ }
+ if got := *retryable.RetryAfter(); got != 7*time.Second {
+ t.Fatalf("retryAfter = %v, want 7s", got)
+ }
+}
+
+func TestParseCodexWebsocketErrorPreservesWrappedBodyAndHeaders(t *testing.T) {
+ err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"code":"websocket_connection_limit_reached","type":"server_error","message":"too many websocket connections"}},"headers":{"x-request-id":"req-1"}}`))
+ if !ok {
+ t.Fatalf("expected websocket error")
+ }
+
+ parsed := gjson.Parse(err.Error())
+ if got := parsed.Get("status").Int(); got != http.StatusTooManyRequests {
+ t.Fatalf("wrapped status = %d, want 429; payload=%s", got, err.Error())
+ }
+ if got := parsed.Get("body.error.code").String(); got != "websocket_connection_limit_reached" {
+ t.Fatalf("wrapped body error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
+ }
+ if got := parsed.Get("error.code").String(); got != "websocket_connection_limit_reached" {
+ t.Fatalf("surface error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
+ }
+ retryable, ok := err.(interface{ RetryAfter() *time.Duration })
+ if !ok || retryable.RetryAfter() == nil {
+ t.Fatalf("expected body.error.code websocket connection limit to be retryable")
+ }
+ withHeaders, ok := err.(interface{ Headers() http.Header })
+ if !ok || withHeaders.Headers().Get("x-request-id") != "req-1" {
+ t.Fatalf("headers = %#v, want x-request-id", err)
+ }
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go
index 15e84572..b6210e6a 100644
--- a/internal/runtime/executor/gemini_cli_executor.go
+++ b/internal/runtime/executor/gemini_cli_executor.go
@@ -411,19 +411,30 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
for i := range segments {
- out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
for i := range segments {
- out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
return
}
reporter.EnsurePublished(ctx)
@@ -434,7 +445,10 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errRead}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
+ case <-ctx.Done():
+ }
return
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
@@ -442,12 +456,20 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
for i := range segments {
- out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
+ case <-ctx.Done():
+ return
+ }
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
for i := range segments {
- out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go
index 0e3c3ec6..2a6e9a6e 100644
--- a/internal/runtime/executor/gemini_executor.go
+++ b/internal/runtime/executor/gemini_executor.go
@@ -324,17 +324,28 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go
index b147fde9..17a93d51 100644
--- a/internal/runtime/executor/gemini_vertex_executor.go
+++ b/internal/runtime/executor/gemini_vertex_executor.go
@@ -338,6 +338,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
+ body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
}
action := getVertexAction(baseModel, false)
@@ -459,6 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
+ body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
@@ -570,6 +572,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
+ body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, true)
baseURL := vertexBaseURL(location)
@@ -656,17 +659,28 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -700,6 +714,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
+ body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
@@ -786,17 +801,28 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range lines {
- out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -818,6 +844,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
+ translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -907,6 +934,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
+ translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
diff --git a/internal/runtime/executor/helps/usage_helpers.go b/internal/runtime/executor/helps/usage_helpers.go
index c5e258c8..312a1d35 100644
--- a/internal/runtime/executor/helps/usage_helpers.go
+++ b/internal/runtime/executor/helps/usage_helpers.go
@@ -18,6 +18,7 @@ import (
type UsageReporter struct {
provider string
model string
+ alias string
authID string
authIndex string
authType string
@@ -29,9 +30,14 @@ type UsageReporter struct {
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := APIKeyFromContext(ctx)
+ alias := usage.RequestedModelAliasFromContext(ctx)
+ if alias == "" {
+ alias = model
+ }
reporter := &UsageReporter{
provider: provider,
model: model,
+ alias: strings.TrimSpace(alias),
requestedAt: time.Now(),
apiKey: apiKey,
source: resolveUsageSource(auth, apiKey),
@@ -139,6 +145,7 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
return usage.Record{
Provider: r.provider,
Model: model,
+ Alias: r.alias,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
diff --git a/internal/runtime/executor/helps/usage_helpers_test.go b/internal/runtime/executor/helps/usage_helpers_test.go
index c77335fd..ef2c7de5 100644
--- a/internal/runtime/executor/helps/usage_helpers_test.go
+++ b/internal/runtime/executor/helps/usage_helpers_test.go
@@ -1,6 +1,7 @@
package helps
import (
+ "context"
"testing"
"time"
@@ -107,6 +108,19 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
}
}
+func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
+ ctx := usage.WithRequestedModelAlias(context.Background(), "client-gpt")
+ reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
+
+ record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
+ if record.Model != "gpt-5.4" {
+ t.Fatalf("model = %q, want %q", record.Model, "gpt-5.4")
+ }
+ if record.Alias != "client-gpt" {
+ t.Fatalf("alias = %q, want %q", record.Alias, "client-gpt")
+ }
+}
+
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
reporter := &UsageReporter{
provider: "codex",
diff --git a/internal/runtime/executor/helps/vertex_payload_helpers.go b/internal/runtime/executor/helps/vertex_payload_helpers.go
new file mode 100644
index 00000000..4c84fae4
--- /dev/null
+++ b/internal/runtime/executor/helps/vertex_payload_helpers.go
@@ -0,0 +1,43 @@
+package helps
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+)
+
+// StripVertexOpenAIResponsesToolCallIDs removes OpenAI Responses call IDs that
+// Vertex rejects in Gemini functionCall/functionResponse payloads.
+func StripVertexOpenAIResponsesToolCallIDs(payload []byte, sourceFormat string) []byte {
+ if !strings.EqualFold(strings.TrimSpace(sourceFormat), "openai-response") {
+ return payload
+ }
+
+ contents := gjson.GetBytes(payload, "contents")
+ if !contents.IsArray() {
+ return payload
+ }
+
+ out := payload
+ for contentIndex, content := range contents.Array() {
+ parts := content.Get("parts")
+ if !parts.IsArray() {
+ continue
+ }
+ for partIndex, part := range parts.Array() {
+ if part.Get("functionCall.id").Exists() {
+ if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionCall.id", contentIndex, partIndex)); errDelete == nil {
+ out = updated
+ }
+ }
+ if part.Get("functionResponse.id").Exists() {
+ if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionResponse.id", contentIndex, partIndex)); errDelete == nil {
+ out = updated
+ }
+ }
+ }
+ }
+ return out
+}
diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go
index 3588c962..93125d9f 100644
--- a/internal/runtime/executor/kimi_executor.go
+++ b/internal/runtime/executor/kimi_executor.go
@@ -290,17 +290,28 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range chunks {
- out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range doneChunks {
- out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -322,7 +333,17 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
return body, nil
}
- out := body
+ msgs := messages.Array()
+ out, dropped, err := filterKimiEmptyAssistantMessages(body, msgs)
+ if err != nil {
+ return body, err
+ }
+ if dropped > 0 {
+ log.WithField("dropped_assistant_messages", dropped).Debug("kimi executor: dropped empty assistant messages")
+ }
+
+ messages = gjson.GetBytes(out, "messages")
+ msgs = messages.Array()
pending := make([]string, 0)
patched := 0
patchedReasoning := 0
@@ -340,7 +361,6 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
}
}
- msgs := messages.Array()
for msgIdx := range msgs {
msg := msgs[msgIdx]
role := strings.TrimSpace(msg.Get("role").String())
@@ -428,6 +448,96 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
return out, nil
}
+func filterKimiEmptyAssistantMessages(body []byte, msgs []gjson.Result) ([]byte, int, error) {
+ kept := make([]string, 0, len(msgs))
+ dropped := 0
+ for _, msg := range msgs {
+ if shouldDropKimiAssistantMessage(msg) {
+ dropped++
+ continue
+ }
+ kept = append(kept, msg.Raw)
+ }
+ if dropped == 0 {
+ return body, 0, nil
+ }
+
+ rawMessages := []byte("[" + strings.Join(kept, ",") + "]")
+ out, err := sjson.SetRawBytes(body, "messages", rawMessages)
+ if err != nil {
+ return body, 0, fmt.Errorf("kimi executor: failed to drop empty assistant messages: %w", err)
+ }
+ return out, dropped, nil
+}
+
+func shouldDropKimiAssistantMessage(msg gjson.Result) bool {
+ if strings.TrimSpace(msg.Get("role").String()) != "assistant" {
+ return false
+ }
+ if hasKimiToolCalls(msg) || hasKimiLegacyFunctionCall(msg) || hasKimiAssistantReasoning(msg) {
+ return false
+ }
+ return isKimiAssistantContentEmpty(msg.Get("content"))
+}
+
+func hasKimiToolCalls(msg gjson.Result) bool {
+ toolCalls := msg.Get("tool_calls")
+ return toolCalls.Exists() && toolCalls.IsArray() && len(toolCalls.Array()) > 0
+}
+
+func hasKimiLegacyFunctionCall(msg gjson.Result) bool {
+ functionCall := msg.Get("function_call")
+ if !functionCall.Exists() || functionCall.Type == gjson.Null {
+ return false
+ }
+ if functionCall.IsObject() && strings.TrimSpace(functionCall.Raw) == "{}" {
+ return false
+ }
+ return strings.TrimSpace(functionCall.Raw) != ""
+}
+
+func hasKimiAssistantReasoning(msg gjson.Result) bool {
+ reasoning := msg.Get("reasoning_content")
+ return reasoning.Exists() && strings.TrimSpace(reasoning.String()) != ""
+}
+
+func isKimiAssistantContentEmpty(content gjson.Result) bool {
+ if !content.Exists() || content.Type == gjson.Null {
+ return true
+ }
+ if content.Type == gjson.String {
+ return strings.TrimSpace(content.String()) == ""
+ }
+ if !content.IsArray() {
+ return false
+ }
+ for _, part := range content.Array() {
+ if !isKimiAssistantContentPartEmpty(part) {
+ return false
+ }
+ }
+ return true
+}
+
+func isKimiAssistantContentPartEmpty(part gjson.Result) bool {
+ if !part.Exists() || part.Type == gjson.Null {
+ return true
+ }
+ if part.Type == gjson.String {
+ return strings.TrimSpace(part.String()) == ""
+ }
+ if !part.IsObject() {
+ return false
+ }
+ if text := part.Get("text"); text.Exists() {
+ return strings.TrimSpace(text.String()) == ""
+ }
+ if strings.TrimSpace(part.Get("type").String()) == "text" {
+ return true
+ }
+ return strings.TrimSpace(part.Raw) == "{}"
+}
+
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
if hasLatest && strings.TrimSpace(latest) != "" {
return latest
diff --git a/internal/runtime/executor/kimi_executor_test.go b/internal/runtime/executor/kimi_executor_test.go
index 210ddb0e..f3de70f1 100644
--- a/internal/runtime/executor/kimi_executor_test.go
+++ b/internal/runtime/executor/kimi_executor_test.go
@@ -203,3 +203,70 @@ func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
}
}
+
+func TestNormalizeKimiToolMessageLinks_DropsEmptyAssistantWithoutToolLink(t *testing.T) {
+ body := []byte(`{
+ "messages":[
+ {"role":"user","content":"start"},
+ {"role":"assistant","content":""},
+ {"role":"assistant","content":" "},
+ {"role":"assistant","content":"","tool_calls":null},
+ {"role":"assistant","content":[{"type":"text","text":" "}]},
+ {"role":"assistant"},
+ {"role":"assistant","content":"keep"},
+ {"role":"user","content":"next"}
+ ]
+ }`)
+
+ out, err := normalizeKimiToolMessageLinks(body)
+ if err != nil {
+ t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
+ }
+
+ messages := gjson.GetBytes(out, "messages").Array()
+ if len(messages) != 3 {
+ t.Fatalf("messages length = %d, want 3, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
+ }
+ if got := messages[0].Get("content").String(); got != "start" {
+ t.Fatalf("messages.0.content = %q, want %q", got, "start")
+ }
+ if got := messages[1].Get("content").String(); got != "keep" {
+ t.Fatalf("messages.1.content = %q, want %q", got, "keep")
+ }
+ if got := messages[2].Get("content").String(); got != "next" {
+ t.Fatalf("messages.2.content = %q, want %q", got, "next")
+ }
+}
+
+func TestNormalizeKimiToolMessageLinks_PreservesAssistantWithToolLinkOrReasoning(t *testing.T) {
+ body := []byte(`{
+ "messages":[
+ {"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
+ {"role":"assistant","content":"","function_call":{"name":"legacy_call","arguments":"{}"}},
+ {"role":"assistant","content":"","reasoning_content":"thought"},
+ {"role":"assistant","content":[{"type":"text","text":" visible "}]}
+ ]
+ }`)
+
+ out, err := normalizeKimiToolMessageLinks(body)
+ if err != nil {
+ t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
+ }
+
+ messages := gjson.GetBytes(out, "messages").Array()
+ if len(messages) != 4 {
+ t.Fatalf("messages length = %d, want 4, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
+ }
+ if !messages[0].Get("tool_calls").Exists() {
+ t.Fatalf("messages.0.tool_calls should exist")
+ }
+ if !messages[1].Get("function_call").Exists() {
+ t.Fatalf("messages.1.function_call should exist")
+ }
+ if got := messages[2].Get("reasoning_content").String(); got != "thought" {
+ t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "thought")
+ }
+ if got := messages[3].Get("content.0.text").String(); got != " visible " {
+ t.Fatalf("messages.3.content.0.text = %q, want %q", got, " visible ")
+ }
+}
diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go
index 4e44a7ae..7e81637c 100644
--- a/internal/runtime/executor/openai_compat_executor.go
+++ b/internal/runtime/executor/openai_compat_executor.go
@@ -96,6 +96,12 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
+
+ translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
+ if err != nil {
+ return resp, err
+ }
+
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
@@ -105,11 +111,6 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
}
- translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
- if err != nil {
- return resp, err
- }
-
url := strings.TrimSuffix(baseURL, "/") + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
@@ -199,15 +200,16 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
- requestedModel := helps.PayloadRequestedModel(opts, req.Model)
- requestPath := helps.PayloadRequestPath(opts)
- translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
+ requestedModel := helps.PayloadRequestedModel(opts, req.Model)
+ requestPath := helps.PayloadRequestPath(opts)
+ translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
+
// Request usage data in the final streaming chunk so that token statistics
// are captured even when the upstream is an OpenAI-compatible provider.
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
@@ -281,32 +283,57 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
- if len(line) == 0 {
+ trimmedLine := bytes.TrimSpace(line)
+ if len(trimmedLine) == 0 {
continue
}
- if !bytes.HasPrefix(line, []byte("data:")) {
+ if !bytes.HasPrefix(trimmedLine, []byte("data:")) {
+ if bytes.HasPrefix(trimmedLine, []byte(":")) || bytes.HasPrefix(trimmedLine, []byte("event:")) ||
+ bytes.HasPrefix(trimmedLine, []byte("id:")) || bytes.HasPrefix(trimmedLine, []byte("retry:")) {
+ continue
+ }
+ if bytes.HasPrefix(trimmedLine, []byte("{")) || bytes.HasPrefix(trimmedLine, []byte("[")) {
+ streamErr := statusErr{code: http.StatusBadGateway, msg: string(trimmedLine)}
+ helps.RecordAPIResponseError(ctx, e.cfg, streamErr)
+ reporter.PublishFailure(ctx)
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: streamErr}:
+ case <-ctx.Done():
+ }
+ return
+ }
continue
}
- // OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
- // Pass through translator; it yields one or more chunks for the target schema.
- chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
+ // OpenAI-compatible streams must use SSE data lines.
+ chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(trimmedLine), ¶m)
for i := range chunks {
- out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
- out <- cliproxyexecutor.StreamChunk{Err: errScan}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
+ case <-ctx.Done():
+ }
} else {
// In case the upstream close the stream without a terminal [DONE] marker.
// Feed a synthetic done marker through the translator so pending
// response.completed events are still emitted exactly once.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m)
for i := range chunks {
- out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
+ select {
+ case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
+ case <-ctx.Done():
+ return
+ }
}
}
// Ensure we record the request if no usage chunk was ever seen
diff --git a/internal/runtime/executor/openai_compat_executor_compact_test.go b/internal/runtime/executor/openai_compat_executor_compact_test.go
index fe281262..49b2cccb 100644
--- a/internal/runtime/executor/openai_compat_executor_compact_test.go
+++ b/internal/runtime/executor/openai_compat_executor_compact_test.go
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
+ "strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -56,3 +57,125 @@ func TestOpenAICompatExecutorCompactPassthrough(t *testing.T) {
t.Fatalf("payload = %s", string(resp.Payload))
}
}
+
+func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T) {
+ var gotBody []byte
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ body, _ := io.ReadAll(r.Body)
+ gotBody = body
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"id":"chatcmpl_1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`))
+ }))
+ defer server.Close()
+
+ executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{
+ Payload: config.PayloadConfig{
+ Override: []config.PayloadRule{
+ {
+ Models: []config.PayloadModelRule{
+ {Name: "custom-openai", Protocol: "openai"},
+ },
+ Params: map[string]any{
+ "reasoning_effort": "low",
+ },
+ },
+ },
+ },
+ })
+ auth := &cliproxyauth.Auth{Attributes: map[string]string{
+ "base_url": server.URL + "/v1",
+ "api_key": "test",
+ }}
+ payload := []byte(`{"model":"custom-openai(high)","messages":[{"role":"user","content":"hi"}]}`)
+ _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
+ Model: "custom-openai(high)",
+ Payload: payload,
+ }, cliproxyexecutor.Options{
+ SourceFormat: sdktranslator.FromString("openai"),
+ Stream: false,
+ })
+ if err != nil {
+ t.Fatalf("Execute error: %v", err)
+ }
+ if got := gjson.GetBytes(gotBody, "reasoning_effort").String(); got != "low" {
+ t.Fatalf("reasoning_effort = %q, want %q; body=%s", got, "low", string(gotBody))
+ }
+}
+
+func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: error\n"))
+ _, _ = w.Write([]byte(`{"error":{"message":"upstream failed","type":"server_error"}}` + "\n"))
+ }))
+ defer server.Close()
+
+ executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
+ auth := &cliproxyauth.Auth{Attributes: map[string]string{
+ "base_url": server.URL + "/v1",
+ "api_key": "test",
+ }}
+ result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
+ Model: "openrouter-model",
+ Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
+ }, cliproxyexecutor.Options{
+ SourceFormat: sdktranslator.FromString("openai"),
+ Stream: true,
+ })
+ if err != nil {
+ t.Fatalf("ExecuteStream error: %v", err)
+ }
+
+ var gotErr error
+ for chunk := range result.Chunks {
+ if chunk.Err != nil {
+ gotErr = chunk.Err
+ break
+ }
+ }
+ if gotErr == nil {
+ t.Fatalf("expected plain JSON stream error")
+ }
+ if status, ok := gotErr.(interface{ StatusCode() int }); !ok || status.StatusCode() != http.StatusBadGateway {
+ t.Fatalf("stream error status = %v, want %d", gotErr, http.StatusBadGateway)
+ }
+ if !strings.Contains(gotErr.Error(), "upstream failed") {
+ t.Fatalf("stream error = %v", gotErr)
+ }
+}
+
+func TestOpenAICompatExecutorStreamSkipsKeepAliveUntilDataLine(t *testing.T) {
+ server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ _, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: ping\nid: 1\nretry: 1000\n"))
+ _, _ = w.Write([]byte(`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hello"},"finish_reason":null}]}` + "\n"))
+ }))
+ defer server.Close()
+
+ executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
+ auth := &cliproxyauth.Auth{Attributes: map[string]string{
+ "base_url": server.URL + "/v1",
+ "api_key": "test",
+ }}
+ result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
+ Model: "openrouter-model",
+ Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
+ }, cliproxyexecutor.Options{
+ SourceFormat: sdktranslator.FromString("openai"),
+ Stream: true,
+ })
+ if err != nil {
+ t.Fatalf("ExecuteStream error: %v", err)
+ }
+
+ var got strings.Builder
+ for chunk := range result.Chunks {
+ if chunk.Err != nil {
+ t.Fatalf("unexpected stream error: %v", chunk.Err)
+ }
+ got.Write(chunk.Payload)
+ }
+ if gjson.Get(got.String(), "choices.0.delta.content").String() != "hello" {
+ t.Fatalf("stream payload = %s", got.String())
+ }
+}
diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response.go b/internal/translator/claude/openai/chat-completions/claude_openai_response.go
index 1fd3f2ae..99c75238 100644
--- a/internal/translator/claude/openai/chat-completions/claude_openai_response.go
+++ b/internal/translator/claude/openai/chat-completions/claude_openai_response.go
@@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct {
CreatedAt int64
ResponseID string
FinishReason string
+ Usage claudeUsageTokens
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
}
+type claudeUsageTokens struct {
+ InputTokens int64
+ OutputTokens int64
+ CacheCreationInputTokens int64
+ CacheReadInputTokens int64
+ HasUsage bool
+}
+
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
@@ -36,15 +45,30 @@ type ToolCallAccumulator struct {
Arguments strings.Builder
}
-func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
- inputTokens := usage.Get("input_tokens").Int()
- completionTokens = usage.Get("output_tokens").Int()
- cachedTokens = usage.Get("cache_read_input_tokens").Int()
- cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
+func (u *claudeUsageTokens) Merge(usage gjson.Result) {
+ if !usage.Exists() {
+ return
+ }
+ u.HasUsage = true
+ if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() {
+ u.InputTokens = inputTokens.Int()
+ }
+ if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() {
+ u.OutputTokens = outputTokens.Int()
+ }
+ if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() {
+ u.CacheCreationInputTokens = cacheCreationInputTokens.Int()
+ }
+ if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() {
+ u.CacheReadInputTokens = cacheReadInputTokens.Int()
+ }
+}
- promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
+func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
+ cachedTokens = u.CacheReadInputTokens
+ promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens
+ completionTokens = u.OutputTokens
totalTokens = promptTokens + completionTokens
-
return promptTokens, completionTokens, totalTokens, cachedTokens
}
@@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
+ (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage"))
}
return [][]byte{template}
@@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
- promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
+ (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage)
+ promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage()
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
@@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var stopReason string
var contentParts []string
var reasoningParts []string
+ usageTokens := claudeUsageTokens{}
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
for _, chunk := range chunks {
@@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
+ usageTokens.Merge(message.Get("usage"))
}
case "content_block_start":
@@ -371,15 +399,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
- promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
- out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
- out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
- out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
- out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
+ usageTokens.Merge(usage)
}
}
}
+ if usageTokens.HasUsage {
+ promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage()
+ out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
+ out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
+ out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
+ out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
+ }
+
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.SetBytes(out, "id", messageID)
out, _ = sjson.SetBytes(out, "created", createdAt)
diff --git a/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go b/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go
index 7bd6eb1f..5a9a6d3a 100644
--- a/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go
+++ b/internal/translator/claude/openai/chat-completions/claude_openai_response_test.go
@@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin
}
}
+func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) {
+ ctx := context.Background()
+ var param any
+
+ ConvertClaudeResponseToOpenAI(
+ ctx,
+ "claude-opus-4-6",
+ nil,
+ nil,
+ []byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`),
+ ¶m,
+ )
+ out := ConvertClaudeResponseToOpenAI(
+ ctx,
+ "claude-opus-4-6",
+ nil,
+ nil,
+ []byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`),
+ ¶m,
+ )
+ if len(out) != 1 {
+ t.Fatalf("expected 1 chunk, got %d", len(out))
+ }
+
+ if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
+ t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
+ }
+ if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
+ t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
+ }
+ if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
+ t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
+ }
+ if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
+ t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
+ }
+}
+
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
@@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
+
+func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) {
+ rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" +
+ "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n")
+
+ out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
+
+ if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
+ t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
+ }
+ if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
+ t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
+ }
+ if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
+ t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
+ }
+ if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
+ t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
+ }
+}
diff --git a/internal/translator/claude/openai/responses/claude_openai-responses_request.go b/internal/translator/claude/openai/responses/claude_openai-responses_request.go
index 514129ca..c0479b87 100644
--- a/internal/translator/claude/openai/responses/claude_openai-responses_request.go
+++ b/internal/translator/claude/openai/responses/claude_openai-responses_request.go
@@ -339,25 +339,21 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
})
}
+ includedToolNames := map[string]struct{}{}
+ toolNameMap := map[string]string{}
+
// tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := []byte("[]")
tools.ForEach(func(_, tool gjson.Result) bool {
- tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
- if n := tool.Get("name"); n.Exists() {
- tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
+ convertedTools := convertResponsesToolToClaudeTools(tool, toolNameMap)
+ for _, tJSON := range convertedTools {
+ toolName := gjson.GetBytes(tJSON, "name").String()
+ if toolName != "" {
+ includedToolNames[toolName] = struct{}{}
+ }
+ toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
}
- if d := tool.Get("description"); d.Exists() {
- tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
- }
-
- if params := tool.Get("parameters"); params.Exists() {
- tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
- } else if params = tool.Get("parametersJsonSchema"); params.Exists() {
- tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
- }
-
- toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
return true
})
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
@@ -375,14 +371,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case "none":
// Leave unset; implies no tools
case "required":
- out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
+ if len(includedToolNames) > 0 {
+ out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
+ }
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
- toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
- toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
- out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
+ if fn == "" {
+ fn = toolChoice.Get("name").String()
+ }
+ if mappedName := toolNameMap[fn]; mappedName != "" {
+ fn = mappedName
+ }
+ if _, ok := includedToolNames[fn]; ok {
+ toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
+ toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
+ out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
+ }
}
default:
@@ -391,3 +397,167 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
return out
}
+
+func convertResponsesToolToClaudeTools(tool gjson.Result, toolNameMap map[string]string) [][]byte {
+ toolType := strings.TrimSpace(tool.Get("type").String())
+ switch toolType {
+ case "", "function":
+ if tJSON, ok := convertResponsesFunctionToolToClaude(tool, ""); ok {
+ return [][]byte{tJSON}
+ }
+ case "namespace":
+ return convertResponsesNamespaceToolToClaude(tool, toolNameMap)
+ case "web_search":
+ if tJSON, ok := convertResponsesWebSearchToolToClaude(tool); ok {
+ if name := gjson.GetBytes(tJSON, "name").String(); name != "" {
+ toolNameMap[name] = name
+ }
+ return [][]byte{tJSON}
+ }
+ default:
+ if isUnsupportedOpenAIBuiltinToolType(toolType) {
+ return nil
+ }
+ if tool.Get("name").String() != "" {
+ return [][]byte{[]byte(tool.Raw)}
+ }
+ }
+ return nil
+}
+
+func convertResponsesNamespaceToolToClaude(tool gjson.Result, toolNameMap map[string]string) [][]byte {
+ namespaceName := strings.TrimSpace(tool.Get("name").String())
+ children := tool.Get("tools")
+ if !children.Exists() || !children.IsArray() {
+ return nil
+ }
+
+ var out [][]byte
+ children.ForEach(func(_, child gjson.Result) bool {
+ childName := responsesToolName(child)
+ qualifiedName := qualifyResponsesNamespaceToolName(namespaceName, childName)
+ if tJSON, ok := convertResponsesFunctionToolToClaude(child, qualifiedName); ok {
+ out = append(out, tJSON)
+ toolNameMap[qualifiedName] = qualifiedName
+ if childName != "" {
+ toolNameMap[childName] = qualifiedName
+ }
+ }
+ return true
+ })
+ return out
+}
+
+func convertResponsesFunctionToolToClaude(tool gjson.Result, overrideName string) ([]byte, bool) {
+ name := strings.TrimSpace(overrideName)
+ if name == "" {
+ name = responsesToolName(tool)
+ }
+ if name == "" {
+ return nil, false
+ }
+
+ tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
+ tJSON, _ = sjson.SetBytes(tJSON, "name", name)
+ if d := responsesToolDescription(tool); d != "" {
+ tJSON, _ = sjson.SetBytes(tJSON, "description", d)
+ }
+ tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", normalizeClaudeToolInputSchema(responsesToolParameters(tool)))
+ return tJSON, true
+}
+
+func convertResponsesWebSearchToolToClaude(tool gjson.Result) ([]byte, bool) {
+ if externalWebAccess := tool.Get("external_web_access"); externalWebAccess.Exists() && !externalWebAccess.Bool() {
+ return nil, false
+ }
+
+ name := strings.TrimSpace(tool.Get("name").String())
+ if name == "" {
+ name = "web_search"
+ }
+ tJSON := []byte(`{"type":"web_search_20250305","name":""}`)
+ tJSON, _ = sjson.SetBytes(tJSON, "name", name)
+ if maxUses := tool.Get("max_uses"); maxUses.Exists() {
+ tJSON, _ = sjson.SetBytes(tJSON, "max_uses", maxUses.Int())
+ }
+ if allowedDomains := tool.Get("filters.allowed_domains"); allowedDomains.Exists() && allowedDomains.IsArray() {
+ tJSON, _ = sjson.SetRawBytes(tJSON, "allowed_domains", []byte(allowedDomains.Raw))
+ }
+ if userLocation := tool.Get("user_location"); userLocation.Exists() && userLocation.IsObject() {
+ tJSON, _ = sjson.SetRawBytes(tJSON, "user_location", []byte(userLocation.Raw))
+ }
+ return tJSON, true
+}
+
+func responsesToolName(tool gjson.Result) string {
+ if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
+ return name
+ }
+ return strings.TrimSpace(tool.Get("function.name").String())
+}
+
+func responsesToolDescription(tool gjson.Result) string {
+ if description := tool.Get("description").String(); description != "" {
+ return description
+ }
+ return tool.Get("function.description").String()
+}
+
+func responsesToolParameters(tool gjson.Result) gjson.Result {
+ for _, path := range []string{
+ "parameters",
+ "parametersJsonSchema",
+ "input_schema",
+ "function.parameters",
+ "function.parametersJsonSchema",
+ } {
+ if parameters := tool.Get(path); parameters.Exists() {
+ return parameters
+ }
+ }
+ return gjson.Result{}
+}
+
+func normalizeClaudeToolInputSchema(parameters gjson.Result) []byte {
+ raw := strings.TrimSpace(parameters.Raw)
+ if raw == "" || raw == "null" || !gjson.Valid(raw) {
+ return []byte(`{"type":"object","properties":{}}`)
+ }
+ result := gjson.Parse(raw)
+ if !result.IsObject() {
+ return []byte(`{"type":"object","properties":{}}`)
+ }
+ schema := []byte(raw)
+ schemaType := result.Get("type").String()
+ if schemaType == "" {
+ schema, _ = sjson.SetBytes(schema, "type", "object")
+ schemaType = "object"
+ }
+ if schemaType == "object" && !result.Get("properties").Exists() {
+ schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
+ }
+ return schema
+}
+
+func qualifyResponsesNamespaceToolName(namespaceName, childName string) string {
+ childName = strings.TrimSpace(childName)
+ if childName == "" || namespaceName == "" || strings.HasPrefix(childName, "mcp__") {
+ return childName
+ }
+ if strings.HasPrefix(childName, namespaceName) {
+ return childName
+ }
+ if strings.HasSuffix(namespaceName, "__") {
+ return namespaceName + childName
+ }
+ return namespaceName + "__" + childName
+}
+
+func isUnsupportedOpenAIBuiltinToolType(toolType string) bool {
+ switch toolType {
+ case "image_generation", "file_search", "code_interpreter", "computer_use_preview":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/internal/translator/claude/openai/responses/claude_openai-responses_response.go b/internal/translator/claude/openai/responses/claude_openai-responses_response.go
index ef2cc1f8..10d12c99 100644
--- a/internal/translator/claude/openai/responses/claude_openai-responses_response.go
+++ b/internal/translator/claude/openai/responses/claude_openai-responses_response.go
@@ -26,7 +26,8 @@ type claudeToResponsesState struct {
FuncNames map[int]string // index -> function name
FuncCallIDs map[int]string // index -> call id
// message text aggregation
- TextBuf strings.Builder
+ TextBuf strings.Builder
+ CurrentTextBuf strings.Builder
// reasoning state
ReasoningActive bool
ReasoningItemID string
@@ -80,6 +81,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.CreatedAt = time.Now().Unix()
// Reset per-message aggregation state
st.TextBuf.Reset()
+ st.CurrentTextBuf.Reset()
st.ReasoningBuf.Reset()
st.ReasoningActive = false
st.InTextBlock = false
@@ -128,6 +130,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if typ == "text" {
// open message item + content part
st.InTextBlock = true
+ st.CurrentTextBuf.Reset()
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
@@ -189,6 +192,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output
st.TextBuf.WriteString(t.String())
+ st.CurrentTextBuf.WriteString(t.String())
}
} else if dt == "input_json_delta" {
idx := int(root.Get("index").Int())
@@ -220,17 +224,21 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
case "content_block_stop":
idx := int(root.Get("index").Int())
if st.InTextBlock {
+ fullText := st.CurrentTextBuf.String()
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
+ done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
+ partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
+ final, _ = sjson.SetBytes(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false
} else if st.InFuncBlock {
diff --git a/internal/translator/codex/gemini-cli/codex_gemini-cli_request_test.go b/internal/translator/codex/gemini-cli/codex_gemini-cli_request_test.go
new file mode 100644
index 00000000..fc41452b
--- /dev/null
+++ b/internal/translator/codex/gemini-cli/codex_gemini-cli_request_test.go
@@ -0,0 +1,78 @@
+package geminiCLI
+
+import (
+ "testing"
+
+ "github.com/tidwall/gjson"
+)
+
+func TestConvertGeminiCLIRequestToCodex_PreservesSchemaPropertyNamedType(t *testing.T) {
+ input := []byte(`{
+ "request": {
+ "tools": [
+ {
+ "functionDeclarations": [
+ {
+ "name": "ask_user",
+ "description": "Ask the user one or more questions.",
+ "parametersJsonSchema": {
+ "type": "object",
+ "properties": {
+ "questions": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "header": {
+ "type": "string"
+ },
+ "type": {
+ "default": "choice",
+ "description": "Question type.",
+ "enum": [
+ "choice",
+ "text",
+ "yesno"
+ ],
+ "type": "string"
+ }
+ },
+ "required": [
+ "question",
+ "header",
+ "type"
+ ]
+ }
+ }
+ },
+ "required": [
+ "questions"
+ ]
+ }
+ }
+ ]
+ }
+ ]
+ }
+ }`)
+
+ out := ConvertGeminiCLIRequestToCodex("gpt-5.2", input, true)
+ tool := gjson.GetBytes(out, "tools.0")
+ if got := tool.Get("type").String(); got != "function" {
+ t.Fatalf("expected tool type %q, got %q; output=%s", "function", got, string(out))
+ }
+
+ typeProperty := tool.Get("parameters.properties.questions.items.properties.type")
+ if !typeProperty.IsObject() {
+ t.Fatalf("expected schema property named type to stay an object; output=%s", string(out))
+ }
+ if got := typeProperty.Get("type").String(); got != "string" {
+ t.Fatalf("expected schema property type %q, got %q; output=%s", "string", got, string(out))
+ }
+ if got := typeProperty.Get("default").String(); got != "choice" {
+ t.Fatalf("expected default %q, got %q; output=%s", "choice", got, string(out))
+ }
+ if got := typeProperty.Get("enum.2").String(); got != "yesno" {
+ t.Fatalf("expected enum value %q, got %q; output=%s", "yesno", got, string(out))
+ }
+}
diff --git a/internal/translator/codex/gemini/codex_gemini_request.go b/internal/translator/codex/gemini/codex_gemini_request.go
index 23dae7d7..37399700 100644
--- a/internal/translator/codex/gemini/codex_gemini_request.go
+++ b/internal/translator/codex/gemini/codex_gemini_request.go
@@ -284,7 +284,11 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
- out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
+ typeValue := gjson.GetBytes(out, fullPath)
+ if typeValue.Type != gjson.String {
+ continue
+ }
+ out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(typeValue.String()))
}
return out
diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request.go b/internal/translator/codex/openai/chat-completions/codex_openai_request.go
index 6cc701e7..569e06e3 100644
--- a/internal/translator/codex/openai/chat-completions/codex_openai_request.go
+++ b/internal/translator/codex/openai/chat-completions/codex_openai_request.go
@@ -121,13 +121,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
case "tool":
// Handle tool response messages as top-level function_call_output objects
toolCallID := m.Get("tool_call_id").String()
- content := m.Get("content").String()
+ content := m.Get("content")
// Create function_call_output object
funcOutput := []byte(`{}`)
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
- funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
+ funcOutput = setToolCallOutputContent(funcOutput, content)
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
default:
@@ -359,6 +359,91 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
return out
}
+func setToolCallOutputContent(funcOutput []byte, content gjson.Result) []byte {
+ switch {
+ case content.Type == gjson.String:
+ funcOutput, _ = sjson.SetBytes(funcOutput, "output", content.String())
+ case content.IsArray():
+ output := []byte(`[]`)
+ for _, item := range content.Array() {
+ output = appendToolOutputContentPart(output, item)
+ }
+ funcOutput, _ = sjson.SetRawBytes(funcOutput, "output", output)
+ default:
+ fallbackOutput := content.Raw
+ if fallbackOutput == "" {
+ fallbackOutput = content.String()
+ }
+ funcOutput, _ = sjson.SetBytes(funcOutput, "output", fallbackOutput)
+ }
+ return funcOutput
+}
+
+func appendToolOutputContentPart(output []byte, item gjson.Result) []byte {
+ switch item.Get("type").String() {
+ case "text":
+ part := []byte(`{}`)
+ part, _ = sjson.SetBytes(part, "type", "input_text")
+ part, _ = sjson.SetBytes(part, "text", item.Get("text").String())
+ output, _ = sjson.SetRawBytes(output, "-1", part)
+ case "image_url":
+ imageURL := item.Get("image_url.url").String()
+ fileID := item.Get("image_url.file_id").String()
+ if imageURL == "" && fileID == "" {
+ return appendToolOutputFallbackPart(output, item)
+ }
+ part := []byte(`{}`)
+ part, _ = sjson.SetBytes(part, "type", "input_image")
+ if imageURL != "" {
+ part, _ = sjson.SetBytes(part, "image_url", imageURL)
+ }
+ if fileID != "" {
+ part, _ = sjson.SetBytes(part, "file_id", fileID)
+ }
+ if detail := item.Get("image_url.detail").String(); detail != "" {
+ part, _ = sjson.SetBytes(part, "detail", detail)
+ }
+ output, _ = sjson.SetRawBytes(output, "-1", part)
+ case "file":
+ fileID := item.Get("file.file_id").String()
+ fileData := item.Get("file.file_data").String()
+ fileURL := item.Get("file.file_url").String()
+ if fileID == "" && fileData == "" && fileURL == "" {
+ return appendToolOutputFallbackPart(output, item)
+ }
+ part := []byte(`{}`)
+ part, _ = sjson.SetBytes(part, "type", "input_file")
+ if fileID != "" {
+ part, _ = sjson.SetBytes(part, "file_id", fileID)
+ }
+ if fileData != "" {
+ part, _ = sjson.SetBytes(part, "file_data", fileData)
+ }
+ if fileURL != "" {
+ part, _ = sjson.SetBytes(part, "file_url", fileURL)
+ }
+ if filename := item.Get("file.filename").String(); filename != "" {
+ part, _ = sjson.SetBytes(part, "filename", filename)
+ }
+ output, _ = sjson.SetRawBytes(output, "-1", part)
+ default:
+ output = appendToolOutputFallbackPart(output, item)
+ }
+ return output
+}
+
+func appendToolOutputFallbackPart(output []byte, item gjson.Result) []byte {
+ text := item.Raw
+ if text == "" {
+ text = item.String()
+ }
+ part := []byte(`{}`)
+ part, _ = sjson.SetBytes(part, "type", "input_text")
+ part, _ = sjson.SetBytes(part, "text", text)
+ output, _ = sjson.SetRawBytes(output, "-1", part)
+ return output
+}
+
// shortenNameIfNeeded applies the simple shortening rule for a single name.
// If the name length exceeds 64, it will try to preserve the "mcp__" prefix and last segment.
// Otherwise it truncates to 64 characters.
diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go b/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go
index 84c8dad2..e31db6d3 100644
--- a/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go
+++ b/internal/translator/codex/openai/chat-completions/codex_openai_request_test.go
@@ -176,6 +176,182 @@ func TestToolCallWithContent(t *testing.T) {
}
}
+func TestToolCallOutputWithMultimodalContent(t *testing.T) {
+ input := []byte(`{
+ "model": "gpt-4o",
+ "messages": [
+ {"role": "user", "content": "Show me the generated result."},
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {
+ "id": "call_output_1",
+ "type": "function",
+ "function": {"name": "render_output", "arguments": "{}"}
+ }
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_output_1",
+ "content": [
+ {"type":"text","text":"Rendered result attached."},
+ {"type":"image_url","image_url":{"url":"https://example.com/generated.png","detail":"high"}},
+ {"type":"image_url","image_url":{"file_id":"file-img-123"}},
+ {"type":"file","file":{"file_id":"file-doc-123","filename":"doc.pdf"}},
+ {"type":"file","file":{"file_data":"SGVsbG8=","filename":"inline.txt"}},
+ {"type":"file","file":{"file_url":"https://example.com/report.pdf","filename":"report.pdf"}}
+ ]
+ }
+ ],
+ "tools": [
+ {
+ "type": "function",
+ "function": {"name": "render_output", "description": "Render output", "parameters": {"type": "object", "properties": {}}}
+ }
+ ]
+ }`)
+
+ out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
+ result := string(out)
+
+ output := gjson.Get(result, "input.2.output")
+ if !output.IsArray() {
+ t.Fatalf("expected tool output to be an array, got: %s", output.Raw)
+ }
+
+ parts := output.Array()
+ if len(parts) != 6 {
+ t.Fatalf("expected 6 output parts, got %d: %s", len(parts), output.Raw)
+ }
+ if parts[0].Get("type").String() != "input_text" || parts[0].Get("text").String() != "Rendered result attached." {
+ t.Fatalf("part 0: expected input_text with rendered text, got %s", parts[0].Raw)
+ }
+ if parts[1].Get("type").String() != "input_image" {
+ t.Fatalf("part 1: expected input_image, got %s", parts[1].Raw)
+ }
+ if parts[1].Get("image_url").String() != "https://example.com/generated.png" {
+ t.Errorf("part 1: unexpected image_url %s", parts[1].Get("image_url").String())
+ }
+ if parts[1].Get("detail").String() != "high" {
+ t.Errorf("part 1: unexpected detail %s", parts[1].Get("detail").String())
+ }
+ if parts[2].Get("type").String() != "input_image" || parts[2].Get("file_id").String() != "file-img-123" {
+ t.Fatalf("part 2: expected file_id-backed input_image, got %s", parts[2].Raw)
+ }
+ if parts[3].Get("type").String() != "input_file" || parts[3].Get("file_id").String() != "file-doc-123" {
+ t.Fatalf("part 3: expected file_id-backed input_file, got %s", parts[3].Raw)
+ }
+ if parts[3].Get("filename").String() != "doc.pdf" {
+ t.Errorf("part 3: unexpected filename %s", parts[3].Get("filename").String())
+ }
+ if parts[4].Get("type").String() != "input_file" || parts[4].Get("file_data").String() != "SGVsbG8=" {
+ t.Fatalf("part 4: expected file_data-backed input_file, got %s", parts[4].Raw)
+ }
+ if parts[5].Get("type").String() != "input_file" || parts[5].Get("file_url").String() != "https://example.com/report.pdf" {
+ t.Fatalf("part 5: expected file_url-backed input_file, got %s", parts[5].Raw)
+ }
+}
+
+func TestToolCallOutputFallsBackForInvalidStructuredParts(t *testing.T) {
+ input := []byte(`{
+ "model": "gpt-4o",
+ "messages": [
+ {"role": "user", "content": "Check tool output."},
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {"id": "call_invalid_parts", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_invalid_parts",
+ "content": [
+ {"type":"image_url","image_url":{"detail":"low"}},
+ {"type":"file","file":{"filename":"orphan.txt"}},
+ {"type":"unknown_type","foo":"bar","nested":{"a":1}}
+ ]
+ }
+ ],
+ "tools": [
+ {"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
+ ]
+ }`)
+
+ out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
+ result := string(out)
+
+ parts := gjson.Get(result, "input.2.output").Array()
+ if len(parts) != 3 {
+ t.Fatalf("expected 3 output parts, got %d: %s", len(parts), gjson.Get(result, "input.2.output").Raw)
+ }
+
+ expectedFallbacks := []string{
+ `{"type":"image_url","image_url":{"detail":"low"}}`,
+ `{"type":"file","file":{"filename":"orphan.txt"}}`,
+ `{"type":"unknown_type","foo":"bar","nested":{"a":1}}`,
+ }
+ for i, expectedFallback := range expectedFallbacks {
+ if parts[i].Get("type").String() != "input_text" {
+ t.Fatalf("part %d: expected input_text fallback, got %s", i, parts[i].Raw)
+ }
+ if parts[i].Get("text").String() != expectedFallback {
+ t.Fatalf("part %d: expected fallback %s, got %s", i, expectedFallback, parts[i].Get("text").String())
+ }
+ }
+}
+
+func TestToolCallOutputWithNonStringJSONContent(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ expectedOutput string
+ }{
+ {name: "null", content: `null`, expectedOutput: `null`},
+ {name: "object", content: `{"status":"ok","count":2}`, expectedOutput: `{"status":"ok","count":2}`},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ input := []byte(`{
+ "model": "gpt-4o",
+ "messages": [
+ {"role": "user", "content": "Check tool output."},
+ {
+ "role": "assistant",
+ "content": null,
+ "tool_calls": [
+ {"id": "call_json", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
+ ]
+ },
+ {
+ "role": "tool",
+ "tool_call_id": "call_json",
+ "content": ` + tt.content + `
+ }
+ ],
+ "tools": [
+ {"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
+ ]
+ }`)
+
+ out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
+ result := string(out)
+
+ output := gjson.Get(result, "input.2.output")
+ if !output.Exists() {
+ t.Fatalf("expected output field to exist: %s", gjson.Get(result, "input.2").Raw)
+ }
+ if output.String() != tt.expectedOutput {
+ t.Fatalf("expected output %s, got %s", tt.expectedOutput, output.String())
+ }
+ })
+ }
+}
+
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go
index 46c75898..af49d306 100644
--- a/internal/translator/openai/claude/openai_claude_response.go
+++ b/internal/translator/openai/claude/openai_claude_response.go
@@ -236,7 +236,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Handle function name
if function := toolCall.Get("function"); function.Exists() {
- if name := function.Get("name"); name.Exists() {
+ if name := function.Get("name"); name.Exists() && name.String() != "" {
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
stopThinkingContentBlock(param, &results)
diff --git a/internal/translator/openai/claude/openai_claude_response_test.go b/internal/translator/openai/claude/openai_claude_response_test.go
new file mode 100644
index 00000000..8c36fc3d
--- /dev/null
+++ b/internal/translator/openai/claude/openai_claude_response_test.go
@@ -0,0 +1,41 @@
+package claude
+
+import (
+ "bytes"
+ "context"
+ "testing"
+)
+
+func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing.T) {
+ originalRequest := []byte(`{"stream":true}`)
+ var param any
+
+ firstChunks := ConvertOpenAIResponseToClaude(
+ context.Background(),
+ "test-model",
+ originalRequest,
+ nil,
+ []byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"read_file","arguments":""}}]},"finish_reason":null}]}`),
+ ¶m,
+ )
+ firstOutput := bytes.Join(firstChunks, nil)
+ if !bytes.Contains(firstOutput, []byte(`"name":"read_file"`)) {
+ t.Fatalf("expected first chunk to start read_file tool block, got %s", string(firstOutput))
+ }
+
+ secondChunks := ConvertOpenAIResponseToClaude(
+ context.Background(),
+ "test-model",
+ originalRequest,
+ nil,
+ []byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"{\"path\":\"/tmp/a\"}"}}]},"finish_reason":null}]}`),
+ ¶m,
+ )
+ secondOutput := bytes.Join(secondChunks, nil)
+ if bytes.Contains(secondOutput, []byte(`content_block_start`)) {
+ t.Fatalf("did not expect null tool name delta to start a new content block, got %s", string(secondOutput))
+ }
+ if bytes.Contains(secondOutput, []byte(`"name":""`)) {
+ t.Fatalf("did not expect null tool name delta to emit an empty tool name, got %s", string(secondOutput))
+ }
+}
diff --git a/internal/tui/app.go b/internal/tui/app.go
index b9ee9e1a..c0a7c3a8 100644
--- a/internal/tui/app.go
+++ b/internal/tui/app.go
@@ -18,7 +18,6 @@ const (
tabAuthFiles
tabAPIKeys
tabOAuth
- tabUsage
tabLogs
)
@@ -40,7 +39,6 @@ type App struct {
auth authTabModel
keys keysTabModel
oauth oauthTabModel
- usage usageTabModel
logs logsTabModel
client *Client
@@ -50,7 +48,7 @@ type App struct {
ready bool
// Track which tabs have been initialized (fetched data)
- initialized [7]bool
+ initialized [6]bool
}
type authConnectMsg struct {
@@ -81,10 +79,9 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
- usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
- initialized: [7]bool{
+ initialized: [6]bool{
tabDashboard: true,
tabLogs: true,
},
@@ -92,7 +89,7 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
app.refreshTabs()
if authRequired {
- app.initialized = [7]bool{}
+ app.initialized = [6]bool{}
}
app.setAuthInputPrompt()
return app
@@ -128,7 +125,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
- a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
@@ -142,7 +138,7 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
- a.initialized = [7]bool{}
+ a.initialized = [6]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
@@ -258,8 +254,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
- case tabUsage:
- a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
@@ -322,8 +316,6 @@ func (a *App) initTabIfNeeded(_ int) tea.Cmd {
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
- case tabUsage:
- return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
@@ -360,8 +352,6 @@ func (a App) View() string {
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
- case tabUsage:
- sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
@@ -529,10 +519,6 @@ func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
if cmd != nil {
cmds = append(cmds, cmd)
}
- a.usage, cmd = a.usage.Update(msg)
- if cmd != nil {
- cmds = append(cmds, cmd)
- }
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
diff --git a/internal/tui/client.go b/internal/tui/client.go
index 6f75d6be..747f30b9 100644
--- a/internal/tui/client.go
+++ b/internal/tui/client.go
@@ -140,11 +140,6 @@ func (c *Client) PutConfigYAML(yamlContent string) error {
return err
}
-// GetUsage fetches usage statistics.
-func (c *Client) GetUsage() (map[string]any, error) {
- return c.getJSON("/v0/management/usage")
-}
-
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
diff --git a/internal/tui/dashboard.go b/internal/tui/dashboard.go
index 8561fe9c..99b5409c 100644
--- a/internal/tui/dashboard.go
+++ b/internal/tui/dashboard.go
@@ -22,14 +22,12 @@ type dashboardModel struct {
// Cached data for re-rendering on locale change
lastConfig map[string]any
- lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
- usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
@@ -47,25 +45,24 @@ func (m dashboardModel) Init() tea.Cmd {
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
- usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
- for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
+ for _, e := range []error{cfgErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
- return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
+ return dashboardDataMsg{config: cfg, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
- m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
+ m.content = m.renderDashboard(m.lastConfig, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
@@ -78,11 +75,10 @@ func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
- m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
- m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
+ m.content = m.renderDashboard(msg.config, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
@@ -121,7 +117,7 @@ func (m dashboardModel) View() string {
return m.viewport.View()
}
-func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
+func (m dashboardModel) renderDashboard(cfg map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
@@ -138,7 +134,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
- cardWidth = (m.width - 6) / 4
+ cardWidth = (m.width - 2) / 2
if cardWidth < 18 {
cardWidth = 18
}
@@ -173,34 +169,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
- // Card 3: Total Requests
- totalReqs := int64(0)
- successReqs := int64(0)
- failedReqs := int64(0)
- totalTokens := int64(0)
- if usage != nil {
- if usageMap, ok := usage["usage"].(map[string]any); ok {
- totalReqs = int64(getFloat(usageMap, "total_requests"))
- successReqs = int64(getFloat(usageMap, "success_count"))
- failedReqs = int64(getFloat(usageMap, "failure_count"))
- totalTokens = int64(getFloat(usageMap, "total_tokens"))
- }
- }
- card3 := cardStyle.Render(fmt.Sprintf(
- "%s\n%s",
- lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
- lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
- ))
-
- // Card 4: Total Tokens
- tokenStr := formatLargeNumber(totalTokens)
- card4 := cardStyle.Render(fmt.Sprintf(
- "%s\n%s",
- lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
- lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
- ))
-
- sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
+ sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
@@ -258,38 +227,6 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
sb.WriteString("\n")
- // ━━━ Per-Model Usage ━━━
- if usage != nil {
- if usageMap, ok := usage["usage"].(map[string]any); ok {
- if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
- sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
- sb.WriteString("\n")
- sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
- sb.WriteString("\n")
-
- header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
- sb.WriteString(tableHeaderStyle.Render(header))
- sb.WriteString("\n")
-
- for _, apiSnap := range apis {
- if apiMap, ok := apiSnap.(map[string]any); ok {
- if models, ok := apiMap["models"].(map[string]any); ok {
- for model, v := range models {
- if stats, ok := v.(map[string]any); ok {
- reqs := int64(getFloat(stats, "total_requests"))
- toks := int64(getFloat(stats, "total_tokens"))
- row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
- sb.WriteString(tableCellStyle.Render(row))
- sb.WriteString("\n")
- }
- }
- }
- }
- }
- }
- }
- }
-
return sb.String()
}
diff --git a/internal/tui/i18n.go b/internal/tui/i18n.go
index f6a33ca4..a4c0ac16 100644
--- a/internal/tui/i18n.go
+++ b/internal/tui/i18n.go
@@ -50,8 +50,8 @@ var locales = map[string]map[string]string{
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
-var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
-var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
+var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "日志"}
+var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
diff --git a/internal/tui/usage_tab.go b/internal/tui/usage_tab.go
deleted file mode 100644
index 6b9fef5e..00000000
--- a/internal/tui/usage_tab.go
+++ /dev/null
@@ -1,418 +0,0 @@
-package tui
-
-import (
- "fmt"
- "sort"
- "strings"
-
- "github.com/charmbracelet/bubbles/viewport"
- tea "github.com/charmbracelet/bubbletea"
- "github.com/charmbracelet/lipgloss"
-)
-
-// usageTabModel displays usage statistics with charts and breakdowns.
-type usageTabModel struct {
- client *Client
- viewport viewport.Model
- usage map[string]any
- err error
- width int
- height int
- ready bool
-}
-
-type usageDataMsg struct {
- usage map[string]any
- err error
-}
-
-func newUsageTabModel(client *Client) usageTabModel {
- return usageTabModel{
- client: client,
- }
-}
-
-func (m usageTabModel) Init() tea.Cmd {
- return m.fetchData
-}
-
-func (m usageTabModel) fetchData() tea.Msg {
- usage, err := m.client.GetUsage()
- return usageDataMsg{usage: usage, err: err}
-}
-
-func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
- switch msg := msg.(type) {
- case localeChangedMsg:
- m.viewport.SetContent(m.renderContent())
- return m, nil
- case usageDataMsg:
- if msg.err != nil {
- m.err = msg.err
- } else {
- m.err = nil
- m.usage = msg.usage
- }
- m.viewport.SetContent(m.renderContent())
- return m, nil
-
- case tea.KeyMsg:
- if msg.String() == "r" {
- return m, m.fetchData
- }
- var cmd tea.Cmd
- m.viewport, cmd = m.viewport.Update(msg)
- return m, cmd
- }
-
- var cmd tea.Cmd
- m.viewport, cmd = m.viewport.Update(msg)
- return m, cmd
-}
-
-func (m *usageTabModel) SetSize(w, h int) {
- m.width = w
- m.height = h
- if !m.ready {
- m.viewport = viewport.New(w, h)
- m.viewport.SetContent(m.renderContent())
- m.ready = true
- } else {
- m.viewport.Width = w
- m.viewport.Height = h
- }
-}
-
-func (m usageTabModel) View() string {
- if !m.ready {
- return T("loading")
- }
- return m.viewport.View()
-}
-
-func (m usageTabModel) renderContent() string {
- var sb strings.Builder
-
- sb.WriteString(titleStyle.Render(T("usage_title")))
- sb.WriteString("\n")
- sb.WriteString(helpStyle.Render(T("usage_help")))
- sb.WriteString("\n\n")
-
- if m.err != nil {
- sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
- sb.WriteString("\n")
- return sb.String()
- }
-
- if m.usage == nil {
- sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
- sb.WriteString("\n")
- return sb.String()
- }
-
- usageMap, _ := m.usage["usage"].(map[string]any)
- if usageMap == nil {
- sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
- sb.WriteString("\n")
- return sb.String()
- }
-
- totalReqs := int64(getFloat(usageMap, "total_requests"))
- successCnt := int64(getFloat(usageMap, "success_count"))
- failureCnt := int64(getFloat(usageMap, "failure_count"))
- totalTokens := int64(getFloat(usageMap, "total_tokens"))
-
- // ━━━ Overview Cards ━━━
- cardWidth := 20
- if m.width > 0 {
- cardWidth = (m.width - 6) / 4
- if cardWidth < 16 {
- cardWidth = 16
- }
- }
- cardStyle := lipgloss.NewStyle().
- Border(lipgloss.RoundedBorder()).
- BorderForeground(lipgloss.Color("240")).
- Padding(0, 1).
- Width(cardWidth).
- Height(3)
-
- // Total Requests
- card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
- "%s\n%s\n%s",
- lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
- lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
- lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
- ))
-
- // Total Tokens
- card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
- "%s\n%s\n%s",
- lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
- lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
- lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
- ))
-
- // RPM
- rpm := float64(0)
- if totalReqs > 0 {
- if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
- rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
- }
- }
- card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
- "%s\n%s\n%s",
- lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
- lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
- lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
- ))
-
- // TPM
- tpm := float64(0)
- if totalTokens > 0 {
- if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
- tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
- }
- }
- card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
- "%s\n%s\n%s",
- lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
- lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
- lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
- ))
-
- sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
- sb.WriteString("\n\n")
-
- // ━━━ Requests by Hour (ASCII bar chart) ━━━
- if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
- sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
- sb.WriteString("\n")
- sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
- sb.WriteString("\n")
- sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
- sb.WriteString("\n")
- }
-
- // ━━━ Tokens by Hour ━━━
- if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
- sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
- sb.WriteString("\n")
- sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
- sb.WriteString("\n")
- sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
- sb.WriteString("\n")
- }
-
- // ━━━ Requests by Day ━━━
- if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
- sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
- sb.WriteString("\n")
- sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
- sb.WriteString("\n")
- sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
- sb.WriteString("\n")
- }
-
- // ━━━ API Detail Stats ━━━
- if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
- sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
- sb.WriteString("\n")
- sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
- sb.WriteString("\n")
-
- header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
- sb.WriteString(tableHeaderStyle.Render(header))
- sb.WriteString("\n")
-
- for apiName, apiSnap := range apis {
- if apiMap, ok := apiSnap.(map[string]any); ok {
- apiReqs := int64(getFloat(apiMap, "total_requests"))
- apiToks := int64(getFloat(apiMap, "total_tokens"))
-
- row := fmt.Sprintf(" %-30s %10d %12s",
- truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
- sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
- sb.WriteString("\n")
-
- // Per-model breakdown
- if models, ok := apiMap["models"].(map[string]any); ok {
- for model, v := range models {
- if stats, ok := v.(map[string]any); ok {
- mReqs := int64(getFloat(stats, "total_requests"))
- mToks := int64(getFloat(stats, "total_tokens"))
- mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
- truncate(model, 28), mReqs, formatLargeNumber(mToks))
- sb.WriteString(tableCellStyle.Render(mRow))
- sb.WriteString("\n")
-
- // Token type breakdown from details
- sb.WriteString(m.renderTokenBreakdown(stats))
-
- // Latency breakdown from details
- sb.WriteString(m.renderLatencyBreakdown(stats))
- }
- }
- }
- }
- }
- }
-
- sb.WriteString("\n")
- return sb.String()
-}
-
-// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
-func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
- details, ok := modelStats["details"]
- if !ok {
- return ""
- }
- detailList, ok := details.([]any)
- if !ok || len(detailList) == 0 {
- return ""
- }
-
- var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
- for _, d := range detailList {
- dm, ok := d.(map[string]any)
- if !ok {
- continue
- }
- tokens, ok := dm["tokens"].(map[string]any)
- if !ok {
- continue
- }
- inputTotal += int64(getFloat(tokens, "input_tokens"))
- outputTotal += int64(getFloat(tokens, "output_tokens"))
- cachedTotal += int64(getFloat(tokens, "cached_tokens"))
- reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
- }
-
- if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
- return ""
- }
-
- parts := []string{}
- if inputTotal > 0 {
- parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
- }
- if outputTotal > 0 {
- parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
- }
- if cachedTotal > 0 {
- parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
- }
- if reasoningTotal > 0 {
- parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
- }
-
- return fmt.Sprintf(" │ %s\n",
- lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
-}
-
-// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
-func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
- details, ok := modelStats["details"]
- if !ok {
- return ""
- }
- detailList, ok := details.([]any)
- if !ok || len(detailList) == 0 {
- return ""
- }
-
- var totalLatency int64
- var count int
- var minLatency, maxLatency int64
- first := true
-
- for _, d := range detailList {
- dm, ok := d.(map[string]any)
- if !ok {
- continue
- }
- latencyMs := int64(getFloat(dm, "latency_ms"))
- if latencyMs <= 0 {
- continue
- }
- totalLatency += latencyMs
- count++
- if first {
- minLatency = latencyMs
- maxLatency = latencyMs
- first = false
- } else {
- if latencyMs < minLatency {
- minLatency = latencyMs
- }
- if latencyMs > maxLatency {
- maxLatency = latencyMs
- }
- }
- }
-
- if count == 0 {
- return ""
- }
-
- avgLatency := totalLatency / int64(count)
- return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
- lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
- avgLatency, minLatency, maxLatency)
-}
-
-// renderBarChart renders a simple ASCII horizontal bar chart.
-func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
- if maxBarWidth < 10 {
- maxBarWidth = 10
- }
-
- // Sort keys
- keys := make([]string, 0, len(data))
- for k := range data {
- keys = append(keys, k)
- }
- sort.Strings(keys)
-
- // Find max value
- maxVal := float64(0)
- for _, k := range keys {
- v := getFloat(data, k)
- if v > maxVal {
- maxVal = v
- }
- }
- if maxVal == 0 {
- return ""
- }
-
- barStyle := lipgloss.NewStyle().Foreground(barColor)
- var sb strings.Builder
-
- labelWidth := 12
- barAvail := maxBarWidth - labelWidth - 12
- if barAvail < 5 {
- barAvail = 5
- }
-
- for _, k := range keys {
- v := getFloat(data, k)
- barLen := int(v / maxVal * float64(barAvail))
- if barLen < 1 && v > 0 {
- barLen = 1
- }
- bar := strings.Repeat("█", barLen)
- label := k
- if len(label) > labelWidth {
- label = label[:labelWidth]
- }
- sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
- labelWidth, label,
- barStyle.Render(bar),
- lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
- ))
- }
-
- return sb.String()
-}
diff --git a/internal/tui/usage_tab_test.go b/internal/tui/usage_tab_test.go
deleted file mode 100644
index 4fffcd98..00000000
--- a/internal/tui/usage_tab_test.go
+++ /dev/null
@@ -1,134 +0,0 @@
-package tui
-
-import (
- "strings"
- "testing"
-)
-
-func TestRenderLatencyBreakdown(t *testing.T) {
- tests := []struct {
- name string
- modelStats map[string]any
- wantEmpty bool
- wantContains string
- }{
- {
- name: "no details",
- modelStats: map[string]any{},
- wantEmpty: true,
- },
- {
- name: "empty details",
- modelStats: map[string]any{
- "details": []any{},
- },
- wantEmpty: true,
- },
- {
- name: "details with zero latency",
- modelStats: map[string]any{
- "details": []any{
- map[string]any{
- "latency_ms": float64(0),
- },
- },
- },
- wantEmpty: true,
- },
- {
- name: "single request with latency",
- modelStats: map[string]any{
- "details": []any{
- map[string]any{
- "latency_ms": float64(1500),
- },
- },
- },
- wantEmpty: false,
- wantContains: "avg 1500ms min 1500ms max 1500ms",
- },
- {
- name: "multiple requests with varying latency",
- modelStats: map[string]any{
- "details": []any{
- map[string]any{
- "latency_ms": float64(100),
- },
- map[string]any{
- "latency_ms": float64(200),
- },
- map[string]any{
- "latency_ms": float64(300),
- },
- },
- },
- wantEmpty: false,
- wantContains: "avg 200ms min 100ms max 300ms",
- },
- {
- name: "mixed valid and invalid latency values",
- modelStats: map[string]any{
- "details": []any{
- map[string]any{
- "latency_ms": float64(500),
- },
- map[string]any{
- "latency_ms": float64(0),
- },
- map[string]any{
- "latency_ms": float64(1500),
- },
- },
- },
- wantEmpty: false,
- wantContains: "avg 1000ms min 500ms max 1500ms",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- m := usageTabModel{}
- result := m.renderLatencyBreakdown(tt.modelStats)
-
- if tt.wantEmpty {
- if result != "" {
- t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
- }
- return
- }
-
- if result == "" {
- t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
- return
- }
-
- if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
- t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
- }
- })
- }
-}
-
-func TestUsageTimeTranslations(t *testing.T) {
- prevLocale := CurrentLocale()
- t.Cleanup(func() {
- SetLocale(prevLocale)
- })
-
- tests := []struct {
- locale string
- want string
- }{
- {locale: "en", want: "Time"},
- {locale: "zh", want: "时间"},
- }
-
- for _, tt := range tests {
- t.Run(tt.locale, func(t *testing.T) {
- SetLocale(tt.locale)
- if got := T("usage_time"); got != tt.want {
- t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
- }
- })
- }
-}
diff --git a/internal/usage/logger_plugin.go b/internal/usage/logger_plugin.go
deleted file mode 100644
index 803d005e..00000000
--- a/internal/usage/logger_plugin.go
+++ /dev/null
@@ -1,484 +0,0 @@
-// Package usage provides usage tracking and logging functionality for the CLI Proxy API server.
-// It includes plugins for monitoring API usage, token consumption, and other metrics
-// to help with observability and billing purposes.
-package usage
-
-import (
- "context"
- "fmt"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/gin-gonic/gin"
- coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
-)
-
-var statisticsEnabled atomic.Bool
-
-func init() {
- statisticsEnabled.Store(true)
- coreusage.RegisterPlugin(NewLoggerPlugin())
-}
-
-// LoggerPlugin collects in-memory request statistics for usage analysis.
-// It implements coreusage.Plugin to receive usage records emitted by the runtime.
-type LoggerPlugin struct {
- stats *RequestStatistics
-}
-
-// NewLoggerPlugin constructs a new logger plugin instance.
-//
-// Returns:
-// - *LoggerPlugin: A new logger plugin instance wired to the shared statistics store.
-func NewLoggerPlugin() *LoggerPlugin { return &LoggerPlugin{stats: defaultRequestStatistics} }
-
-// HandleUsage implements coreusage.Plugin.
-// It updates the in-memory statistics store whenever a usage record is received.
-//
-// Parameters:
-// - ctx: The context for the usage record
-// - record: The usage record to aggregate
-func (p *LoggerPlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
- if !statisticsEnabled.Load() {
- return
- }
- if p == nil || p.stats == nil {
- return
- }
- p.stats.Record(ctx, record)
-}
-
-// SetStatisticsEnabled toggles whether in-memory statistics are recorded.
-func SetStatisticsEnabled(enabled bool) { statisticsEnabled.Store(enabled) }
-
-// StatisticsEnabled reports the current recording state.
-func StatisticsEnabled() bool { return statisticsEnabled.Load() }
-
-// RequestStatistics maintains aggregated request metrics in memory.
-type RequestStatistics struct {
- mu sync.RWMutex
-
- totalRequests int64
- successCount int64
- failureCount int64
- totalTokens int64
-
- apis map[string]*apiStats
-
- requestsByDay map[string]int64
- requestsByHour map[int]int64
- tokensByDay map[string]int64
- tokensByHour map[int]int64
-}
-
-// apiStats holds aggregated metrics for a single API key.
-type apiStats struct {
- TotalRequests int64
- TotalTokens int64
- Models map[string]*modelStats
-}
-
-// modelStats holds aggregated metrics for a specific model within an API.
-type modelStats struct {
- TotalRequests int64
- TotalTokens int64
- Details []RequestDetail
-}
-
-// RequestDetail stores the timestamp, latency, and token usage for a single request.
-type RequestDetail struct {
- Timestamp time.Time `json:"timestamp"`
- LatencyMs int64 `json:"latency_ms"`
- Source string `json:"source"`
- AuthIndex string `json:"auth_index"`
- Tokens TokenStats `json:"tokens"`
- Failed bool `json:"failed"`
-}
-
-// TokenStats captures the token usage breakdown for a request.
-type TokenStats struct {
- InputTokens int64 `json:"input_tokens"`
- OutputTokens int64 `json:"output_tokens"`
- ReasoningTokens int64 `json:"reasoning_tokens"`
- CachedTokens int64 `json:"cached_tokens"`
- TotalTokens int64 `json:"total_tokens"`
-}
-
-// StatisticsSnapshot represents an immutable view of the aggregated metrics.
-type StatisticsSnapshot struct {
- TotalRequests int64 `json:"total_requests"`
- SuccessCount int64 `json:"success_count"`
- FailureCount int64 `json:"failure_count"`
- TotalTokens int64 `json:"total_tokens"`
-
- APIs map[string]APISnapshot `json:"apis"`
-
- RequestsByDay map[string]int64 `json:"requests_by_day"`
- RequestsByHour map[string]int64 `json:"requests_by_hour"`
- TokensByDay map[string]int64 `json:"tokens_by_day"`
- TokensByHour map[string]int64 `json:"tokens_by_hour"`
-}
-
-// APISnapshot summarises metrics for a single API key.
-type APISnapshot struct {
- TotalRequests int64 `json:"total_requests"`
- TotalTokens int64 `json:"total_tokens"`
- Models map[string]ModelSnapshot `json:"models"`
-}
-
-// ModelSnapshot summarises metrics for a specific model.
-type ModelSnapshot struct {
- TotalRequests int64 `json:"total_requests"`
- TotalTokens int64 `json:"total_tokens"`
- Details []RequestDetail `json:"details"`
-}
-
-var defaultRequestStatistics = NewRequestStatistics()
-
-// GetRequestStatistics returns the shared statistics store.
-func GetRequestStatistics() *RequestStatistics { return defaultRequestStatistics }
-
-// NewRequestStatistics constructs an empty statistics store.
-func NewRequestStatistics() *RequestStatistics {
- return &RequestStatistics{
- apis: make(map[string]*apiStats),
- requestsByDay: make(map[string]int64),
- requestsByHour: make(map[int]int64),
- tokensByDay: make(map[string]int64),
- tokensByHour: make(map[int]int64),
- }
-}
-
-// Record ingests a new usage record and updates the aggregates.
-func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record) {
- if s == nil {
- return
- }
- if !statisticsEnabled.Load() {
- return
- }
- timestamp := record.RequestedAt
- if timestamp.IsZero() {
- timestamp = time.Now()
- }
- detail := normaliseDetail(record.Detail)
- totalTokens := detail.TotalTokens
- statsKey := record.APIKey
- if statsKey == "" {
- statsKey = resolveAPIIdentifier(ctx, record)
- }
- failed := record.Failed
- if !failed {
- failed = !resolveSuccess(ctx)
- }
- success := !failed
- modelName := record.Model
- if modelName == "" {
- modelName = "unknown"
- }
- dayKey := timestamp.Format("2006-01-02")
- hourKey := timestamp.Hour()
-
- s.mu.Lock()
- defer s.mu.Unlock()
-
- s.totalRequests++
- if success {
- s.successCount++
- } else {
- s.failureCount++
- }
- s.totalTokens += totalTokens
-
- stats, ok := s.apis[statsKey]
- if !ok {
- stats = &apiStats{Models: make(map[string]*modelStats)}
- s.apis[statsKey] = stats
- }
- s.updateAPIStats(stats, modelName, RequestDetail{
- Timestamp: timestamp,
- LatencyMs: normaliseLatency(record.Latency),
- Source: record.Source,
- AuthIndex: record.AuthIndex,
- Tokens: detail,
- Failed: failed,
- })
-
- s.requestsByDay[dayKey]++
- s.requestsByHour[hourKey]++
- s.tokensByDay[dayKey] += totalTokens
- s.tokensByHour[hourKey] += totalTokens
-}
-
-func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail RequestDetail) {
- stats.TotalRequests++
- stats.TotalTokens += detail.Tokens.TotalTokens
- modelStatsValue, ok := stats.Models[model]
- if !ok {
- modelStatsValue = &modelStats{}
- stats.Models[model] = modelStatsValue
- }
- modelStatsValue.TotalRequests++
- modelStatsValue.TotalTokens += detail.Tokens.TotalTokens
- modelStatsValue.Details = append(modelStatsValue.Details, detail)
-}
-
-// Snapshot returns a copy of the aggregated metrics for external consumption.
-func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
- result := StatisticsSnapshot{}
- if s == nil {
- return result
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- result.TotalRequests = s.totalRequests
- result.SuccessCount = s.successCount
- result.FailureCount = s.failureCount
- result.TotalTokens = s.totalTokens
-
- result.APIs = make(map[string]APISnapshot, len(s.apis))
- for apiName, stats := range s.apis {
- apiSnapshot := APISnapshot{
- TotalRequests: stats.TotalRequests,
- TotalTokens: stats.TotalTokens,
- Models: make(map[string]ModelSnapshot, len(stats.Models)),
- }
- for modelName, modelStatsValue := range stats.Models {
- requestDetails := make([]RequestDetail, len(modelStatsValue.Details))
- copy(requestDetails, modelStatsValue.Details)
- apiSnapshot.Models[modelName] = ModelSnapshot{
- TotalRequests: modelStatsValue.TotalRequests,
- TotalTokens: modelStatsValue.TotalTokens,
- Details: requestDetails,
- }
- }
- result.APIs[apiName] = apiSnapshot
- }
-
- result.RequestsByDay = make(map[string]int64, len(s.requestsByDay))
- for k, v := range s.requestsByDay {
- result.RequestsByDay[k] = v
- }
-
- result.RequestsByHour = make(map[string]int64, len(s.requestsByHour))
- for hour, v := range s.requestsByHour {
- key := formatHour(hour)
- result.RequestsByHour[key] = v
- }
-
- result.TokensByDay = make(map[string]int64, len(s.tokensByDay))
- for k, v := range s.tokensByDay {
- result.TokensByDay[k] = v
- }
-
- result.TokensByHour = make(map[string]int64, len(s.tokensByHour))
- for hour, v := range s.tokensByHour {
- key := formatHour(hour)
- result.TokensByHour[key] = v
- }
-
- return result
-}
-
-type MergeResult struct {
- Added int64 `json:"added"`
- Skipped int64 `json:"skipped"`
-}
-
-// MergeSnapshot merges an exported statistics snapshot into the current store.
-// Existing data is preserved and duplicate request details are skipped.
-func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
- result := MergeResult{}
- if s == nil {
- return result
- }
-
- s.mu.Lock()
- defer s.mu.Unlock()
-
- seen := make(map[string]struct{})
- for apiName, stats := range s.apis {
- if stats == nil {
- continue
- }
- for modelName, modelStatsValue := range stats.Models {
- if modelStatsValue == nil {
- continue
- }
- for _, detail := range modelStatsValue.Details {
- seen[dedupKey(apiName, modelName, detail)] = struct{}{}
- }
- }
- }
-
- for apiName, apiSnapshot := range snapshot.APIs {
- apiName = strings.TrimSpace(apiName)
- if apiName == "" {
- continue
- }
- stats, ok := s.apis[apiName]
- if !ok || stats == nil {
- stats = &apiStats{Models: make(map[string]*modelStats)}
- s.apis[apiName] = stats
- } else if stats.Models == nil {
- stats.Models = make(map[string]*modelStats)
- }
- for modelName, modelSnapshot := range apiSnapshot.Models {
- modelName = strings.TrimSpace(modelName)
- if modelName == "" {
- modelName = "unknown"
- }
- for _, detail := range modelSnapshot.Details {
- detail.Tokens = normaliseTokenStats(detail.Tokens)
- if detail.LatencyMs < 0 {
- detail.LatencyMs = 0
- }
- if detail.Timestamp.IsZero() {
- detail.Timestamp = time.Now()
- }
- key := dedupKey(apiName, modelName, detail)
- if _, exists := seen[key]; exists {
- result.Skipped++
- continue
- }
- seen[key] = struct{}{}
- s.recordImported(apiName, modelName, stats, detail)
- result.Added++
- }
- }
- }
-
- return result
-}
-
-func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
- totalTokens := detail.Tokens.TotalTokens
- if totalTokens < 0 {
- totalTokens = 0
- }
-
- s.totalRequests++
- if detail.Failed {
- s.failureCount++
- } else {
- s.successCount++
- }
- s.totalTokens += totalTokens
-
- s.updateAPIStats(stats, modelName, detail)
-
- dayKey := detail.Timestamp.Format("2006-01-02")
- hourKey := detail.Timestamp.Hour()
-
- s.requestsByDay[dayKey]++
- s.requestsByHour[hourKey]++
- s.tokensByDay[dayKey] += totalTokens
- s.tokensByHour[hourKey] += totalTokens
-}
-
-func dedupKey(apiName, modelName string, detail RequestDetail) string {
- timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
- tokens := normaliseTokenStats(detail.Tokens)
- return fmt.Sprintf(
- "%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
- apiName,
- modelName,
- timestamp,
- detail.Source,
- detail.AuthIndex,
- detail.Failed,
- tokens.InputTokens,
- tokens.OutputTokens,
- tokens.ReasoningTokens,
- tokens.CachedTokens,
- tokens.TotalTokens,
- )
-}
-
-func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
- if ctx != nil {
- if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
- path := ginCtx.FullPath()
- if path == "" && ginCtx.Request != nil {
- path = ginCtx.Request.URL.Path
- }
- method := ""
- if ginCtx.Request != nil {
- method = ginCtx.Request.Method
- }
- if path != "" {
- if method != "" {
- return method + " " + path
- }
- return path
- }
- }
- }
- if record.Provider != "" {
- return record.Provider
- }
- return "unknown"
-}
-
-func resolveSuccess(ctx context.Context) bool {
- if ctx == nil {
- return true
- }
- ginCtx, ok := ctx.Value("gin").(*gin.Context)
- if !ok || ginCtx == nil {
- return true
- }
- status := ginCtx.Writer.Status()
- if status == 0 {
- return true
- }
- return status < httpStatusBadRequest
-}
-
-const httpStatusBadRequest = 400
-
-func normaliseDetail(detail coreusage.Detail) TokenStats {
- tokens := TokenStats{
- InputTokens: detail.InputTokens,
- OutputTokens: detail.OutputTokens,
- ReasoningTokens: detail.ReasoningTokens,
- CachedTokens: detail.CachedTokens,
- TotalTokens: detail.TotalTokens,
- }
- if tokens.TotalTokens == 0 {
- tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
- }
- if tokens.TotalTokens == 0 {
- tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens + detail.CachedTokens
- }
- return tokens
-}
-
-func normaliseTokenStats(tokens TokenStats) TokenStats {
- if tokens.TotalTokens == 0 {
- tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
- }
- if tokens.TotalTokens == 0 {
- tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
- }
- return tokens
-}
-
-func normaliseLatency(latency time.Duration) int64 {
- if latency <= 0 {
- return 0
- }
- return latency.Milliseconds()
-}
-
-func formatHour(hour int) string {
- if hour < 0 {
- hour = 0
- }
- hour = hour % 24
- return fmt.Sprintf("%02d", hour)
-}
diff --git a/internal/usage/logger_plugin_test.go b/internal/usage/logger_plugin_test.go
deleted file mode 100644
index 842b3f0c..00000000
--- a/internal/usage/logger_plugin_test.go
+++ /dev/null
@@ -1,96 +0,0 @@
-package usage
-
-import (
- "context"
- "testing"
- "time"
-
- coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
-)
-
-func TestRequestStatisticsRecordIncludesLatency(t *testing.T) {
- stats := NewRequestStatistics()
- stats.Record(context.Background(), coreusage.Record{
- APIKey: "test-key",
- Model: "gpt-5.4",
- RequestedAt: time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC),
- Latency: 1500 * time.Millisecond,
- Detail: coreusage.Detail{
- InputTokens: 10,
- OutputTokens: 20,
- TotalTokens: 30,
- },
- })
-
- snapshot := stats.Snapshot()
- details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
- if len(details) != 1 {
- t.Fatalf("details len = %d, want 1", len(details))
- }
- if details[0].LatencyMs != 1500 {
- t.Fatalf("latency_ms = %d, want 1500", details[0].LatencyMs)
- }
-}
-
-func TestRequestStatisticsMergeSnapshotDedupIgnoresLatency(t *testing.T) {
- stats := NewRequestStatistics()
- timestamp := time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC)
- first := StatisticsSnapshot{
- APIs: map[string]APISnapshot{
- "test-key": {
- Models: map[string]ModelSnapshot{
- "gpt-5.4": {
- Details: []RequestDetail{{
- Timestamp: timestamp,
- LatencyMs: 0,
- Source: "user@example.com",
- AuthIndex: "0",
- Tokens: TokenStats{
- InputTokens: 10,
- OutputTokens: 20,
- TotalTokens: 30,
- },
- }},
- },
- },
- },
- },
- }
- second := StatisticsSnapshot{
- APIs: map[string]APISnapshot{
- "test-key": {
- Models: map[string]ModelSnapshot{
- "gpt-5.4": {
- Details: []RequestDetail{{
- Timestamp: timestamp,
- LatencyMs: 2500,
- Source: "user@example.com",
- AuthIndex: "0",
- Tokens: TokenStats{
- InputTokens: 10,
- OutputTokens: 20,
- TotalTokens: 30,
- },
- }},
- },
- },
- },
- },
- }
-
- result := stats.MergeSnapshot(first)
- if result.Added != 1 || result.Skipped != 0 {
- t.Fatalf("first merge = %+v, want added=1 skipped=0", result)
- }
-
- result = stats.MergeSnapshot(second)
- if result.Added != 0 || result.Skipped != 1 {
- t.Fatalf("second merge = %+v, want added=0 skipped=1", result)
- }
-
- snapshot := stats.Snapshot()
- details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
- if len(details) != 1 {
- t.Fatalf("details len = %d, want 1", len(details))
- }
-}
diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go
index 2be9aa90..b414ed5a 100644
--- a/internal/watcher/diff/config_diff.go
+++ b/internal/watcher/diff/config_diff.go
@@ -39,6 +39,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
}
+ if oldCfg.RedisUsageQueueRetentionSeconds != newCfg.RedisUsageQueueRetentionSeconds {
+ changes = append(changes, fmt.Sprintf("redis-usage-queue-retention-seconds: %d -> %d", oldCfg.RedisUsageQueueRetentionSeconds, newCfg.RedisUsageQueueRetentionSeconds))
+ }
if oldCfg.DisableCooling != newCfg.DisableCooling {
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
}
diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go
index 22f7c41a..e89227aa 100644
--- a/sdk/api/handlers/handlers.go
+++ b/sdk/api/handlers/handlers.go
@@ -375,11 +375,32 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
- } else if requestID := logging.GetGinRequestID(c); requestID != "" {
+ } else if requestID = logging.GetGinRequestID(c); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
}
}
newCtx, cancel := context.WithCancel(parentCtx)
+
+ endpoint := ""
+ if c != nil && c.Request != nil {
+ path := strings.TrimSpace(c.FullPath())
+ if path == "" && c.Request.URL != nil {
+ path = strings.TrimSpace(c.Request.URL.Path)
+ }
+ if path != "" {
+ method := strings.TrimSpace(c.Request.Method)
+ if method != "" {
+ endpoint = method + " " + path
+ } else {
+ endpoint = path
+ }
+ }
+ }
+ if endpoint != "" {
+ newCtx = logging.WithEndpoint(newCtx, endpoint)
+ }
+ newCtx = logging.WithResponseStatusHolder(newCtx)
+
cancelCtx := newCtx
if requestCtx != nil && requestCtx != parentCtx {
go func() {
@@ -393,6 +414,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) {
+ if c != nil {
+ logging.SetResponseStatus(cancelCtx, c.Writer.Status())
+ }
if h.Cfg.RequestLog && len(params) == 1 {
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(bytes.TrimSpace(existingBytes)) > 0 {
@@ -515,7 +539,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
- reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
+ reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -563,7 +587,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
- reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
+ reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -615,7 +639,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return nil, nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
- reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
+ reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON
if len(payload) == 0 {
payload = nil
diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go
index 8969ce2f..8dd1a0a7 100644
--- a/sdk/api/handlers/openai/openai_responses_handlers.go
+++ b/sdk/api/handlers/openai/openai_responses_handlers.go
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net/http"
+ "sort"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -45,7 +46,10 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
}
type responsesSSEFramer struct {
- pending []byte
+ pending []byte
+ outputItems map[int][]byte
+ outputOrder []int
+ unindexedOutputItems [][]byte
}
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
@@ -61,7 +65,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if frameLen == 0 {
break
}
- writeResponsesSSEChunk(w, f.pending[:frameLen])
+ f.writeFrame(w, f.pending[:frameLen])
copy(f.pending, f.pending[frameLen:])
f.pending = f.pending[:len(f.pending)-frameLen]
}
@@ -72,7 +76,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
return
}
- writeResponsesSSEChunk(w, f.pending)
+ f.writeFrame(w, f.pending)
f.pending = f.pending[:0]
}
@@ -88,10 +92,133 @@ func (f *responsesSSEFramer) Flush(w io.Writer) {
f.pending = f.pending[:0]
return
}
- writeResponsesSSEChunk(w, f.pending)
+ f.writeFrame(w, f.pending)
f.pending = f.pending[:0]
}
+func (f *responsesSSEFramer) writeFrame(w io.Writer, frame []byte) {
+ writeResponsesSSEChunk(w, f.repairFrame(frame))
+}
+
+func (f *responsesSSEFramer) repairFrame(frame []byte) []byte {
+ payload, ok := responsesSSEDataPayload(frame)
+ if !ok || len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) || !json.Valid(payload) {
+ return frame
+ }
+
+ switch gjson.GetBytes(payload, "type").String() {
+ case "response.output_item.done":
+ f.recordOutputItem(payload)
+ case "response.completed":
+ repaired := f.repairCompletedPayload(payload)
+ if !bytes.Equal(repaired, payload) {
+ return responsesSSEFrameWithData(frame, repaired)
+ }
+ }
+ return frame
+}
+
+func responsesSSEDataPayload(frame []byte) ([]byte, bool) {
+ var payload []byte
+ found := false
+ for _, line := range bytes.Split(frame, []byte("\n")) {
+ line = bytes.TrimRight(line, "\r")
+ trimmed := bytes.TrimSpace(line)
+ if !bytes.HasPrefix(trimmed, []byte("data:")) {
+ continue
+ }
+ data := bytes.TrimSpace(trimmed[len("data:"):])
+ if found {
+ payload = append(payload, '\n')
+ }
+ payload = append(payload, data...)
+ found = true
+ }
+ return payload, found
+}
+
+func responsesSSEFrameWithData(frame, payload []byte) []byte {
+ var out bytes.Buffer
+ for _, line := range bytes.Split(frame, []byte("\n")) {
+ line = bytes.TrimRight(line, "\r")
+ trimmed := bytes.TrimSpace(line)
+ if len(trimmed) == 0 || bytes.HasPrefix(trimmed, []byte("data:")) {
+ continue
+ }
+ out.Write(line)
+ out.WriteByte('\n')
+ }
+ for _, line := range bytes.Split(payload, []byte("\n")) {
+ out.WriteString("data: ")
+ out.Write(line)
+ out.WriteByte('\n')
+ }
+ out.WriteByte('\n')
+ return out.Bytes()
+}
+
+func (f *responsesSSEFramer) recordOutputItem(payload []byte) {
+ item := gjson.GetBytes(payload, "item")
+ if !item.Exists() || !item.IsObject() || item.Get("type").String() == "" {
+ return
+ }
+
+ if outputIndex := gjson.GetBytes(payload, "output_index"); outputIndex.Exists() {
+ index := int(outputIndex.Int())
+ if f.outputItems == nil {
+ f.outputItems = make(map[int][]byte)
+ }
+ if _, exists := f.outputItems[index]; !exists {
+ f.outputOrder = append(f.outputOrder, index)
+ }
+ f.outputItems[index] = append([]byte(nil), item.Raw...)
+ return
+ }
+
+ f.unindexedOutputItems = append(f.unindexedOutputItems, append([]byte(nil), item.Raw...))
+}
+
+func (f *responsesSSEFramer) repairCompletedPayload(payload []byte) []byte {
+ if len(f.outputOrder) == 0 && len(f.unindexedOutputItems) == 0 {
+ return payload
+ }
+ output := gjson.GetBytes(payload, "response.output")
+ if output.Exists() && (!output.IsArray() || len(output.Array()) > 0) {
+ return payload
+ }
+
+ var outputJSON bytes.Buffer
+ outputJSON.WriteByte('[')
+ indexes := append([]int(nil), f.outputOrder...)
+ sort.Ints(indexes)
+ written := 0
+ for _, index := range indexes {
+ item, ok := f.outputItems[index]
+ if !ok {
+ continue
+ }
+ if written > 0 {
+ outputJSON.WriteByte(',')
+ }
+ outputJSON.Write(item)
+ written++
+ }
+ for _, item := range f.unindexedOutputItems {
+ if written > 0 {
+ outputJSON.WriteByte(',')
+ }
+ outputJSON.Write(item)
+ written++
+ }
+ outputJSON.WriteByte(']')
+
+ repaired, err := sjson.SetRawBytes(payload, "response.output", outputJSON.Bytes())
+ if err != nil {
+ return payload
+ }
+ return repaired
+}
+
func responsesSSEFrameLen(chunk []byte) int {
if len(chunk) == 0 {
return 0
diff --git a/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go b/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go
index ef16fe80..151da9a7 100644
--- a/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go
+++ b/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go
@@ -10,6 +10,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
+ "github.com/tidwall/gjson"
)
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
@@ -53,12 +54,108 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
}
- expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
+ expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"function_call\",\"arguments\":\"{}\"}]}}"
if parts[1] != expectedPart2 {
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
}
}
+func TestForwardResponsesStreamRepairsEmptyCompletedOutputFromDoneItems(t *testing.T) {
+ h, recorder, c, flusher := newResponsesStreamTestHandler(t)
+
+ data := make(chan []byte, 3)
+ errs := make(chan *interfaces.ErrorMessage)
+ data <- []byte(`data: {"type":"response.output_item.done","output_index":0,"item":{"type":"reasoning","id":"rs-1","summary":[]}}`)
+ data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{\"cmd\":\"pwd\"}","status":"completed"}}`)
+ data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
+ close(data)
+ close(errs)
+
+ h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
+
+ parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
+ if len(parts) != 3 {
+ t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
+ }
+
+ payload := strings.TrimPrefix(parts[2], "data: ")
+ output := gjson.Get(payload, "response.output")
+ if !output.IsArray() || len(output.Array()) != 2 {
+ t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
+ }
+ if got := gjson.Get(payload, "response.output.1.name").String(); got != "shell" {
+ t.Fatalf("expected function_call name to be preserved, got %q in %s", got, payload)
+ }
+ if got := gjson.Get(payload, "response.output.1.arguments").String(); got != `{"cmd":"pwd"}` {
+ t.Fatalf("expected function_call arguments to be preserved, got %q in %s", got, payload)
+ }
+}
+
+func TestForwardResponsesStreamRepairsMixedIndexedAndUnindexedDoneItems(t *testing.T) {
+ h, recorder, c, flusher := newResponsesStreamTestHandler(t)
+
+ data := make(chan []byte, 3)
+ errs := make(chan *interfaces.ErrorMessage)
+ data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{}","status":"completed"}}`)
+ data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"message","id":"msg-1","role":"assistant","content":[{"type":"output_text","text":"done"}]}}`)
+ data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
+ close(data)
+ close(errs)
+
+ h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
+
+ parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
+ if len(parts) != 3 {
+ t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
+ }
+
+ payload := strings.TrimPrefix(parts[2], "data: ")
+ output := gjson.Get(payload, "response.output")
+ if !output.IsArray() || len(output.Array()) != 2 {
+ t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
+ }
+ if got := gjson.Get(payload, "response.output.0.name").String(); got != "shell" {
+ t.Fatalf("expected indexed function_call to be preserved first, got %q in %s", got, payload)
+ }
+ if got := gjson.Get(payload, "response.output.1.id").String(); got != "msg-1" {
+ t.Fatalf("expected unindexed message to be appended, got %q in %s", got, payload)
+ }
+}
+
+func TestForwardResponsesStreamRepairsMultilineCompletedOutputAsSSEDataLines(t *testing.T) {
+ h, recorder, c, flusher := newResponsesStreamTestHandler(t)
+
+ data := make(chan []byte, 2)
+ errs := make(chan *interfaces.ErrorMessage)
+ data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"function_call","arguments":"{}"}}`)
+ data <- []byte("data: {\"type\":\"response.completed\",\ndata: \"response\":{\"id\":\"resp-1\",\"output\":[]}}\n\n")
+ close(data)
+ close(errs)
+
+ h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
+
+ parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
+ if len(parts) != 2 {
+ t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
+ }
+
+ completedFrame := []byte(parts[1])
+ for _, line := range strings.Split(parts[1], "\n") {
+ if line != "" && !strings.HasPrefix(line, "data: ") {
+ t.Fatalf("expected every completed payload line to be an SSE data line, got %q in %q", line, parts[1])
+ }
+ }
+
+ payload, ok := responsesSSEDataPayload(completedFrame)
+ if !ok {
+ t.Fatalf("expected completed frame to contain data payload: %q", parts[1])
+ }
+ output := gjson.GetBytes(payload, "response.output")
+ if !output.IsArray() || len(output.Array()) != 1 {
+ t.Fatalf("expected repaired completed output with 1 item, got %s from %q", output.Raw, payload)
+ }
+}
+
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go
index 2f6b14a7..7a9d2224 100644
--- a/sdk/api/handlers/openai/openai_responses_websocket.go
+++ b/sdk/api/handlers/openai/openai_responses_websocket.go
@@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
+ forceTranscriptReplayNextRequest := false
for {
msgType, payload, errReadMessage := conn.ReadMessage()
@@ -115,6 +116,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
+ if forceTranscriptReplayNextRequest {
+ allowIncrementalInputWithPreviousResponseID = false
+ }
+
+ allowCompactionReplayBypass := false
+ if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
+ if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
+ allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
+ }
+ } else {
+ requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
+ if requestModelName == "" {
+ requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
+ }
+ allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
+ }
var requestJSON []byte
var updatedLastRequest []byte
@@ -124,6 +141,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
+ allowCompactionReplayBypass,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
@@ -165,7 +183,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
+ previousLastRequest := bytes.Clone(lastRequest)
+ previousLastResponseOutput := bytes.Clone(lastResponseOutput)
+ forcedTranscriptReplay := forceTranscriptReplayNextRequest
lastRequest = updatedLastRequest
+ if forcedTranscriptReplay {
+ forceTranscriptReplayNextRequest = false
+ }
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
@@ -190,12 +214,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
- completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
+ completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
+ if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
+ pinnedAuthID = ""
+ forceTranscriptReplayNextRequest = true
+ lastRequest = previousLastRequest
+ lastResponseOutput = previousLastResponseOutput
+ continue
+ }
lastResponseOutput = completedOutput
}
}
@@ -222,10 +253,10 @@ func websocketUpgradeHeaders(req *http.Request) http.Header {
}
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
- return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
+ return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true, true)
}
-func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
+func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
@@ -233,10 +264,10 @@ func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []by
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
- return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
+ return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
- return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
+ return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
@@ -265,7 +296,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
return normalized, bytes.Clone(normalized), nil
}
-func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
+func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
@@ -315,20 +346,37 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}
- existingInput := gjson.GetBytes(lastRequest, "input")
- mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
- if errMerge != nil {
- return nil, lastRequest, &interfaces.ErrorMessage{
- StatusCode: http.StatusBadRequest,
- Error: fmt.Errorf("invalid previous response output: %w", errMerge),
+ // When the client sends a compact replay for a downstream that can consume it
+ // directly, the input already carries the canonical history. In that case,
+ // skip merging with stale lastRequest/lastResponseOutput to avoid breaking
+ // function_call / function_call_output pairings.
+ // See: https://github.com/router-for-me/CLIProxyAPI/issues/2207
+ var mergedInput string
+ if allowCompactionReplayBypass && inputContainsFullTranscript(nextInput) {
+ log.Infof("responses websocket: full transcript detected, skipping stale merge (input items=%d)", len(nextInput.Array()))
+ mergedInput = nextInput.Raw
+ } else {
+ appendInputRaw := nextInput.Raw
+ if inputContainsFullTranscript(nextInput) {
+ appendInputRaw = inputWithoutCompactionItems(nextInput)
}
- }
- mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
- if errMerge != nil {
- return nil, lastRequest, &interfaces.ErrorMessage{
- StatusCode: http.StatusBadRequest,
- Error: fmt.Errorf("invalid request input: %w", errMerge),
+ existingInput := gjson.GetBytes(lastRequest, "input")
+ var errMerge error
+ mergedInput, errMerge = mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
+ if errMerge != nil {
+ return nil, lastRequest, &interfaces.ErrorMessage{
+ StatusCode: http.StatusBadRequest,
+ Error: fmt.Errorf("invalid previous response output: %w", errMerge),
+ }
+ }
+
+ mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, appendInputRaw)
+ if errMerge != nil {
+ return nil, lastRequest, &interfaces.ErrorMessage{
+ StatusCode: http.StatusBadRequest,
+ Error: fmt.Errorf("invalid request input: %w", errMerge),
+ }
}
}
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
@@ -480,72 +528,104 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
}
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
- if h == nil || h.AuthManager == nil {
+ auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
+ for _, auth := range auths {
+ if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
+ return true
+ }
+ }
+ return false
+}
+
+func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsCompactionReplayForModel(modelName string) bool {
+ auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
+ if len(auths) == 0 {
return false
}
+ for _, auth := range auths {
+ if !responsesWebsocketAuthSupportsCompactionReplay(auth) {
+ return false
+ }
+ }
+ return true
+}
- resolvedModelName := modelName
+func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(modelName string) ([]*coreauth.Auth, string) {
+ if h == nil || h.AuthManager == nil {
+ return nil, ""
+ }
+ resolvedModelName := responsesWebsocketResolvedModelName(modelName)
+ providerSet, modelKey := responsesWebsocketProviderSetForModel(resolvedModelName)
+ if len(providerSet) == 0 {
+ return nil, modelKey
+ }
+
+ registryRef := registry.GetGlobalRegistry()
+ now := time.Now()
+ auths := h.AuthManager.List()
+ available := make([]*coreauth.Auth, 0, len(auths))
+ for _, auth := range auths {
+ if !responsesWebsocketAuthMatchesModel(auth, providerSet, modelKey, registryRef, now) {
+ continue
+ }
+ available = append(available, auth)
+ }
+ return available, modelKey
+}
+
+func responsesWebsocketResolvedModelName(modelName string) string {
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
- resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
- } else {
- resolvedModelName = resolvedBase
+ return fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
}
- } else {
- resolvedModelName = util.ResolveAutoModel(modelName)
+ return resolvedBase
}
+ return util.ResolveAutoModel(modelName)
+}
+func responsesWebsocketProviderSetForModel(resolvedModelName string) (map[string]struct{}, string) {
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
- if len(providers) == 0 {
- return false
- }
-
providerSet := make(map[string]struct{}, len(providers))
- for i := 0; i < len(providers); i++ {
- providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
+ for _, provider := range providers {
+ providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
- if len(providerSet) == 0 {
- return false
- }
-
modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
- registryRef := registry.GetGlobalRegistry()
- now := time.Now()
- auths := h.AuthManager.List()
- for i := 0; i < len(auths); i++ {
- auth := auths[i]
- if auth == nil {
- continue
- }
- providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
- if _, ok := providerSet[providerKey]; !ok {
- continue
- }
- if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
- continue
- }
- if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
- continue
- }
- if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
- return true
- }
+ return providerSet, modelKey
+}
+
+func responsesWebsocketAuthMatchesModel(auth *coreauth.Auth, providerSet map[string]struct{}, modelKey string, registryRef *registry.ModelRegistry, now time.Time) bool {
+ if auth == nil {
+ return false
}
- return false
+ providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
+ if _, ok := providerSet[providerKey]; !ok {
+ return false
+ }
+ if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
+ return false
+ }
+ return responsesWebsocketAuthAvailableForModel(auth, modelKey, now)
+}
+
+func responsesWebsocketAuthSupportsCompactionReplay(auth *coreauth.Auth) bool {
+ if auth == nil {
+ return false
+ }
+ return strings.EqualFold(strings.TrimSpace(auth.Provider), "codex")
}
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
@@ -691,6 +771,42 @@ func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
return string(out), nil
}
+// inputContainsFullTranscript returns true when the input array carries compact
+// replay markers that indicate the client already sent the full conversation
+// transcript. Merging that input with stale lastRequest/lastResponseOutput
+// would duplicate or break function_call/function_call_output pairings, so the
+// caller should use the input as-is.
+//
+// Assistant messages alone are not enough to classify the payload as a replay:
+// incremental websocket requests may legitimately append assistant items.
+func inputContainsFullTranscript(input gjson.Result) bool {
+ if !input.IsArray() {
+ return false
+ }
+ for _, item := range input.Array() {
+ t := item.Get("type").String()
+ if t == "compaction" || t == "compaction_summary" {
+ return true
+ }
+ }
+ return false
+}
+
+func inputWithoutCompactionItems(input gjson.Result) string {
+ if !input.IsArray() {
+ return normalizeJSONArrayRaw([]byte(input.Raw))
+ }
+ filtered := make([]string, 0, len(input.Array()))
+ for _, item := range input.Array() {
+ t := item.Get("type").String()
+ if t == "compaction" || t == "compaction_summary" {
+ continue
+ }
+ filtered = append(filtered, item.Raw)
+ }
+ return "[" + strings.Join(filtered, ",") + "]"
+}
+
func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
@@ -711,7 +827,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsTimelineLog *strings.Builder,
sessionID string,
-) ([]byte, error) {
+) ([]byte, *interfaces.ErrorMessage, error) {
completed := false
completedOutput := []byte("[]")
downstreamSessionKey := ""
@@ -723,7 +839,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
- return completedOutput, c.Request.Context().Err()
+ return completedOutput, nil, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
@@ -748,7 +864,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
- return completedOutput, errWrite
+ return completedOutput, errMsg, errWrite
}
}
if errMsg != nil {
@@ -756,7 +872,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
} else {
cancel(nil)
}
- return completedOutput, nil
+ return completedOutput, errMsg, nil
case chunk, ok := <-data:
if !ok {
if !completed {
@@ -782,13 +898,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
- return completedOutput, errWrite
+ return completedOutput, errMsg, errWrite
}
cancel(errMsg.Error)
- return completedOutput, nil
+ return completedOutput, errMsg, nil
}
cancel(nil)
- return completedOutput, nil
+ return completedOutput, nil, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
@@ -815,13 +931,31 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
- return completedOutput, errWrite
+ return completedOutput, nil, errWrite
}
}
}
}
}
+func shouldReleaseResponsesWebsocketPinnedAuth(errMsg *interfaces.ErrorMessage) bool {
+ if errMsg == nil {
+ return false
+ }
+ status := errMsg.StatusCode
+ if status <= 0 && errMsg.Error != nil {
+ if se, ok := errMsg.Error.(interface{ StatusCode() int }); ok && se != nil {
+ status = se.StatusCode()
+ }
+ }
+ switch status {
+ case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests:
+ return true
+ default:
+ return false
+ }
+}
+
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go
index ecfc90b3..1d397ecd 100644
--- a/sdk/api/handlers/openai/openai_responses_websocket_test.go
+++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go
@@ -69,6 +69,22 @@ type websocketAuthCaptureExecutor struct {
authIDs []string
}
+type websocketPinnedFailoverExecutor struct {
+ mu sync.Mutex
+ authIDs []string
+ calls map[string]int
+ payloads map[string][][]byte
+}
+
+type websocketPinnedFailoverStatusError struct {
+ status int
+ msg string
+}
+
+func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
+
+func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
+
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -106,6 +122,76 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
return append([]string(nil), e.authIDs...)
}
+func (e *websocketPinnedFailoverExecutor) Identifier() string { return "test-provider" }
+
+func (e *websocketPinnedFailoverExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
+ return coreexecutor.Response{}, errors.New("not implemented")
+}
+
+func (e *websocketPinnedFailoverExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
+ authID := ""
+ if auth != nil {
+ authID = auth.ID
+ }
+
+ e.mu.Lock()
+ if e.calls == nil {
+ e.calls = make(map[string]int)
+ }
+ if e.payloads == nil {
+ e.payloads = make(map[string][][]byte)
+ }
+ e.authIDs = append(e.authIDs, authID)
+ e.calls[authID]++
+ call := e.calls[authID]
+ e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload))
+ e.mu.Unlock()
+
+ if authID == "auth-a" && call == 2 {
+ chunks := make(chan coreexecutor.StreamChunk, 1)
+ chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{
+ status: http.StatusTooManyRequests,
+ msg: `{"error":{"message":"quota exhausted","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
+ }}
+ close(chunks)
+ return &coreexecutor.StreamResult{Chunks: chunks}, nil
+ }
+
+ chunks := make(chan coreexecutor.StreamChunk, 1)
+ chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-%s-%d","output":[{"type":"message","id":"out-%s-%d"}]}}`, authID, call, authID, call))}
+ close(chunks)
+ return &coreexecutor.StreamResult{Chunks: chunks}, nil
+}
+
+func (e *websocketPinnedFailoverExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
+ return auth, nil
+}
+
+func (e *websocketPinnedFailoverExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
+ return coreexecutor.Response{}, errors.New("not implemented")
+}
+
+func (e *websocketPinnedFailoverExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (e *websocketPinnedFailoverExecutor) AuthIDs() []string {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ return append([]string(nil), e.authIDs...)
+}
+
+func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ src := e.payloads[authID]
+ out := make([][]byte, len(src))
+ for i := range src {
+ out[i] = bytes.Clone(src[i])
+ }
+ return out
+}
+
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -242,7 +328,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
- normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
+ normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
@@ -278,7 +364,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncre
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
- normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
+ normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
@@ -681,7 +767,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
close(errCh)
var timelineLog strings.Builder
- completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
+ completedOutput, errMsg, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
@@ -694,6 +780,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
serverErrCh <- err
return
}
+ if errMsg != nil {
+ serverErrCh <- fmt.Errorf("unexpected websocket error message: %v", errMsg.Error)
+ return
+ }
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
serverErrCh <- errors.New("completed output not captured")
return
@@ -760,7 +850,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing
return
}
- _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
+ _, _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
@@ -867,6 +957,53 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
}
}
+func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) {
+ manager := coreauth.NewManager(nil, nil, nil)
+ auth := &coreauth.Auth{
+ ID: "auth-codex",
+ Provider: "codex",
+ Status: coreauth.StatusActive,
+ }
+ if _, err := manager.Register(context.Background(), auth); err != nil {
+ t.Fatalf("Register auth: %v", err)
+ }
+ registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
+ t.Cleanup(func() {
+ registry.GetGlobalRegistry().UnregisterClient(auth.ID)
+ })
+
+ base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
+ h := NewOpenAIResponsesAPIHandler(base)
+ if !h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
+ t.Fatalf("expected codex upstream to support compaction replay")
+ }
+}
+
+func TestWebsocketUpstreamSupportsCompactionReplayForModelFalseWhenMixedBackends(t *testing.T) {
+ manager := coreauth.NewManager(nil, nil, nil)
+ auths := []*coreauth.Auth{
+ {ID: "auth-codex", Provider: "codex", Status: coreauth.StatusActive},
+ {ID: "auth-claude", Provider: "claude", Status: coreauth.StatusActive},
+ }
+ for _, auth := range auths {
+ if _, err := manager.Register(context.Background(), auth); err != nil {
+ t.Fatalf("Register auth %s: %v", auth.ID, err)
+ }
+ registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
+ }
+ t.Cleanup(func() {
+ for _, auth := range auths {
+ registry.GetGlobalRegistry().UnregisterClient(auth.ID)
+ }
+ })
+
+ base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
+ h := NewOpenAIResponsesAPIHandler(base)
+ if h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
+ t.Fatalf("expected mixed backend model to disable compaction replay bypass")
+ }
+}
+
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -1066,6 +1203,99 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
}
}
+func TestResponsesWebsocketReleasesPinnedAuthAfterQuotaError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ selector := &orderedWebsocketSelector{order: []string{"auth-a", "auth-b"}}
+ executor := &websocketPinnedFailoverExecutor{}
+ manager := coreauth.NewManager(nil, selector, nil)
+ manager.RegisterExecutor(executor)
+
+ authA := &coreauth.Auth{
+ ID: "auth-a",
+ Provider: executor.Identifier(),
+ Status: coreauth.StatusActive,
+ Attributes: map[string]string{"websockets": "true"},
+ }
+ if _, err := manager.Register(context.Background(), authA); err != nil {
+ t.Fatalf("Register auth A: %v", err)
+ }
+ authB := &coreauth.Auth{
+ ID: "auth-b",
+ Provider: executor.Identifier(),
+ Status: coreauth.StatusActive,
+ Attributes: map[string]string{"websockets": "true"},
+ }
+ if _, err := manager.Register(context.Background(), authB); err != nil {
+ t.Fatalf("Register auth B: %v", err)
+ }
+
+ registry.GetGlobalRegistry().RegisterClient(authA.ID, authA.Provider, []*registry.ModelInfo{{ID: "quota-model"}})
+ registry.GetGlobalRegistry().RegisterClient(authB.ID, authB.Provider, []*registry.ModelInfo{{ID: "quota-model"}})
+ t.Cleanup(func() {
+ registry.GetGlobalRegistry().UnregisterClient(authA.ID)
+ registry.GetGlobalRegistry().UnregisterClient(authB.ID)
+ })
+
+ base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
+ h := NewOpenAIResponsesAPIHandler(base)
+ router := gin.New()
+ router.GET("/v1/responses/ws", h.ResponsesWebsocket)
+
+ server := httptest.NewServer(router)
+ defer server.Close()
+
+ wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
+ conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
+ if err != nil {
+ t.Fatalf("dial websocket: %v", err)
+ }
+ defer func() {
+ if errClose := conn.Close(); errClose != nil {
+ t.Fatalf("close websocket: %v", errClose)
+ }
+ }()
+
+ requests := []string{
+ `{"type":"response.create","model":"quota-model","input":[{"type":"message","id":"msg-1"}]}`,
+ `{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-2"}]}`,
+ `{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-3"}]}`,
+ }
+ wantTypes := []string{wsEventTypeCompleted, wsEventTypeError, wsEventTypeCompleted}
+ for i := range requests {
+ if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
+ t.Fatalf("write websocket message %d: %v", i+1, errWrite)
+ }
+ _, payload, errReadMessage := conn.ReadMessage()
+ if errReadMessage != nil {
+ t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
+ }
+ if got := gjson.GetBytes(payload, "type").String(); got != wantTypes[i] {
+ t.Fatalf("message %d payload type = %s, want %s: %s", i+1, got, wantTypes[i], payload)
+ }
+ if i == 1 && int(gjson.GetBytes(payload, "status").Int()) != http.StatusTooManyRequests {
+ t.Fatalf("quota payload status = %d, want %d: %s", gjson.GetBytes(payload, "status").Int(), http.StatusTooManyRequests, payload)
+ }
+ }
+
+ if got := executor.AuthIDs(); len(got) != 3 || got[0] != "auth-a" || got[1] != "auth-a" || got[2] != "auth-b" {
+ t.Fatalf("selected auth IDs = %v, want [auth-a auth-a auth-b]", got)
+ }
+
+ authBPayloads := executor.Payloads("auth-b")
+ if len(authBPayloads) != 1 {
+ t.Fatalf("auth-b payload count = %d, want 1", len(authBPayloads))
+ }
+ authBPayload := authBPayloads[0]
+ if gjson.GetBytes(authBPayload, "previous_response_id").Exists() {
+ t.Fatalf("previous_response_id leaked after auth failover: %s", authBPayload)
+ }
+ authBInput := gjson.GetBytes(authBPayload, "input").Raw
+ if !strings.Contains(authBInput, `"id":"msg-1"`) || !strings.Contains(authBInput, `"id":"msg-3"`) {
+ t.Fatalf("auth-b replay input missing expected transcript items: %s", authBInput)
+ }
+}
+
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
lastResponseOutput := []byte(`[
@@ -1400,3 +1630,171 @@ func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *t
t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String())
}
}
+
+func TestInputContainsFullTranscriptFalseForAssistantMessageOnly(t *testing.T) {
+ input := gjson.Parse(`[
+ {"type":"message","role":"user","content":"hello"},
+ {"type":"message","role":"assistant","content":"hi there"}
+ ]`)
+ if inputContainsFullTranscript(input) {
+ t.Fatal("assistant message alone must not be treated as full transcript")
+ }
+}
+
+func TestInputContainsFullTranscriptDetectsCompactionItem(t *testing.T) {
+ for _, typ := range []string{"compaction", "compaction_summary"} {
+ input := gjson.Parse(`[{"type":"message","role":"user","content":"hello"},{"type":"` + typ + `","encrypted_content":"summary"}]`)
+ if !inputContainsFullTranscript(input) {
+ t.Fatalf("expected full transcript for type=%s", typ)
+ }
+ }
+}
+
+func TestInputContainsFullTranscriptFalseForIncremental(t *testing.T) {
+ // Normal incremental turns: user messages or function_call_output only.
+ for _, raw := range []string{
+ `[{"type":"function_call_output","call_id":"call-1","output":"result"}]`,
+ `[{"type":"message","role":"user","content":"next question"}]`,
+ `[]`,
+ } {
+ if inputContainsFullTranscript(gjson.Parse(raw)) {
+ t.Fatalf("incremental input must not be detected as full transcript: %s", raw)
+ }
+ }
+}
+
+func TestNormalizeSubsequentRequestCompactSkipsMerge(t *testing.T) {
+ lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
+ {"type":"message","role":"user","id":"msg-1","content":"original long prompt"},
+ {"type":"message","role":"assistant","id":"msg-2","content":"original long response"},
+ {"type":"function_call","id":"fc-1","call_id":"call-old","name":"bash","arguments":"{}"},
+ {"type":"function_call_output","id":"fco-1","call_id":"call-old","output":"old result"}
+ ]}`)
+ lastResponseOutput := []byte(`[
+ {"type":"message","role":"assistant","id":"msg-3","content":"another assistant reply"},
+ {"type":"function_call","id":"fc-2","call_id":"call-stale","name":"read","arguments":"{}"}
+ ]`)
+
+ // Remote compact response: user messages + compaction item, NO assistant message.
+ // This is the primary compact scenario from Codex CLI.
+ raw := []byte(`{"type":"response.create","input":[
+ {"type":"message","role":"user","id":"msg-1c","content":"compacted user msg"},
+ {"type":"compaction","encrypted_content":"conversation summary"}
+ ]}`)
+
+ normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
+ if errMsg != nil {
+ t.Fatalf("unexpected error: %v", errMsg.Error)
+ }
+
+ input := gjson.GetBytes(normalized, "input").Array()
+ if len(input) != 2 {
+ t.Fatalf("input len = %d, want 2 (compacted only); stale state was not skipped", len(input))
+ }
+ if input[0].Get("id").String() != "msg-1c" {
+ t.Fatalf("input[0].id = %q, want %q", input[0].Get("id").String(), "msg-1c")
+ }
+ if input[1].Get("type").String() != "compaction" {
+ t.Fatalf("input[1].type = %q, want %q", input[1].Get("type").String(), "compaction")
+ }
+}
+
+func TestNormalizeSubsequentRequestCompactMergesWhenCompactionReplayUnsupported(t *testing.T) {
+ lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
+ {"type":"message","role":"user","id":"msg-1","content":"original long prompt"},
+ {"type":"message","role":"assistant","id":"msg-2","content":"original long response"},
+ {"type":"function_call","id":"fc-1","call_id":"call-old","name":"bash","arguments":"{}"},
+ {"type":"function_call_output","id":"fco-1","call_id":"call-old","output":"old result"}
+ ]}`)
+ lastResponseOutput := []byte(`[
+ {"type":"message","role":"assistant","id":"msg-3","content":"another assistant reply"},
+ {"type":"function_call","id":"fc-2","call_id":"call-stale","name":"read","arguments":"{}"}
+ ]`)
+ raw := []byte(`{"type":"response.create","input":[
+ {"type":"message","role":"user","id":"msg-1c","content":"compacted user msg"},
+ {"type":"compaction","encrypted_content":"conversation summary"}
+ ]}`)
+
+ normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
+ if errMsg != nil {
+ t.Fatalf("unexpected error: %v", errMsg.Error)
+ }
+
+ input := gjson.GetBytes(normalized, "input").Array()
+ if len(input) != 7 {
+ t.Fatalf("input len = %d, want 7 (merged fallback without compaction items)", len(input))
+ }
+ wantIDs := []string{"msg-1", "msg-2", "fc-1", "fco-1", "msg-3", "fc-2", "msg-1c"}
+ for i, want := range wantIDs {
+ got := input[i].Get("id").String()
+ if got != want {
+ t.Fatalf("input[%d].id = %q, want %q", i, got, want)
+ }
+ }
+ for _, item := range input {
+ if item.Get("type").String() == "compaction" || item.Get("type").String() == "compaction_summary" {
+ t.Fatalf("compaction items must be stripped for unsupported downstream fallback: %s", item.Raw)
+ }
+ }
+}
+
+func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
+ // Normal incremental flow: user sends function_call_output (no assistant message).
+ lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
+ {"type":"message","role":"user","id":"msg-1","content":"hello"}
+ ]}`)
+ lastResponseOutput := []byte(`[
+ {"type":"message","role":"assistant","id":"msg-2","content":"let me check"},
+ {"type":"function_call","id":"fc-1","call_id":"call-1","name":"bash","arguments":"{}"}
+ ]`)
+ raw := []byte(`{"type":"response.create","input":[
+ {"type":"function_call_output","call_id":"call-1","id":"fco-1","output":"done"}
+ ]}`)
+
+ normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
+ if errMsg != nil {
+ t.Fatalf("unexpected error: %v", errMsg.Error)
+ }
+
+ input := gjson.GetBytes(normalized, "input").Array()
+
+ // Should be merged: msg-1 + msg-2 + fc-1 + fco-1 = 4 items
+ if len(input) != 4 {
+ t.Fatalf("input len = %d, want 4 (merged)", len(input))
+ }
+ wantIDs := []string{"msg-1", "msg-2", "fc-1", "fco-1"}
+ for i, want := range wantIDs {
+ got := input[i].Get("id").String()
+ if got != want {
+ t.Fatalf("input[%d].id = %q, want %q", i, got, want)
+ }
+ }
+}
+
+func TestNormalizeSubsequentRequestAssistantInputTriggersTranscriptReplacement(t *testing.T) {
+ // After dev's shouldReplaceWebsocketTranscript, assistant messages in input
+ // trigger transcript replacement (no merge with prior state).
+ lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
+ {"type":"message","role":"user","id":"msg-1","content":"hello"}
+ ]}`)
+ lastResponseOutput := []byte(`[
+ {"type":"message","role":"assistant","id":"msg-2","content":"prior assistant"},
+ {"type":"function_call","id":"fc-1","call_id":"call-1","name":"bash","arguments":"{}"}
+ ]`)
+ raw := []byte(`{"type":"response.append","input":[
+ {"type":"message","role":"assistant","id":"msg-3","content":"patched assistant turn"}
+ ]}`)
+
+ normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
+ if errMsg != nil {
+ t.Fatalf("unexpected error: %v", errMsg.Error)
+ }
+
+ input := gjson.GetBytes(normalized, "input").Array()
+ if len(input) != 1 {
+ t.Fatalf("input len = %d, want 1 (transcript replacement, not merge)", len(input))
+ }
+ if input[0].Get("id").String() != "msg-3" {
+ t.Fatalf("input[0].id = %q, want %q", input[0].Get("id").String(), "msg-3")
+ }
+}
diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go
index 6571518d..ab3eca49 100644
--- a/sdk/cliproxy/auth/conductor.go
+++ b/sdk/cliproxy/auth/conductor.go
@@ -22,6 +22,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
+ coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
log "github.com/sirupsen/logrus"
)
@@ -827,6 +828,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
+ ctx = contextWithRequestedModelAlias(ctx, opts, routeModel)
var lastErr error
for idx, execModel := range execModels {
resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled)
@@ -1126,6 +1128,9 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
+ auth.Success = existing.Success
+ auth.Failed = existing.Failed
+ auth.recentRequests = existing.recentRequests
if !existing.Disabled && existing.Status != StatusDisabled && !auth.Disabled && auth.Status != StatusDisabled {
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
auth.ModelStates = existing.ModelStates
@@ -1316,6 +1321,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
+ execCtx = contextWithRequestedModelAlias(execCtx, opts, routeModel)
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
@@ -1394,6 +1400,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
+ execCtx = contextWithRequestedModelAlias(execCtx, opts, routeModel)
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
@@ -1531,6 +1538,36 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
}
}
+func contextWithRequestedModelAlias(ctx context.Context, opts cliproxyexecutor.Options, fallback string) context.Context {
+ alias := requestedModelAliasFromOptions(opts, fallback)
+ return coreusage.WithRequestedModelAlias(ctx, alias)
+}
+
+func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback string) string {
+ fallback = strings.TrimSpace(fallback)
+ if len(opts.Metadata) == 0 {
+ return fallback
+ }
+ raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
+ if !ok || raw == nil {
+ return fallback
+ }
+ switch value := raw.(type) {
+ case string:
+ if strings.TrimSpace(value) == "" {
+ return fallback
+ }
+ return strings.TrimSpace(value)
+ case []byte:
+ if len(value) == 0 {
+ return fallback
+ }
+ return strings.TrimSpace(string(value))
+ default:
+ return fallback
+ }
+}
+
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
@@ -2021,6 +2058,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
now := time.Now()
+ auth.recordRecentRequest(now, result.Success)
+ if result.Success {
+ auth.Success++
+ } else {
+ auth.Failed++
+ }
if result.Success {
if result.Model != "" {
@@ -3087,6 +3130,7 @@ func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxy
creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt)
}
creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
+ creditsCtx = contextWithRequestedModelAlias(creditsCtx, creditsOpts, routeModel)
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
models := m.executionModelCandidates(c.auth, routeModel)
if len(models) == 0 {
diff --git a/sdk/cliproxy/auth/conductor_oauth_alias_suspension_test.go b/sdk/cliproxy/auth/conductor_oauth_alias_suspension_test.go
index 8bc779e5..b4b72204 100644
--- a/sdk/cliproxy/auth/conductor_oauth_alias_suspension_test.go
+++ b/sdk/cliproxy/auth/conductor_oauth_alias_suspension_test.go
@@ -10,20 +10,23 @@ import (
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
+ coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
type aliasRoutingExecutor struct {
id string
- mu sync.Mutex
- executeModels []string
+ mu sync.Mutex
+ executeModels []string
+ executeAliases []string
}
func (e *aliasRoutingExecutor) Identifier() string { return e.id }
-func (e *aliasRoutingExecutor) Execute(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
+func (e *aliasRoutingExecutor) Execute(ctx context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
+ e.executeAliases = append(e.executeAliases, coreusage.RequestedModelAliasFromContext(ctx))
e.mu.Unlock()
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
@@ -52,6 +55,14 @@ func (e *aliasRoutingExecutor) ExecuteModels() []string {
return out
}
+func (e *aliasRoutingExecutor) ExecuteAliases() []string {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ out := make([]string, len(e.executeAliases))
+ copy(out, e.executeAliases)
+ return out
+}
+
func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
const (
provider = "antigravity"
@@ -108,4 +119,12 @@ func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
if gotModels[0] != targetModel {
t.Fatalf("execute model = %q, want %q", gotModels[0], targetModel)
}
+
+ gotAliases := executor.ExecuteAliases()
+ if len(gotAliases) != 1 {
+ t.Fatalf("execute aliases len = %d, want 1", len(gotAliases))
+ }
+ if gotAliases[0] != routeModel {
+ t.Fatalf("execute alias = %q, want %q", gotAliases[0], routeModel)
+ }
}
diff --git a/sdk/cliproxy/auth/conductor_recent_requests_test.go b/sdk/cliproxy/auth/conductor_recent_requests_test.go
new file mode 100644
index 00000000..d2003b7c
--- /dev/null
+++ b/sdk/cliproxy/auth/conductor_recent_requests_test.go
@@ -0,0 +1,95 @@
+package auth
+
+import (
+ "context"
+ "testing"
+ "time"
+)
+
+func TestManagerMarkResultRecordsRecentRequests(t *testing.T) {
+ mgr := NewManager(nil, nil, nil)
+ auth := &Auth{
+ ID: "auth-1",
+ Provider: "antigravity",
+ Attributes: map[string]string{
+ "runtime_only": "true",
+ },
+ Metadata: map[string]any{
+ "type": "antigravity",
+ },
+ }
+
+ if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
+ t.Fatalf("Register returned error: %v", err)
+ }
+
+ mgr.MarkResult(context.Background(), Result{AuthID: "auth-1", Provider: "antigravity", Model: "gpt-5", Success: true})
+ mgr.MarkResult(context.Background(), Result{AuthID: "auth-1", Provider: "antigravity", Model: "gpt-5", Success: false})
+
+ gotAuth, ok := mgr.GetByID("auth-1")
+ if !ok || gotAuth == nil {
+ t.Fatalf("GetByID returned ok=%v auth=%v", ok, gotAuth)
+ }
+
+ if gotAuth.Success != 1 || gotAuth.Failed != 1 {
+ t.Fatalf("auth totals = success=%d failed=%d, want 1/1", gotAuth.Success, gotAuth.Failed)
+ }
+
+ snapshot := gotAuth.RecentRequestsSnapshot(time.Now())
+ var successTotal int64
+ var failedTotal int64
+ for _, bucket := range snapshot {
+ successTotal += bucket.Success
+ failedTotal += bucket.Failed
+ }
+ if successTotal != 1 || failedTotal != 1 {
+ t.Fatalf("totals = success=%d failed=%d, want 1/1", successTotal, failedTotal)
+ }
+}
+
+func TestManagerUpdatePreservesRecentRequestsAndTotals(t *testing.T) {
+ mgr := NewManager(nil, nil, nil)
+ auth := &Auth{
+ ID: "auth-1",
+ Provider: "antigravity",
+ Metadata: map[string]any{
+ "type": "antigravity",
+ },
+ }
+ if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
+ t.Fatalf("Register returned error: %v", err)
+ }
+
+ mgr.MarkResult(context.Background(), Result{AuthID: "auth-1", Provider: "antigravity", Model: "gpt-5", Success: true})
+
+ updated := &Auth{
+ ID: "auth-1",
+ Provider: "antigravity",
+ Metadata: map[string]any{
+ "type": "antigravity",
+ "note": "updated",
+ },
+ }
+ if _, err := mgr.Update(WithSkipPersist(context.Background()), updated); err != nil {
+ t.Fatalf("Update returned error: %v", err)
+ }
+
+ gotAuth, ok := mgr.GetByID("auth-1")
+ if !ok || gotAuth == nil {
+ t.Fatalf("GetByID returned ok=%v auth=%v", ok, gotAuth)
+ }
+ if gotAuth.Success != 1 || gotAuth.Failed != 0 {
+ t.Fatalf("auth totals = success=%d failed=%d, want 1/0", gotAuth.Success, gotAuth.Failed)
+ }
+
+ snapshot := gotAuth.RecentRequestsSnapshot(time.Now())
+ var successTotal int64
+ var failedTotal int64
+ for _, bucket := range snapshot {
+ successTotal += bucket.Success
+ failedTotal += bucket.Failed
+ }
+ if successTotal != 1 || failedTotal != 0 {
+ t.Fatalf("bucket totals = success=%d failed=%d, want 1/0", successTotal, failedTotal)
+ }
+}
diff --git a/sdk/cliproxy/auth/types.go b/sdk/cliproxy/auth/types.go
index f30f4dc0..76f4c396 100644
--- a/sdk/cliproxy/auth/types.go
+++ b/sdk/cliproxy/auth/types.go
@@ -92,7 +92,32 @@ type Auth struct {
// Runtime carries non-serialisable data used during execution (in-memory only).
Runtime any `json:"-"`
- indexAssigned bool `json:"-"`
+ Success int64 `json:"-"`
+ Failed int64 `json:"-"`
+
+ recentRequests recentRequestRing `json:"-"`
+ indexAssigned bool `json:"-"`
+}
+
+const (
+ recentRequestBucketSeconds int64 = 10 * 60
+ recentRequestBucketCount = 20
+)
+
+type recentRequestBucket struct {
+ bucketID int64
+ success int64
+ failed int64
+}
+
+type recentRequestRing struct {
+ buckets [recentRequestBucketCount]recentRequestBucket
+}
+
+type RecentRequestBucket struct {
+ Time string `json:"time"`
+ Success int64 `json:"success"`
+ Failed int64 `json:"failed"`
}
// QuotaState contains limiter tracking data for a credential.
@@ -125,6 +150,70 @@ type ModelState struct {
UpdatedAt time.Time `json:"updated_at"`
}
+func recentRequestBucketID(now time.Time) int64 {
+ if now.IsZero() {
+ return 0
+ }
+ return now.Unix() / recentRequestBucketSeconds
+}
+
+func recentRequestBucketIndex(bucketID int64) int {
+ mod := bucketID % int64(recentRequestBucketCount)
+ if mod < 0 {
+ mod += int64(recentRequestBucketCount)
+ }
+ return int(mod)
+}
+
+func formatRecentRequestBucketLabel(bucketID int64) string {
+ start := time.Unix(bucketID*recentRequestBucketSeconds, 0).In(time.Local)
+ end := start.Add(time.Duration(recentRequestBucketSeconds) * time.Second)
+ return start.Format("15:04") + "-" + end.Format("15:04")
+}
+
+func (a *Auth) recordRecentRequest(now time.Time, success bool) {
+ if a == nil {
+ return
+ }
+ bucketID := recentRequestBucketID(now)
+ idx := recentRequestBucketIndex(bucketID)
+ bucket := &a.recentRequests.buckets[idx]
+ if bucket.bucketID != bucketID {
+ bucket.bucketID = bucketID
+ bucket.success = 0
+ bucket.failed = 0
+ }
+ if success {
+ bucket.success++
+ return
+ }
+ bucket.failed++
+}
+
+func (a *Auth) RecentRequestsSnapshot(now time.Time) []RecentRequestBucket {
+ out := make([]RecentRequestBucket, 0, recentRequestBucketCount)
+ if a == nil {
+ return out
+ }
+
+ currentBucketID := recentRequestBucketID(now)
+ for i := recentRequestBucketCount - 1; i >= 0; i-- {
+ bucketID := currentBucketID - int64(i)
+ idx := recentRequestBucketIndex(bucketID)
+ bucket := a.recentRequests.buckets[idx]
+ entry := RecentRequestBucket{
+ Time: formatRecentRequestBucketLabel(bucketID),
+ }
+ if bucket.bucketID == bucketID {
+ entry.Success = bucket.success
+ entry.Failed = bucket.failed
+ }
+ out = append(out, entry)
+ }
+
+ return out
+}
+
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
func (a *Auth) Clone() *Auth {
if a == nil {
diff --git a/sdk/cliproxy/auth/types_test.go b/sdk/cliproxy/auth/types_test.go
index e7029385..06836da1 100644
--- a/sdk/cliproxy/auth/types_test.go
+++ b/sdk/cliproxy/auth/types_test.go
@@ -1,6 +1,10 @@
package auth
-import "testing"
+import (
+ "strings"
+ "testing"
+ "time"
+)
func TestToolPrefixDisabled(t *testing.T) {
var a *Auth
@@ -96,3 +100,72 @@ func TestEnsureIndexUsesCredentialIdentity(t *testing.T) {
t.Fatalf("duplicate config entries should be separated by source-derived seed, got %q", geminiIndex)
}
}
+
+func TestRecentRequestsSnapshotEmptyReturnsTwentyBuckets(t *testing.T) {
+ now := time.Unix(1_700_000_000, 0).In(time.Local)
+ a := &Auth{}
+
+ got := a.RecentRequestsSnapshot(now)
+ if len(got) != recentRequestBucketCount {
+ t.Fatalf("len = %d, want %d", len(got), recentRequestBucketCount)
+ }
+
+ currentBucketID := now.Unix() / recentRequestBucketSeconds
+ baseBucketID := currentBucketID - int64(recentRequestBucketCount-1)
+ for i, bucket := range got {
+ if bucket.Success != 0 || bucket.Failed != 0 {
+ t.Fatalf("bucket[%d] counts = %d/%d, want 0/0", i, bucket.Success, bucket.Failed)
+ }
+ if strings.TrimSpace(bucket.Time) == "" {
+ t.Fatalf("bucket[%d] time label is empty", i)
+ }
+ expectedBucketID := baseBucketID + int64(i)
+ start := time.Unix(expectedBucketID*recentRequestBucketSeconds, 0).In(time.Local)
+ end := start.Add(10 * time.Minute)
+ expected := start.Format("15:04") + "-" + end.Format("15:04")
+ if bucket.Time != expected {
+ t.Fatalf("bucket[%d] time = %q, want %q", i, bucket.Time, expected)
+ }
+ }
+}
+
+func TestRecentRequestsSnapshotIncludesCounts(t *testing.T) {
+ now := time.Unix(1_700_000_000, 0).In(time.Local)
+ a := &Auth{}
+
+ a.recordRecentRequest(now, true)
+ a.recordRecentRequest(now, false)
+
+ got := a.RecentRequestsSnapshot(now)
+ if len(got) != recentRequestBucketCount {
+ t.Fatalf("len = %d, want %d", len(got), recentRequestBucketCount)
+ }
+
+ newest := got[len(got)-1]
+ if newest.Success != 1 || newest.Failed != 1 {
+ t.Fatalf("newest bucket = success=%d failed=%d, want 1/1", newest.Success, newest.Failed)
+ }
+}
+
+func TestRecentRequestsSnapshotBucketAdvanceMovesCounts(t *testing.T) {
+ now := time.Unix(1_700_000_000, 0).In(time.Local)
+ next := now.Add(10 * time.Minute)
+ a := &Auth{}
+
+ a.recordRecentRequest(now, true)
+ a.recordRecentRequest(next, false)
+
+ got := a.RecentRequestsSnapshot(next)
+ if len(got) != recentRequestBucketCount {
+ t.Fatalf("len = %d, want %d", len(got), recentRequestBucketCount)
+ }
+
+ secondNewest := got[len(got)-2]
+ newest := got[len(got)-1]
+ if secondNewest.Success != 1 || secondNewest.Failed != 0 {
+ t.Fatalf("second newest bucket = success=%d failed=%d, want 1/0", secondNewest.Success, secondNewest.Failed)
+ }
+ if newest.Success != 0 || newest.Failed != 1 {
+ t.Fatalf("newest bucket = success=%d failed=%d, want 0/1", newest.Success, newest.Failed)
+ }
+}
diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go
index d9613150..9f195f56 100644
--- a/sdk/cliproxy/service.go
+++ b/sdk/cliproxy/service.go
@@ -16,7 +16,6 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
- _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
diff --git a/sdk/cliproxy/usage/manager.go b/sdk/cliproxy/usage/manager.go
index c3d95f66..72405d75 100644
--- a/sdk/cliproxy/usage/manager.go
+++ b/sdk/cliproxy/usage/manager.go
@@ -2,6 +2,7 @@ package usage
import (
"context"
+ "strings"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
type Record struct {
Provider string
Model string
+ Alias string
APIKey string
AuthID string
AuthIndex string
@@ -32,6 +34,36 @@ type Detail struct {
TotalTokens int64
}
+type requestedModelAliasContextKey struct{}
+
+// WithRequestedModelAlias stores the client-requested model name for usage sinks.
+func WithRequestedModelAlias(ctx context.Context, alias string) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ alias = strings.TrimSpace(alias)
+ if alias == "" {
+ return ctx
+ }
+ return context.WithValue(ctx, requestedModelAliasContextKey{}, alias)
+}
+
+// RequestedModelAliasFromContext returns the client-requested model name stored in ctx.
+func RequestedModelAliasFromContext(ctx context.Context) string {
+ if ctx == nil {
+ return ""
+ }
+ raw := ctx.Value(requestedModelAliasContextKey{})
+ switch value := raw.(type) {
+ case string:
+ return strings.TrimSpace(value)
+ case []byte:
+ return strings.TrimSpace(string(value))
+ default:
+ return ""
+ }
+}
+
// Plugin consumes usage records emitted by the proxy runtime.
type Plugin interface {
HandleUsage(ctx context.Context, record Record)
diff --git a/test/usage_logging_test.go b/test/usage_logging_test.go
index 41c2ee34..ee03c4d7 100644
--- a/test/usage_logging_test.go
+++ b/test/usage_logging_test.go
@@ -2,6 +2,7 @@ package test
import (
"context"
+ "encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -9,14 +10,14 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
+ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
- internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
-func TestGeminiExecutorRecordsSuccessfulZeroUsageInStatistics(t *testing.T) {
+func TestGeminiExecutorRecordsSuccessfulZeroUsageInQueue(t *testing.T) {
model := fmt.Sprintf("gemini-2.5-flash-zero-usage-%d", time.Now().UnixNano())
source := fmt.Sprintf("zero-usage-%d@example.com", time.Now().UnixNano())
@@ -42,10 +43,15 @@ func TestGeminiExecutorRecordsSuccessfulZeroUsageInStatistics(t *testing.T) {
},
}
- prevStatsEnabled := internalusage.StatisticsEnabled()
- internalusage.SetStatisticsEnabled(true)
+ prevQueueEnabled := redisqueue.Enabled()
+ prevUsageEnabled := redisqueue.UsageStatisticsEnabled()
+ redisqueue.SetEnabled(false)
+ redisqueue.SetEnabled(true)
+ redisqueue.SetUsageStatisticsEnabled(true)
t.Cleanup(func() {
- internalusage.SetStatisticsEnabled(prevStatsEnabled)
+ redisqueue.SetEnabled(false)
+ redisqueue.SetEnabled(prevQueueEnabled)
+ redisqueue.SetUsageStatisticsEnabled(prevUsageEnabled)
})
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
@@ -59,39 +65,58 @@ func TestGeminiExecutorRecordsSuccessfulZeroUsageInStatistics(t *testing.T) {
t.Fatalf("Execute error: %v", err)
}
- detail := waitForStatisticsDetail(t, "gemini", model, source)
- if detail.Failed {
- t.Fatalf("detail failed = true, want false")
- }
- if detail.Tokens.TotalTokens != 0 {
- t.Fatalf("total tokens = %d, want 0", detail.Tokens.TotalTokens)
- }
+ waitForQueuedUsageModelTotalTokens(t, "gemini", model, 0)
}
-func waitForStatisticsDetail(t *testing.T, apiName, model, source string) internalusage.RequestDetail {
+func waitForQueuedUsageModelTotalTokens(t *testing.T, wantProvider, wantModel string, wantTokens int64) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
- snapshot := internalusage.GetRequestStatistics().Snapshot()
- apiSnapshot, ok := snapshot.APIs[apiName]
- if !ok {
- time.Sleep(10 * time.Millisecond)
- continue
- }
- modelSnapshot, ok := apiSnapshot.Models[model]
- if !ok {
- time.Sleep(10 * time.Millisecond)
- continue
- }
- for _, detail := range modelSnapshot.Details {
- if detail.Source == source {
- return detail
+ items := redisqueue.PopOldest(10)
+ for _, item := range items {
+ got, ok := parseQueuedUsagePayload(t, item)
+ if !ok {
+ continue
}
+ if got.Provider != wantProvider || got.Model != wantModel {
+ continue
+ }
+ if got.Failed {
+ t.Fatalf("payload failed = true, want false")
+ }
+ if got.Tokens.TotalTokens != wantTokens {
+ t.Fatalf("payload total tokens = %d, want %d", got.Tokens.TotalTokens, wantTokens)
+ }
+ return
}
time.Sleep(10 * time.Millisecond)
}
- t.Fatalf("timed out waiting for statistics detail for api=%q model=%q source=%q", apiName, model, source)
- return internalusage.RequestDetail{}
+ t.Fatalf("timed out waiting for queued usage payload for provider=%q model=%q", wantProvider, wantModel)
+}
+
+type queuedUsagePayload struct {
+ Provider string `json:"provider"`
+ Model string `json:"model"`
+ Failed bool `json:"failed"`
+ Tokens struct {
+ TotalTokens int64 `json:"total_tokens"`
+ } `json:"tokens"`
+}
+
+func parseQueuedUsagePayload(t *testing.T, payload []byte) (queuedUsagePayload, bool) {
+ t.Helper()
+
+ var parsed queuedUsagePayload
+ if len(payload) == 0 {
+ return parsed, false
+ }
+ if err := json.Unmarshal(payload, &parsed); err != nil {
+ return parsed, false
+ }
+ if parsed.Provider == "" || parsed.Model == "" {
+ return parsed, false
+ }
+ return parsed, true
}
| |