Merge branch 'router-for-me:main' into my-fix

This commit is contained in:
AhDEV
2026-05-06 16:41:14 +08:00
committed by GitHub
81 changed files with 4470 additions and 1903 deletions
+1
View File
@@ -50,3 +50,4 @@ _bmad-output/*
# macOS
.DS_Store
._*
.gocache/
+2
View File
@@ -19,6 +19,8 @@ builds:
archives:
- id: "cli-proxy-api"
format: tar.gz
name_template: >-
{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{- if eq .Arch "arm64" -}}aarch64{{- else -}}{{ .Arch }}{{- end -}}
format_overrides:
- goos: windows
format: zip
+1 -1
View File
@@ -19,7 +19,7 @@ ARG BUILD_DATE=unknown
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/
FROM alpine:3.22.0
FROM alpine:3.23
RUN apk add --no-cache tzdata
+28 -16
View File
@@ -10,23 +10,19 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/
## Sponsor
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
[![https://www.packyapi.com/register?aff=cliproxyapi](./assets/packycode-en.png)](https://www.packyapi.com/register?aff=cliproxyapi)
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
Thanks to PackyCode for sponsoring this project!
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & GLM-5 Only Available for Pro Usersmodel across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more.
Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
</tr>
@@ -35,10 +31,6 @@ Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
<td>Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through <a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus - Premium AI Accounts & Top-ups</a>, users can unlock the mind-blowing rate of <b>10% of the official GPT subscription price (90% OFF)</b>!</td>
</tr>
<tr>
<td width="180"><a href="https://poixe.com/i/m8kvep"><img src="./assets/poixeai.png" alt="PoixeAI" width="150"></a></td>
<td>Thanks to Poixe AI for sponsoring this project! Poixe AI provides reliable LLM API services. You can leverage the platform's API endpoints to seamlessly build AI-powered products. Additionally, you can become a vendor by providing AI API resources to the platform and earn revenue. Register through the exclusive CLIProxyAPI <a href="https://poixe.com/i/m8kvep">referral link</a> and receive a bonus of $5 USD on your first top-up.</td>
</tr>
<tr>
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
<td>Thanks to VisionCoder for supporting this project. <a href="https://coder.visioncoder.cn" target="_blank">VisionCoder Developer Platform</a> is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity.
<p></p>
@@ -53,7 +45,7 @@ VisionCoder is also offering our users a limited-time <a href="https://coder.vis
- OpenAI Codex support (GPT models) via OAuth login
- Claude Code support via OAuth login
- Amp CLI and IDE extensions support with provider routing
- Streaming and non-streaming responses
- Streaming, non-streaming, and WebSocket responses where supported
- Function calling/tools support
- Multimodal input support (text and images)
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude)
@@ -74,6 +66,18 @@ CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/)
see [MANAGEMENT_API.md](https://help.router-for.me/management/api)
## Usage Statistics
Since v6.10.0, CLIProxyAPI and [CPAMC](https://github.com/router-for-me/Cli-Proxy-API-Management-Center) no longer ship built-in usage statistics. If you need usage statistics, use:
### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
Standalone persistence and visualization service for CLIProxyAPI, with periodic data sync, SQLite storage, aggregate APIs, and a built-in dashboard for usage and statistics.
### [CLIProxyAPI Usage Dashboard](https://github.com/zhanglunet/cliproxyapi-usage-dashboard)
Local-first usage and quota dashboard for CLIProxyAPI. It collects per-request token usage from the Redis-compatible usage queue into SQLite, visualizes daily and recent-window usage by account and model, and shows Codex 5h/7d quota remaining in a local web UI.
## Amp CLI Support
CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools:
@@ -122,7 +126,7 @@ Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with A
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
A cross-platform desktop and web app to translate and validate SRT subtitles using your existing LLM subscriptions (Gemini, ChatGPT, Claude, etc.) via CLIProxyAPI - no API keys needed.
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
@@ -183,9 +187,13 @@ Cross-platform desktop app (macOS, Windows, Linux) wrapping CLIProxyAPI with a n
Ready-to-use cross-platform quota inspector for CLIProxyAPI, supporting per-account codex 5h/7d quota windows, plan-based sorting, status coloring, and multi-account summary analytics.
### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
### [CodexCliPlus](https://github.com/C4AL/CodexCliPlus)
Standalone persistence and visualization service for CLIProxyAPI, with periodic data sync, SQLite storage, aggregate APIs, and a built-in dashboard for usage and statistics.
Windows-focused, local-first desktop management platform for Codex CLI built on CLIProxyAPI, focused on simplifying local setup, account and runtime management, and providing a more complete Codex CLI experience for local users.
### [CLIProxy Pool Watch](https://github.com/murasame612/CLIProxyPoolWidget)
Native macOS SwiftUI app for monitoring ChatGPT/Codex account quotas in CLIProxyAPI pools. Displays account availability, Plus-base capacity, 5-hour and weekly quota bars, plan weights, and restore forecasts through the Management API.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
@@ -204,6 +212,10 @@ Never stop coding. Smart routing to FREE & low-cost AI models with automatic fal
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
### [Playful Proxy API Panel (PPAP)](https://github.com/daishuge/playful-proxy-api-panel)
A public CLIProxyAPI-compatible fork and bundled management panel. It keeps upstream-style usage while restoring built-in usage statistics, adding cache hit rate, first-byte latency, TPS tracking, and Docker-oriented self-hosted installation docs.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
+28 -16
View File
@@ -10,23 +10,19 @@
## 赞助商
[![bigmodel.cn](https://assets.router-for.me/chinese-5-0.jpg)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
[![https://www.packyapi.com/register?aff=cliproxyapi](./assets/packycode-cn.png)](https://www.packyapi.com/register?aff=cliproxyapi)
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
感谢 PackyCode 对本项目的赞助!
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验
PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转
智谱AI为本产品提供了特别优惠使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
PackyCode 为本软件用户提供了特别优惠使用<a href="https://www.packyapi.com/register?aff=cliproxyapi" target="_blank">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi" target="_blank">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF" target="_blank">此链接</a>注册的用户,可享受首充8折,企业客户最高可享 7.5 折!</td>
</tr>
@@ -35,10 +31,6 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
<td>感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过<a href="https://shop.bmoplus.com/?utm_source=github" target="_blank">BmoPlus AI成品号专卖/代充</a>注册下单的用户,可享GPT <b>官网订阅一折</b> 的震撼价格!</td>
</tr>
<tr>
<td width="180"><a href="https://poixe.com/i/m8kvep"><img src="./assets/poixeai.png" alt="PoixeAI" width="150"></a></td>
<td>感谢 Poixe AI 对本项目的赞助!Poixe AI 提供可靠的 AI 模型接口服务,您可以使用平台提供的 LLM API 接口轻松构建 AI 产品,同时也可以成为供应商,为平台提供大模型资源以赚取收益。通过 CLIProxyAPI <a href="https://poixe.com/i/m8kvep" target="_blank">专属链接</a>注册,充值额外赠送 $5 美金</td>
</tr>
<tr>
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
<td>感谢 VisionCoder 对本项目的支持。<a href="https://coder.visioncoder.cn" target="_blank">VisionCoder 开发平台</a> 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。
<p></p>
@@ -53,7 +45,7 @@ VisionCoder 还为我们的用户提供 <a href="https://coder.visioncoder.cn" t
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
- 新增 OpenAI CodexGPT 系列)支持(OAuth 登录)
- 新增 Claude Code 支持(OAuth 登录)
- 支持流式非流式响应
- 支持流式非流式响应,以及受支持场景下的 WebSocket 响应
- 函数调用/工具支持
- 多模态输入(文本、图片)
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude
@@ -74,6 +66,18 @@ CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-fo
请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api)
## 使用量统计
自v6.10.0版本以后,CLIProxyAPI及 [CPAMC](https://github.com/router-for-me/Cli-Proxy-API-Management-Center) 项目不再预置数据统计功能,如果有数据统计需求的请使用以下项目:
### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
独立的 CLIProxyAPI 使用量持久化与可视化服务,定期同步 CLIProxyAPI 数据,存储到 SQLite,提供聚合 API,并内置使用量分析与统计仪表盘。
### [CLIProxyAPI Usage Dashboard](https://github.com/zhanglunet/cliproxyapi-usage-dashboard)
面向 CLIProxyAPI 的本地优先使用量与配额看板。它从 Redis 兼容使用量队列采集每次请求的 Token 消耗并写入 SQLite,按账号和模型可视化每日及最近时间窗口的用量,并在本地网页中显示 Codex 5h/7d 配额余量。
## Amp CLI 支持
CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具:
@@ -121,7 +125,7 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
一款跨平台的桌面和 Web 应用程序,可通过 CLIProxyAPI 使用您现有的 LLM 订阅(Gemini、ChatGPT、Claude, etc.)来翻译和验证 SRT 字幕 - 无需 API 密钥。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
@@ -179,9 +183,13 @@ Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口
上手即用的面向 CLIProxyAPI 跨平台配额查询工具,支持按账号展示 codex 5h/7d 配额窗口、按计划排序、状态着色及多账号汇总分析。
### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
### [CodexCliPlus](https://github.com/C4AL/CodexCliPlus)
独立的 CLIProxyAPI 使用量持久化与可视化服务,定期同步 CPA 数据,存储到 SQLite,提供聚合 API,并内置使用量分析与统计仪表盘
基于 CLIProxyAPI 的 Windows Codex CLI 本地优先桌面管理平台,聚焦简化本机配置、账号与运行状态管理,并为本地用户提供更完整的 Codex CLI 使用体验
### [CLIProxy Pool Watch](https://github.com/murasame612/CLIProxyPoolWidget)
原生 macOS SwiftUI 应用,用于监控 CLIProxyAPI 池中的 ChatGPT/Codex 账号额度。通过 Management API 展示账号可用状态、Plus 基准容量、5 小时与周额度进度条、套餐权重和恢复预测。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
@@ -200,6 +208,10 @@ Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
### [Playful Proxy API Panel (PPAP)](https://github.com/daishuge/playful-proxy-api-panel)
一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
+28 -16
View File
@@ -10,23 +10,19 @@ OAuth経由でOpenAI CodexGPTモデル)およびClaude Codeもサポート
## スポンサー
[![z.ai](https://assets.router-for.me/english-5-0.jpg)](https://z.ai/subscribe?ic=8JVLJQFSKB)
[![https://www.packyapi.com/register?aff=cliproxyapi](./assets/packycode-en.png)](https://www.packyapi.com/register?aff=cliproxyapi)
本プロジェクトはZ.aiにスポンサーされており、GLM CODING PLANの提供を受けています
PackyCodeのスポンサーシップに感謝します
GLM CODING PLANはAIコーディング向けに設計されたサブスクリプションサービスで、月額わずか$10から利用可能です。フラッグシップのGLM-4.7および(GLM-5はProユーザーのみ利用可能)モデルを10以上の人気AIコーディングツール(Claude Code、Cline、Roo Codeなど)で利用でき、開発者にトップクラスの高速かつ安定したコーディング体験を提供します。
PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。
GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。
---
<table>
<tbody>
<tr>
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>PackyCodeのスポンサーシップに感謝します!PackyCodeは信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどのリレーサービスを提供しています。PackyCodeは当ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.packyapi.com/register?aff=cliproxyapi">こちらのリンク</a>から登録し、チャージ時にプロモーションコード「cliproxyapi」を入力すると10%割引になります。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
</tr>
@@ -35,10 +31,6 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
</tr>
<tr>
<td width="180"><a href="https://poixe.com/i/m8kvep"><img src="./assets/poixeai.png" alt="PoixeAI" width="150"></a></td>
<td>Poixe AIのスポンサーシップに感謝します!Poixe AIは信頼できるAIモデルAPIサービスを提供しており、プラットフォームが提供するLLM APIを使って簡単にAI製品を構築できます。また、サプライヤーとしてプラットフォームに大規模モデルのリソースを提供し、収益を得ることも可能です。CLIProxyAPIの<a href="https://poixe.com/i/m8kvep">専用リンク</a>から登録すると、チャージ時に追加で$5が付与されます。</td>
</tr>
<tr>
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
<td>VisionCoderのご支援に感謝します!<a href="https://coder.visioncoder.cn">VisionCoder 開発プラットフォーム</a> は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに <a href="https://coder.visioncoder.cn">Token Plan</a> の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。</td>
</tr>
@@ -51,7 +43,7 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
- OAuthログインによるOpenAI Codexサポート(GPTモデル)
- OAuthログインによるClaude Codeサポート
- プロバイダールーティングによるAmp CLIおよびIDE拡張機能のサポート
- ストリーミングおよび非ストリーミングレスポンス
- ストリーミング非ストリーミング、および対応環境でのWebSocketレスポンス
- 関数呼び出し/ツールのサポート
- マルチモーダル入力サポート(テキストと画像)
- ラウンドロビン負荷分散による複数アカウント対応(Gemini、OpenAI、Claude
@@ -72,6 +64,18 @@ CLIProxyAPIガイド:[https://help.router-for.me/](https://help.router-for.me/
[MANAGEMENT_API.md](https://help.router-for.me/management/api)を参照
## 使用量統計
v6.10.0以降、CLIProxyAPIおよび [CPAMC](https://github.com/router-for-me/Cli-Proxy-API-Management-Center) プロジェクトには使用量統計機能がプリセットされなくなりました。使用量統計が必要な場合は、次のプロジェクトをご利用ください:
### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
CLIProxyAPI向けの独立した使用量永続化・可視化サービス。CLIProxyAPIデータを定期同期してSQLiteに保存し、集計APIと、使用量や各種統計を確認できる組み込みダッシュボードを提供します。
### [CLIProxyAPI Usage Dashboard](https://github.com/zhanglunet/cliproxyapi-usage-dashboard)
CLIProxyAPI向けのローカル優先の使用量・クォータダッシュボード。Redis互換の使用量キューからリクエストごとのToken使用量を収集してSQLiteに保存し、アカウント別・モデル別の日次および直近時間枠の使用量を可視化し、Codex 5h/7dクォータ残量をローカルWeb UIで表示します。
## Amp CLIサポート
CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統合サポートを含んでおり、Google/ChatGPT/ClaudeのOAuthサブスクリプションをAmpのコーディングツールで使用できます:
@@ -120,7 +124,7 @@ macOSネイティブのメニューバーアプリで、Claude CodeとChatGPTの
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
CLIProxyAPI経由でGeminiサブスクリプションを使用してSRT字幕を翻訳するブラウザベースのツール。自動検証/エラー修正機能付き - APIキー不要
CLIProxyAPI経由で既存のLLMサブスクリプションGemini、ChatGPT、Claude, etc.)を使用してSRT字幕を翻訳および検証する、クロスプラットフォームのデスクトップおよびWebアプリ - APIキー不要
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
@@ -178,9 +182,13 @@ CLIProxyAPIをネイティブGUIでラップしたクロスプラットフォー
CLIProxyAPI向けのすぐに使えるクロスプラットフォームのクォータ確認ツール。アカウントごとの codex 5h/7d クォータ表示、プラン別ソート、ステータス色分け、複数アカウントの集計分析に対応。
### [CPA Usage Keeper](https://github.com/Willxup/cpa-usage-keeper)
### [CodexCliPlus](https://github.com/C4AL/CodexCliPlus)
CLIProxyAPI向けの独立した使用量永続化・可視化サービス。CPAデータを定期同期してSQLiteに保存し、集計APIと、使用量や各種統計を確認できる組み込みダッシュボードを提供します。
CLIProxyAPIを基盤にしたWindows向けのローカル優先Codex CLIデスクトップ管理プラットフォーム。ローカル設定、アカウント、実行状態の管理を簡素化し、ローカルユーザーにより包括的なCodex CLI体験を提供します。
### [CLIProxy Pool Watch](https://github.com/murasame612/CLIProxyPoolWidget)
CLIProxyAPIプール内のChatGPT/Codexアカウントクォータを監視するmacOSネイティブSwiftUIアプリ。Management APIを通じて、アカウントの可用性、Plus基準の容量、5時間/週次クォータバー、プラン重み、復元予測を表示します。
> [!NOTE]
> CLIProxyAPIをベースにプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
@@ -199,6 +207,10 @@ CLIProxyAPIに触発されたNext.js実装。インストールと使用が簡
OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:スマートルーティング、負荷分散、リトライ、フォールバックを備えたOpenAI互換エンドポイント。ポリシー、レート制限、キャッシュ、可観測性を追加して、信頼性が高くコストを意識した推論を実現します。
### [Playful Proxy API Panel (PPAP)](https://github.com/daishuge/playful-proxy-api-panel)
上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。
> [!NOTE]
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
Binary file not shown.

After

Width:  |  Height:  |  Size: 170 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 401 KiB

+3 -2
View File
@@ -24,11 +24,11 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -417,7 +417,8 @@ func main() {
configFileExists = true
}
}
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
redisqueue.SetUsageStatisticsEnabled(cfg.UsageStatisticsEnabled)
redisqueue.SetRetentionSeconds(cfg.RedisUsageQueueRetentionSeconds)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg); err != nil {
+4
View File
@@ -66,6 +66,10 @@ error-logs-max-files: 10
# When false, disable in-memory usage statistics aggregation
usage-statistics-enabled: false
# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP).
# Default: 60. Max: 3600.
redis-usage-queue-retention-seconds: 60
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
+3 -129
View File
@@ -5,123 +5,13 @@
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Hidden feature: Preserve usage statistics across rebuilds
# Usage: ./docker-build.sh --with-usage
# First run prompts for management API key, saved to temp/stats/.api_secret
set -euo pipefail
STATS_DIR="temp/stats"
STATS_FILE="${STATS_DIR}/.usage_backup.json"
SECRET_FILE="${STATS_DIR}/.api_secret"
WITH_USAGE=false
get_port() {
if [[ -f "config.yaml" ]]; then
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
else
echo "8317"
fi
}
export_stats_api_secret() {
if [[ -f "${SECRET_FILE}" ]]; then
API_SECRET=$(cat "${SECRET_FILE}")
else
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
echo "First time using --with-usage. Management API key required."
read -r -p "Enter management key: " -s API_SECRET
echo
echo "${API_SECRET}" > "${SECRET_FILE}"
chmod 600 "${SECRET_FILE}"
fi
}
check_container_running() {
local port
port=$(get_port)
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
echo "Please start the container first or use without --with-usage flag."
if [[ "${1:-}" != "" ]]; then
echo "Error: unknown option '${1}'."
echo "Usage: ./docker-build.sh"
exit 1
fi
}
export_stats() {
local port
port=$(get_port)
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
check_container_running
echo "Exporting usage statistics..."
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
"http://localhost:${port}/v0/management/usage/export")
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
if [[ "${HTTP_CODE}" != "200" ]]; then
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
exit 1
fi
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
echo "Statistics exported to ${STATS_FILE}"
}
import_stats() {
local port
port=$(get_port)
echo "Importing usage statistics..."
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
-H "X-Management-Key: ${API_SECRET}" \
-H "Content-Type: application/json" \
-d @"${STATS_FILE}" \
"http://localhost:${port}/v0/management/usage/import")
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
if [[ "${IMPORT_CODE}" == "200" ]]; then
echo "Statistics imported successfully"
else
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
fi
rm -f "${STATS_FILE}"
}
wait_for_service() {
local port
port=$(get_port)
echo "Waiting for service to be ready..."
for i in {1..30}; do
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
break
fi
sleep 1
done
sleep 2
}
case "${1:-}" in
"")
;;
"--with-usage")
WITH_USAGE=true
export_stats_api_secret
;;
*)
echo "Error: unknown option '${1}'. Did you mean '--with-usage'?"
echo "Usage: ./docker-build.sh [--with-usage]"
exit 1
;;
esac
# --- Step 1: Choose Environment ---
echo "Please select an option:"
@@ -133,14 +23,7 @@ read -r -p "Enter choice [1-2]: " choice
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
docker compose up -d --remove-orphans --no-build
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -167,18 +50,9 @@ case "$choice" in
--build-arg COMMIT="${COMMIT}" \
--build-arg BUILD_DATE="${BUILD_DATE}"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
@@ -0,0 +1,107 @@
package management
import (
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type apiKeyUsageEntry struct {
Success int64 `json:"success"`
Failed int64 `json:"failed"`
RecentRequests []coreauth.RecentRequestBucket `json:"recent_requests"`
}
func mergeRecentRequestBuckets(dst, src []coreauth.RecentRequestBucket) []coreauth.RecentRequestBucket {
if len(dst) == 0 {
return src
}
if len(src) == 0 {
return dst
}
if len(dst) != len(src) {
n := len(dst)
if len(src) < n {
n = len(src)
}
for i := 0; i < n; i++ {
dst[i].Success += src[i].Success
dst[i].Failed += src[i].Failed
}
return dst
}
for i := range dst {
dst[i].Success += src[i].Success
dst[i].Failed += src[i].Failed
}
return dst
}
// GetAPIKeyUsage returns recent request buckets for all in-memory api_key auths,
// grouped by provider and keyed by "base_url|api_key".
func (h *Handler) GetAPIKeyUsage(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler not initialized"})
return
}
h.mu.Lock()
manager := h.authManager
h.mu.Unlock()
if manager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
now := time.Now()
out := make(map[string]map[string]apiKeyUsageEntry)
for _, auth := range manager.List() {
if auth == nil {
continue
}
kind, apiKey := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
continue
}
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
continue
}
baseURL := ""
if auth.Attributes != nil {
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
if baseURL == "" {
baseURL = strings.TrimSpace(auth.Attributes["base-url"])
}
}
compositeKey := baseURL + "|" + apiKey
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "" {
provider = "unknown"
}
recent := auth.RecentRequestsSnapshot(now)
providerBucket, ok := out[provider]
if !ok {
providerBucket = make(map[string]apiKeyUsageEntry)
out[provider] = providerBucket
}
if existing, exists := providerBucket[compositeKey]; exists {
existing.Success += auth.Success
existing.Failed += auth.Failed
existing.RecentRequests = mergeRecentRequestBuckets(existing.RecentRequests, recent)
providerBucket[compositeKey] = existing
continue
}
providerBucket[compositeKey] = apiKeyUsageEntry{
Success: auth.Success,
Failed: auth.Failed,
RecentRequests: recent,
}
}
c.JSON(http.StatusOK, out)
}
@@ -0,0 +1,95 @@
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func sumRecentRequestBuckets(buckets []coreauth.RecentRequestBucket) (int64, int64) {
var success int64
var failed int64
for _, bucket := range buckets {
success += bucket.Success
failed += bucket.Failed
}
return success, failed
}
func TestGetAPIKeyUsage_GroupsByProviderAndAPIKey(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
manager := coreauth.NewManager(nil, nil, nil)
if _, err := manager.Register(context.Background(), &coreauth.Auth{
ID: "codex-auth",
Provider: "codex",
Attributes: map[string]string{
"api_key": "codex-key",
"base_url": "https://codex.example.com",
},
}); err != nil {
t.Fatalf("register codex auth: %v", err)
}
if _, err := manager.Register(context.Background(), &coreauth.Auth{
ID: "claude-auth",
Provider: "claude",
Attributes: map[string]string{
"api_key": "claude-key",
"base_url": "https://claude.example.com",
},
}); err != nil {
t.Fatalf("register claude auth: %v", err)
}
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: true})
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: false})
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "claude-auth", Provider: "claude", Model: "claude-4", Success: true})
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodGet, "/v0/management/api-key-usage", nil)
ginCtx.Request = req
h.GetAPIKeyUsage(ginCtx)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload map[string]map[string]apiKeyUsageEntry
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode payload: %v", err)
}
codexEntry := payload["codex"]["https://codex.example.com|codex-key"]
if codexEntry.Success != 1 || codexEntry.Failed != 1 {
t.Fatalf("codex totals = %d/%d, want 1/1", codexEntry.Success, codexEntry.Failed)
}
if len(codexEntry.RecentRequests) != 20 {
t.Fatalf("codex buckets len = %d, want 20", len(codexEntry.RecentRequests))
}
codexSuccess, codexFailed := sumRecentRequestBuckets(codexEntry.RecentRequests)
if codexSuccess != 1 || codexFailed != 1 {
t.Fatalf("codex totals = %d/%d, want 1/1", codexSuccess, codexFailed)
}
claudeEntry := payload["claude"]["https://claude.example.com|claude-key"]
if claudeEntry.Success != 1 || claudeEntry.Failed != 0 {
t.Fatalf("claude totals = %d/%d, want 1/0", claudeEntry.Success, claudeEntry.Failed)
}
if len(claudeEntry.RecentRequests) != 20 {
t.Fatalf("claude buckets len = %d, want 20", len(claudeEntry.RecentRequests))
}
claudeSuccess, claudeFailed := sumRecentRequestBuckets(claudeEntry.RecentRequests)
if claudeSuccess != 1 || claudeFailed != 0 {
t.Fatalf("claude totals = %d/%d, want 1/0", claudeSuccess, claudeFailed)
}
}
+5 -15
View File
@@ -388,6 +388,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
"source": "memory",
"size": int64(0),
}
entry["success"] = auth.Success
entry["failed"] = auth.Failed
entry["recent_requests"] = auth.RecentRequestsSnapshot(time.Now())
if email := authEmail(auth); email != "" {
entry["email"] = email
}
@@ -2395,24 +2398,11 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// For free users, use backend project ID for preview model access
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
finalProjectID = responseProjectID
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s", responseProjectID)
}
} else {
finalProjectID = responseProjectID
}
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
if storage.ProjectID == "" {
@@ -0,0 +1,94 @@
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestListAuthFiles_IncludesRecentRequestsBuckets(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
manager := coreauth.NewManager(nil, nil, nil)
record := &coreauth.Auth{
ID: "runtime-only-auth-1",
Provider: "codex",
Attributes: map[string]string{
"runtime_only": "true",
},
Metadata: map[string]any{
"type": "codex",
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
h.tokenStore = &memoryAuthStore{}
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
ginCtx.Request = req
h.ListAuthFiles(ginCtx)
if rec.Code != http.StatusOK {
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
}
filesRaw, ok := payload["files"].([]any)
if !ok {
t.Fatalf("expected files array, payload: %#v", payload)
}
if len(filesRaw) != 1 {
t.Fatalf("expected 1 auth entry, got %d", len(filesRaw))
}
fileEntry, ok := filesRaw[0].(map[string]any)
if !ok {
t.Fatalf("expected file entry object, got %#v", filesRaw[0])
}
if _, ok := fileEntry["success"].(float64); !ok {
t.Fatalf("expected success number, got %#v", fileEntry["success"])
}
if _, ok := fileEntry["failed"].(float64); !ok {
t.Fatalf("expected failed number, got %#v", fileEntry["failed"])
}
recentRaw, ok := fileEntry["recent_requests"].([]any)
if !ok {
t.Fatalf("expected recent_requests array, got %#v", fileEntry["recent_requests"])
}
if len(recentRaw) != 20 {
t.Fatalf("expected 20 recent_requests buckets, got %d", len(recentRaw))
}
for idx, item := range recentRaw {
bucket, ok := item.(map[string]any)
if !ok {
t.Fatalf("expected bucket object at %d, got %#v", idx, item)
}
if _, ok := bucket["time"].(string); !ok {
t.Fatalf("expected bucket time string at %d, got %#v", idx, bucket["time"])
}
if _, ok := bucket["success"].(float64); !ok {
t.Fatalf("expected bucket success number at %d, got %#v", idx, bucket["success"])
}
if _, ok := bucket["failed"].(float64); !ok {
t.Fatalf("expected bucket failed number at %d, got %#v", idx, bucket["failed"])
}
}
}
@@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"golang.org/x/crypto/bcrypt"
@@ -41,7 +40,6 @@ type Handler struct {
attemptsMu sync.Mutex
failedAttempts map[string]*attemptInfo // keyed by client IP
authManager *coreauth.Manager
usageStats *usage.RequestStatistics
tokenStore coreauth.Store
localPassword string
allowRemoteOverride bool
@@ -60,7 +58,6 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
authManager: manager,
usageStats: usage.GetRequestStatistics(),
tokenStore: sdkAuth.GetTokenStore(),
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
@@ -124,9 +121,6 @@ func (h *Handler) SetAuthManager(manager *coreauth.Manager) {
h.mu.Unlock()
}
// SetUsageStatistics allows replacing the usage statistics reference.
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
+35 -59
View File
@@ -2,78 +2,54 @@ package management
import (
"encoding/json"
"errors"
"net/http"
"time"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
type usageQueueRecord []byte
func (r usageQueueRecord) MarshalJSON() ([]byte, error) {
if json.Valid(r) {
return append([]byte(nil), r...), nil
}
return json.Marshal(string(r))
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, gin.H{
"usage": snapshot,
"failed_requests": snapshot.FailureCount,
})
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
// GetUsageQueue pops queued usage records from the usage queue.
func (h *Handler) GetUsageQueue(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
count, errCount := parseUsageQueueCount(c.Query("count"))
if errCount != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errCount.Error()})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
items := redisqueue.PopOldest(count)
records := make([]usageQueueRecord, 0, len(items))
for _, item := range items {
records = append(records, usageQueueRecord(append([]byte(nil), item...)))
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
c.JSON(http.StatusOK, records)
}
func parseUsageQueueCount(value string) (int, error) {
value = strings.TrimSpace(value)
if value == "" {
return 1, nil
}
count, errCount := strconv.Atoi(value)
if errCount != nil || count <= 0 {
return 0, errors.New("count must be a positive integer")
}
return count, nil
}
@@ -0,0 +1,98 @@
package management
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
)
func TestGetUsageQueuePopsRequestedRecords(t *testing.T) {
gin.SetMode(gin.TestMode)
withManagementUsageQueue(t, func() {
redisqueue.Enqueue([]byte(`{"id":1}`))
redisqueue.Enqueue([]byte(`{"id":2}`))
redisqueue.Enqueue([]byte(`{"id":3}`))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
h := &Handler{}
h.GetUsageQueue(ginCtx)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload []json.RawMessage
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
t.Fatalf("unmarshal response: %v", errUnmarshal)
}
if len(payload) != 2 {
t.Fatalf("response records = %d, want 2", len(payload))
}
requireRecordID(t, payload[0], 1)
requireRecordID(t, payload[1], 2)
remaining := redisqueue.PopOldest(10)
if len(remaining) != 1 || string(remaining[0]) != `{"id":3}` {
t.Fatalf("remaining queue = %q, want third item only", remaining)
}
})
}
func TestGetUsageQueueInvalidCountDoesNotPop(t *testing.T) {
gin.SetMode(gin.TestMode)
withManagementUsageQueue(t, func() {
redisqueue.Enqueue([]byte(`{"id":1}`))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=0", nil)
h := &Handler{}
h.GetUsageQueue(ginCtx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
remaining := redisqueue.PopOldest(10)
if len(remaining) != 1 || string(remaining[0]) != `{"id":1}` {
t.Fatalf("remaining queue = %q, want original item", remaining)
}
})
}
func withManagementUsageQueue(t *testing.T, fn func()) {
t.Helper()
prevQueueEnabled := redisqueue.Enabled()
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(true)
defer func() {
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(prevQueueEnabled)
}()
fn()
}
func requireRecordID(t *testing.T, raw json.RawMessage, want int) {
t.Helper()
var payload struct {
ID int `json:"id"`
}
if errUnmarshal := json.Unmarshal(raw, &payload); errUnmarshal != nil {
t.Fatalf("unmarshal record: %v", errUnmarshal)
}
if payload.ID != want {
t.Fatalf("record id = %d, want %d", payload.ID, want)
}
}
@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// ampCanonicalToolNames maps tool names to the exact casing expected by the
// Amp mode tool whitelist (case-sensitive match).
var ampCanonicalToolNames = map[string]string{
"bash": "Bash",
"read": "Read",
"grep": "Grep",
"glob": "glob",
"task": "Task",
"check": "Check",
}
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
// which causes Amp's case-sensitive mode whitelist to reject them.
func normalizeAmpToolNames(data []byte) []byte {
// Non-streaming: content[].name in tool_use blocks
for index, block := range gjson.GetBytes(data, "content").Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
path := fmt.Sprintf("content.%d.name", index)
var err error
data, err = sjson.SetBytes(data, path, canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
}
}
}
// Streaming: content_block.name in content_block_start events
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
name := gjson.GetBytes(data, "content_block.name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
var err error
data, err = sjson.SetBytes(data, "content_block.name", canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
}
}
}
return data
}
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = normalizeAmpToolNames(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Normalize tool names to canonical casing
data = normalizeAmpToolNames(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
}
}
func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
result := normalizeAmpToolNames(input)
if !contains(result, []byte(`"name":"Bash"`)) {
t.Errorf("expected bash->Bash, got %s", string(result))
}
if !contains(result, []byte(`"name":"Read"`)) {
t.Errorf("expected read->Read, got %s", string(result))
}
if contains(result, []byte(`"name":"bash"`)) {
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
result := normalizeAmpToolNames(input)
if !contains(result, []byte(`"name":"Grep"`)) {
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected glob to remain lowercase, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected no modification for unknown tool, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {
+7 -5
View File
@@ -31,7 +31,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
@@ -507,9 +506,6 @@ func (s *Server) registerManagementRoutes() {
mgmt := s.engine.Group("/v0/management")
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -554,6 +550,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
mgmt.GET("/api-key-usage", s.mgmt.GetAPIKeyUsage)
mgmt.GET("/usage-queue", s.mgmt.GetUsageQueue)
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
@@ -1000,7 +998,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
redisqueue.SetUsageStatisticsEnabled(cfg.UsageStatisticsEnabled)
}
if oldCfg == nil || oldCfg.RedisUsageQueueRetentionSeconds != cfg.RedisUsageQueueRetentionSeconds {
redisqueue.SetRetentionSeconds(cfg.RedisUsageQueueRetentionSeconds)
}
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
+63
View File
@@ -13,6 +13,7 @@ import (
gin "github.com/gin-gonic/gin"
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
@@ -84,6 +85,68 @@ func TestHealthz(t *testing.T) {
})
}
func TestManagementUsageRequiresManagementAuthAndPopsArray(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "test-management-key")
prevQueueEnabled := redisqueue.Enabled()
redisqueue.SetEnabled(false)
t.Cleanup(func() {
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(prevQueueEnabled)
})
server := newTestServer(t)
redisqueue.Enqueue([]byte(`{"id":1}`))
redisqueue.Enqueue([]byte(`{"id":2}`))
missingKeyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
missingKeyRR := httptest.NewRecorder()
server.engine.ServeHTTP(missingKeyRR, missingKeyReq)
if missingKeyRR.Code != http.StatusUnauthorized {
t.Fatalf("missing key status = %d, want %d body=%s", missingKeyRR.Code, http.StatusUnauthorized, missingKeyRR.Body.String())
}
legacyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage?count=2", nil)
legacyReq.Header.Set("Authorization", "Bearer test-management-key")
legacyRR := httptest.NewRecorder()
server.engine.ServeHTTP(legacyRR, legacyReq)
if legacyRR.Code != http.StatusNotFound {
t.Fatalf("legacy usage status = %d, want %d body=%s", legacyRR.Code, http.StatusNotFound, legacyRR.Body.String())
}
authReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
authReq.Header.Set("Authorization", "Bearer test-management-key")
authRR := httptest.NewRecorder()
server.engine.ServeHTTP(authRR, authReq)
if authRR.Code != http.StatusOK {
t.Fatalf("authenticated status = %d, want %d body=%s", authRR.Code, http.StatusOK, authRR.Body.String())
}
var payload []json.RawMessage
if errUnmarshal := json.Unmarshal(authRR.Body.Bytes(), &payload); errUnmarshal != nil {
t.Fatalf("unmarshal response: %v body=%s", errUnmarshal, authRR.Body.String())
}
if len(payload) != 2 {
t.Fatalf("response records = %d, want 2", len(payload))
}
for i, raw := range payload {
var record struct {
ID int `json:"id"`
}
if errUnmarshal := json.Unmarshal(raw, &record); errUnmarshal != nil {
t.Fatalf("unmarshal record %d: %v", i, errUnmarshal)
}
if record.ID != i+1 {
t.Fatalf("record %d id = %d, want %d", i, record.ID, i+1)
}
}
if remaining := redisqueue.PopOldest(1); len(remaining) != 0 {
t.Fatalf("remaining queue = %q, want empty", remaining)
}
}
func TestAmpProviderModelRoutes(t *testing.T) {
testCases := []struct {
name string
+134 -1
View File
@@ -6,15 +6,18 @@ package claude
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)
// OAuth configuration constants for Claude/Anthropic
@@ -23,8 +26,94 @@ const (
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
claudeRefreshMinBackoff = 5 * time.Second
claudeRefreshMaxBackoff = 5 * time.Minute
)
var (
claudeRefreshGroup singleflight.Group
claudeRefreshMu sync.Mutex
claudeRefreshBlock = make(map[string]time.Time)
)
type refreshHTTPError struct {
status int
message string
retryable bool
}
func (e *refreshHTTPError) Error() string {
return fmt.Sprintf("token refresh failed with status %d: %s", e.status, e.message)
}
func (e *refreshHTTPError) Retryable() bool {
return e != nil && e.retryable
}
func resetClaudeRefreshState() {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
claudeRefreshBlock = make(map[string]time.Time)
claudeRefreshGroup = singleflight.Group{}
}
func claudeRefreshBlockedUntil(refreshToken string) time.Time {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
return claudeRefreshBlock[refreshToken]
}
func setClaudeRefreshBlockedUntil(refreshToken string, until time.Time) {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
claudeRefreshBlock[refreshToken] = until
}
func clearClaudeRefreshBlockedUntil(refreshToken string) {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
delete(claudeRefreshBlock, refreshToken)
}
func clampClaudeRefreshBackoff(d time.Duration) time.Duration {
if d < claudeRefreshMinBackoff {
return claudeRefreshMinBackoff
}
if d > claudeRefreshMaxBackoff {
return claudeRefreshMaxBackoff
}
return d
}
func parseClaudeRetryAfter(resp *http.Response) time.Duration {
if resp == nil {
return claudeRefreshMinBackoff
}
if raw := strings.TrimSpace(resp.Header.Get("Retry-After")); raw != "" {
if seconds, err := time.ParseDuration(raw + "s"); err == nil {
return clampClaudeRefreshBackoff(seconds)
}
if when, err := http.ParseTime(raw); err == nil {
return clampClaudeRefreshBackoff(time.Until(when))
}
}
if raw := strings.TrimSpace(resp.Header.Get("Retry-After-Ms")); raw != "" {
if ms, err := time.ParseDuration(raw + "ms"); err == nil {
return clampClaudeRefreshBackoff(ms)
}
}
return claudeRefreshMinBackoff
}
func isClaudeRefreshRetryable(err error) bool {
var httpErr *refreshHTTPError
if errors.As(err, &httpErr) {
return httpErr.Retryable()
}
return true
}
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
// It contains access token, refresh token, and associated user/organization information.
type tokenResponse struct {
@@ -242,6 +331,35 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
if refreshToken == "" {
return nil, fmt.Errorf("refresh token is required")
}
if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
return nil, &refreshHTTPError{
status: http.StatusTooManyRequests,
message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
retryable: false,
}
}
result, err, _ := claudeRefreshGroup.Do(refreshToken, func() (interface{}, error) {
return o.refreshTokensSingleFlight(context.WithoutCancel(ctx), refreshToken)
})
if err != nil {
return nil, err
}
tokenData, ok := result.(*ClaudeTokenData)
if !ok || tokenData == nil {
return nil, fmt.Errorf("token refresh failed: invalid single-flight result")
}
return tokenData, nil
}
func (o *ClaudeAuth) refreshTokensSingleFlight(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
return nil, &refreshHTTPError{
status: http.StatusTooManyRequests,
message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
retryable: false,
}
}
reqBody := map[string]interface{}{
"client_id": ClientID,
@@ -276,7 +394,17 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
message := string(body)
if resp.StatusCode == http.StatusTooManyRequests {
retryAfter := parseClaudeRetryAfter(resp)
setClaudeRefreshBlockedUntil(refreshToken, time.Now().Add(retryAfter))
return nil, &refreshHTTPError{status: resp.StatusCode, message: message, retryable: false}
}
return nil, &refreshHTTPError{
status: resp.StatusCode,
message: message,
retryable: resp.StatusCode >= http.StatusInternalServerError,
}
}
// log.Debugf("Token response: %s", string(body))
@@ -287,6 +415,8 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
// Create token data
clearClaudeRefreshBlockedUntil(refreshToken)
return &ClaudeTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -348,6 +478,9 @@ func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken st
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
if !isClaudeRefreshRetryable(err) {
break
}
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
+123
View File
@@ -0,0 +1,123 @@
package claude
import (
"context"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_429BlocksImmediateReplay(t *testing.T) {
resetClaudeRefreshState()
defer resetClaudeRefreshState()
var calls int32
auth := &ClaudeAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Body: io.NopCloser(strings.NewReader(`{"error":"rate_limited"}`)),
Header: http.Header{"Retry-After": []string{"60"}},
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected 429 refresh error")
}
if !strings.Contains(err.Error(), "status 429") {
t.Fatalf("expected status 429 in error, got %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt after 429, got %d", got)
}
_, err = auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected immediate blocked refresh error")
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected blocked retry to avoid a second refresh call, got %d attempts", got)
}
if blockedUntil := claudeRefreshBlockedUntil("dummy_refresh_token"); !blockedUntil.After(time.Now()) {
t.Fatalf("expected blocked-until timestamp to be set, got %v", blockedUntil)
}
}
func TestRefreshTokens_DeduplicatesConcurrentRefresh(t *testing.T) {
resetClaudeRefreshState()
defer resetClaudeRefreshState()
var calls int32
started := make(chan struct{})
release := make(chan struct{})
var once sync.Once
auth := &ClaudeAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
once.Do(func() { close(started) })
<-release
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{
"access_token":"new-access",
"refresh_token":"new-refresh",
"token_type":"Bearer",
"expires_in":3600,
"account":{"email_address":"shared@example.com"}
}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
results := make(chan *ClaudeTokenData, 2)
errs := make(chan error, 2)
runRefresh := func() {
td, err := auth.RefreshTokens(context.Background(), "shared-refresh-token")
results <- td
errs <- err
}
go runRefresh()
go runRefresh()
<-started
time.Sleep(20 * time.Millisecond)
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
}
close(release)
for i := 0; i < 2; i++ {
if err := <-errs; err != nil {
t.Fatalf("expected refresh to succeed, got %v", err)
}
td := <-results
if td == nil || td.AccessToken != "new-access" {
t.Fatalf("expected refreshed access token, got %#v", td)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected exactly 1 upstream refresh call, got %d", got)
}
}
+3 -35
View File
@@ -333,43 +333,11 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// Interactive prompt for free users
fmt.Printf("\nGoogle returned a different project ID:\n")
fmt.Printf(" Requested (frontend): %s\n", projectID)
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
fmt.Printf(" This is normal for free tier users.\n\n")
fmt.Printf("Which project ID would you like to use?\n")
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
fmt.Printf("Enter choice [1]: ")
reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(choice)
if choice == "2" {
log.Infof("Using frontend project ID: %s", projectID)
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
finalProjectID = projectID
} else {
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s", responseProjectID)
}
finalProjectID = responseProjectID
}
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
if storage.ProjectID == "" {
+13
View File
@@ -65,6 +65,11 @@ type Config struct {
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
// are retained in memory for the Redis RESP interface (LPOP/RPOP).
// Default: 60. Max: 3600.
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
// DisableCooling disables quota cooldown scheduling when true.
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
@@ -609,6 +614,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.LogsMaxTotalSizeMB = 0
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
cfg.RedisUsageQueueRetentionSeconds = 60
cfg.DisableCooling = false
cfg.DisableImageGeneration = DisableImageGenerationOff
cfg.Pprof.Enable = false
@@ -671,6 +677,13 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
}
if cfg.RedisUsageQueueRetentionSeconds <= 0 {
cfg.RedisUsageQueueRetentionSeconds = 60
} else if cfg.RedisUsageQueueRetentionSeconds > 3600 {
log.WithField("value", cfg.RedisUsageQueueRetentionSeconds).Warn("redis-usage-queue-retention-seconds too large; clamping to 3600")
cfg.RedisUsageQueueRetentionSeconds = 3600
}
if cfg.MaxRetryCredentials < 0 {
cfg.MaxRetryCredentials = 0
}
+62
View File
@@ -0,0 +1,62 @@
package logging
import (
"context"
"sync/atomic"
)
type endpointKey struct{}
type responseStatusKey struct{}
type responseStatusHolder struct {
status atomic.Int32
}
func WithEndpoint(ctx context.Context, endpoint string) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, endpointKey{}, endpoint)
}
func GetEndpoint(ctx context.Context) string {
if ctx == nil {
return ""
}
if endpoint, ok := ctx.Value(endpointKey{}).(string); ok {
return endpoint
}
return ""
}
func WithResponseStatusHolder(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
if holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder); ok && holder != nil {
return ctx
}
return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{})
}
func SetResponseStatus(ctx context.Context, status int) {
if ctx == nil || status <= 0 {
return
}
holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
if !ok || holder == nil {
return
}
holder.status.Store(int32(status))
}
func GetResponseStatus(ctx context.Context) int {
if ctx == nil {
return 0
}
holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
if !ok || holder == nil {
return 0
}
return int(holder.status.Load())
}
+2 -2
View File
@@ -12,7 +12,7 @@ import (
const (
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
GeminiCLIVersion = "0.31.0"
GeminiCLIVersion = "0.34.0"
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
@@ -46,7 +46,7 @@ func GeminiCLIUserAgent(model string) string {
if model == "" {
model = "unknown"
}
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
}
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
+32 -42
View File
@@ -3,13 +3,10 @@ package redisqueue
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
@@ -23,7 +20,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if p == nil {
return
}
if !Enabled() || !internalusage.StatisticsEnabled() {
if !Enabled() || !UsageStatisticsEnabled() {
return
}
@@ -36,6 +33,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if modelName == "" {
modelName = "unknown"
}
aliasName := strings.TrimSpace(record.Alias)
if aliasName == "" {
aliasName = modelName
}
provider := strings.TrimSpace(record.Provider)
if provider == "" {
provider = "unknown"
@@ -46,13 +47,8 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
apiKey := strings.TrimSpace(record.APIKey)
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
if requestID == "" {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx))
}
}
tokens := internalusage.TokenStats{
tokens := tokenStats{
InputTokens: record.Detail.InputTokens,
OutputTokens: record.Detail.OutputTokens,
ReasoningTokens: record.Detail.ReasoningTokens,
@@ -71,7 +67,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
failed = !resolveSuccess(ctx)
}
detail := internalusage.RequestDetail{
detail := requestDetail{
Timestamp: timestamp,
LatencyMs: record.Latency.Milliseconds(),
Source: record.Source,
@@ -81,9 +77,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
payload, err := json.Marshal(queuedUsageDetail{
RequestDetail: detail,
requestDetail: detail,
Provider: provider,
Model: modelName,
Alias: aliasName,
Endpoint: resolveEndpoint(ctx),
AuthType: authType,
APIKey: apiKey,
@@ -96,50 +93,43 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
type queuedUsageDetail struct {
internalusage.RequestDetail
requestDetail
Provider string `json:"provider"`
Model string `json:"model"`
Alias string `json:"alias"`
Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"`
APIKey string `json:"api_key"`
RequestID string `json:"request_id"`
}
type requestDetail struct {
Timestamp time.Time `json:"timestamp"`
LatencyMs int64 `json:"latency_ms"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens tokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
type tokenStats struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
ReasoningTokens int64 `json:"reasoning_tokens"`
CachedTokens int64 `json:"cached_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
func resolveSuccess(ctx context.Context) bool {
if ctx == nil {
return true
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return true
}
status := ginCtx.Writer.Status()
status := internallogging.GetResponseStatus(ctx)
if status == 0 {
return true
}
return status < http.StatusBadRequest
return status < httpStatusBadRequest
}
func resolveEndpoint(ctx context.Context) string {
if ctx == nil {
return ""
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil || ginCtx.Request == nil {
return ""
return strings.TrimSpace(internallogging.GetEndpoint(ctx))
}
path := strings.TrimSpace(ginCtx.FullPath())
if path == "" && ginCtx.Request.URL != nil {
path = strings.TrimSpace(ginCtx.Request.URL.Path)
}
if path == "" {
return ""
}
method := strings.TrimSpace(ginCtx.Request.Method)
if method == "" {
return path
}
return method + " " + path
}
const httpStatusBadRequest = 400
+87 -10
View File
@@ -10,20 +10,21 @@ import (
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored")
ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx)
ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id")
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusOK)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
@@ -40,6 +41,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "ctx-request-id")
@@ -49,14 +51,16 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError)
internallogging.SetGinRequestID(ginCtx, "gin-request-id")
ctx := context.WithValue(context.Background(), "gin", ginCtx)
ctx := internallogging.WithRequestID(context.Background(), "gin-request-id")
ctx = internallogging.WithEndpoint(ctx, "GET /v1/responses")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4-mini",
Alias: "client-mini",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
@@ -73,6 +77,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4-mini")
requireStringField(t, payload, "alias", "client-mini")
requireStringField(t, payload, "endpoint", "GET /v1/responses")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "gin-request-id")
@@ -80,20 +85,63 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
})
}
func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
ctx := context.WithValue(context.Background(), "gin", ginCtx)
ctx = internallogging.WithRequestID(ctx, "ctx-request-id")
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
mgr := coreusage.NewManager(16)
defer mgr.Stop()
mgr.Register(pluginFunc(func(_ context.Context, _ coreusage.Record) {
ginCtx.Request = httptest.NewRequest(http.MethodGet, "http://example.com/v1/responses", nil)
ginCtx.Status(http.StatusOK)
}))
mgr.Register(&usageQueuePlugin{})
mgr.Publish(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
Source: "user@example.com",
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
Latency: 1500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
})
payload := waitForSinglePayload(t, 2*time.Second)
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "request_id", "ctx-request-id")
requireBoolField(t, payload, "failed", true)
})
}
func withEnabledQueue(t *testing.T, fn func()) {
t.Helper()
prevQueueEnabled := Enabled()
prevStatsEnabled := internalusage.StatisticsEnabled()
prevUsageEnabled := UsageStatisticsEnabled()
SetEnabled(false)
SetEnabled(true)
internalusage.SetStatisticsEnabled(true)
SetUsageStatisticsEnabled(true)
defer func() {
SetEnabled(false)
SetEnabled(prevQueueEnabled)
internalusage.SetStatisticsEnabled(prevStatsEnabled)
SetUsageStatisticsEnabled(prevUsageEnabled)
}()
fn()
@@ -127,6 +175,29 @@ func popSinglePayload(t *testing.T) map[string]json.RawMessage {
return payload
}
func waitForSinglePayload(t *testing.T, timeout time.Duration) map[string]json.RawMessage {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
items := PopOldest(10)
if len(items) == 0 {
time.Sleep(10 * time.Millisecond)
continue
}
if len(items) != 1 {
t.Fatalf("PopOldest() items = %d, want 1", len(items))
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(items[0], &payload); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
return payload
}
t.Fatalf("timeout waiting for queued payload")
return nil
}
func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) {
t.Helper()
@@ -143,6 +214,12 @@ func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, w
}
}
type pluginFunc func(context.Context, coreusage.Record)
func (fn pluginFunc) HandleUsage(ctx context.Context, record coreusage.Record) {
fn(ctx, record)
}
func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) {
t.Helper()
+24 -2
View File
@@ -6,7 +6,10 @@ import (
"time"
)
const retentionWindow = time.Minute
const (
defaultRetentionSeconds int64 = 60
maxRetentionSeconds int64 = 3600
)
type queueItem struct {
enqueuedAt time.Time
@@ -21,9 +24,14 @@ type queue struct {
var (
enabled atomic.Bool
retentionSeconds atomic.Int64
global queue
)
func init() {
retentionSeconds.Store(defaultRetentionSeconds)
}
func SetEnabled(value bool) {
enabled.Store(value)
if !value {
@@ -35,6 +43,16 @@ func Enabled() bool {
return enabled.Load()
}
func SetRetentionSeconds(value int) {
normalized := int64(value)
if normalized <= 0 {
normalized = defaultRetentionSeconds
} else if normalized > maxRetentionSeconds {
normalized = maxRetentionSeconds
}
retentionSeconds.Store(normalized)
}
func Enqueue(payload []byte) {
if !Enabled() {
return
@@ -110,7 +128,11 @@ func (q *queue) pruneLocked(now time.Time) {
return
}
cutoff := now.Add(-retentionWindow)
windowSeconds := retentionSeconds.Load()
if windowSeconds <= 0 {
windowSeconds = defaultRetentionSeconds
}
cutoff := now.Add(-time.Duration(windowSeconds) * time.Second)
for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) {
q.head++
}
+16
View File
@@ -0,0 +1,16 @@
package redisqueue
import "sync/atomic"
var usageStatisticsEnabled atomic.Bool
func init() {
usageStatisticsEnabled.Store(true)
}
// SetUsageStatisticsEnabled toggles whether usage records are enqueued into the redisqueue payload buffer.
// This is controlled by the config field `usage-statistics-enabled` and the corresponding management API.
func SetUsageStatisticsEnabled(enabled bool) { usageStatisticsEnabled.Store(enabled) }
// UsageStatisticsEnabled reports whether the usage queue plugin should publish records.
func UsageStatisticsEnabled() bool { return usageStatisticsEnabled.Load() }
+18 -4
View File
@@ -285,7 +285,10 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if event.Err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
select {
case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
case <-ctx.Done():
}
return false
}
switch event.Type {
@@ -303,7 +306,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
case <-ctx.Done():
return false
}
}
break
}
@@ -319,14 +326,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
case <-ctx.Done():
return false
}
}
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
return false
case wsrelay.MessageTypeError:
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
select {
case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
case <-ctx.Done():
}
return false
}
return true
@@ -1357,17 +1357,28 @@ attemptLoop:
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
} else {
reporter.EnsurePublished(ctx)
}
+168 -71
View File
@@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{
"notebookedit": "NotebookEdit",
}
// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding.
var oauthToolRenameReverseMap = func() map[string]string {
m := make(map[string]string, len(oauthToolRenameMap))
for k, v := range oauthToolRenameMap {
m[v] = k
}
return m
}()
// The reverse map is now computed per-request in remapOAuthToolNames so that
// only names the client actually caused us to rewrite are restored on the
// response. A global reverse map — as used previously — corrupted responses
// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase
// alongside `glob` lowercase; the request flagged renames via `glob→Glob`,
// then the global reverse map incorrectly rewrote every `Bash` in the
// response to `bash`, causing Amp to reject the tool_use as unknown).
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
// even after remapping. Currently empty — all tools are mapped instead of removed.
@@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap third-party tool names to Claude Code equivalents and remove
// tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns.
var oauthToolNamesReverseMap map[string]string
if oauthToken {
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
@@ -285,6 +278,10 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if stream {
if errValidate := validateClaudeStreamingResponse(data); errValidate != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errValidate)
return resp, errValidate
}
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
@@ -294,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
// Reverse the OAuth tool name remap so the downstream client sees original names.
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
data = reverseRemapOAuthToolNames(data)
}
data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
var param any
out := sdktranslator.TranslateNonStream(
ctx,
@@ -375,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap third-party tool names to Claude Code equivalents and remove
// tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns.
var oauthToolNamesReverseMap map[string]string
if oauthToken {
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
@@ -474,22 +459,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line)
}
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
cloned[len(line)] = '\n'
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: cloned}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
return
}
@@ -504,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line)
}
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
chunks := sdktranslator.TranslateStream(
ctx,
to,
@@ -521,18 +503,83 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func validateClaudeStreamingResponse(data []byte) error {
scanner := bufio.NewScanner(bytes.NewReader(data))
scanner.Buffer(nil, 52_428_800)
hasData := false
hasMessageStart := false
hasMessageDelta := false
for scanner.Scan() {
line := bytes.TrimSpace(scanner.Bytes())
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(line[len("data:"):])
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
hasData = true
if !gjson.ValidBytes(payload) {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned malformed stream data"}
}
root := gjson.ParseBytes(payload)
switch root.Get("type").String() {
case "error":
message := strings.TrimSpace(root.Get("error.message").String())
if message == "" {
message = strings.TrimSpace(root.Get("error.type").String())
}
if message == "" {
message = "unknown upstream error"
}
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned error event: " + message}
case "message_start":
message := root.Get("message")
if strings.TrimSpace(message.Get("id").String()) == "" || strings.TrimSpace(message.Get("model").String()) == "" {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream message_start is missing id or model"}
}
hasMessageStart = true
case "message_delta":
hasMessageDelta = true
}
}
if errScan := scanner.Err(); errScan != nil {
return errScan
}
if !hasData {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned empty stream response"}
}
if !hasMessageStart {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response is missing message_start"}
}
if !hasMessageDelta {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response ended before message completion"}
}
return nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
@@ -559,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap tool names for OAuth token requests to avoid third-party fingerprinting.
if isClaudeOAuthToken(apiKey) {
body, _ = remapOAuthToolNames(body)
body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled())
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
@@ -661,7 +704,7 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
return auth, nil
}
svc := claudeauth.NewClaudeAuthWithProxyURL(e.cfg, auth.ProxyURL)
td, err := svc.RefreshTokens(ctx, refreshToken)
td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3)
if err != nil {
return nil, err
}
@@ -1004,6 +1047,36 @@ func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name
// transforms in the same order across request paths. Remap runs before prefixing
// so any future non-empty prefix still composes correctly with the per-request
// reverse map.
func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) {
body, reverseMap := remapOAuthToolNames(body)
if !prefixDisabled {
body = applyClaudeToolPrefix(body, prefix)
}
return body, reverseMap
}
// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name
// transforms for non-stream responses in reverse order.
func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
if !prefixDisabled {
body = stripClaudeToolPrefixFromResponse(body, prefix)
}
return reverseRemapOAuthToolNames(body, reverseMap)
}
// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name
// transforms for SSE lines in reverse order.
func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
if !prefixDisabled {
line = stripClaudeToolPrefixFromStreamLine(line, prefix)
}
return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap)
}
// remapOAuthToolNames renames third-party tool names to Claude Code equivalents
// and removes tools without an official counterpart. This prevents Anthropic from
// fingerprinting the request as a third-party client via tool naming patterns.
@@ -1011,8 +1084,25 @@ func isClaudeOAuthToken(apiKey string) bool {
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
// references in messages. Removed tools' corresponding tool_result blocks are preserved
// (they just become orphaned, which is safe for Claude).
func remapOAuthToolNames(body []byte) ([]byte, bool) {
renamed := false
//
// The returned map is keyed on the upstream (TitleCase) name and maps to the
// client-supplied original name. Callers MUST pass this map to the reverse
// functions so only names the client actually caused us to rewrite are restored
// on the response. A global reverse map (the previous implementation) incorrectly
// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`)
// when any OTHER tool in the same request triggered a forward rename (e.g.
// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash`
// regardless of what the client originally sent.
func remapOAuthToolNames(body []byte) ([]byte, map[string]string) {
reverseMap := make(map[string]string, len(oauthToolRenameMap))
recordRename := func(original, renamed string) {
// Preserve the first-seen original name if the same upstream name is
// produced from multiple call sites; they all map back identically.
if _, exists := reverseMap[renamed]; !exists {
reverseMap[renamed] = original
}
}
// 1. Rewrite tools array in a single pass (if present).
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
@@ -1045,7 +1135,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
updatedTool, err := sjson.Set(toolJSON, "name", newName)
if err == nil {
toolJSON = updatedTool
renamed = true
recordRename(name, newName)
}
}
@@ -1070,7 +1160,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
body, _ = sjson.DeleteBytes(body, "tool_choice")
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
renamed = true
recordRename(tcName, newName)
}
}
@@ -1090,14 +1180,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
renamed = true
recordRename(name, newName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
renamed = true
recordRename(toolName, newName)
}
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
@@ -1111,7 +1201,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, newName)
renamed = true
recordRename(nestedToolName, newName)
}
}
return true
@@ -1124,13 +1214,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
})
}
return body, renamed
return body, reverseMap
}
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
// It maps Claude Code TitleCase names back to the original lowercase names so the
// downstream client receives tool names it recognizes.
func reverseRemapOAuthToolNames(body []byte) []byte {
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses
// using the per-request map produced by remapOAuthToolNames. Names the client sent
// that were NOT forward-renamed are passed through unchanged.
func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte {
if len(reverseMap) == 0 {
return body
}
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
@@ -1140,13 +1233,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
switch partType {
case "tool_use":
name := part.Get("name").String()
if origName, ok := oauthToolRenameReverseMap[name]; ok {
if origName, ok := reverseMap[name]; ok {
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
if origName, ok := reverseMap[toolName]; ok {
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
@@ -1156,8 +1249,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
return body
}
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines.
func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE
// stream lines, using the per-request reverseMap produced by remapOAuthToolNames.
func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte {
if len(reverseMap) == 0 {
return line
}
payload := helps.JSONPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
@@ -1175,7 +1272,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if origName, ok := oauthToolRenameReverseMap[name]; ok {
if origName, ok := reverseMap[name]; ok {
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
if err != nil {
return line
@@ -1185,7 +1282,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
if origName, ok := reverseMap[toolName]; ok {
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
if err != nil {
return line
+248 -14
View File
@@ -936,6 +936,113 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsEmptyClaudeStream(t *testing.T) {
_, err := executeOpenAIChatCompletionThroughClaude(t, "")
if err == nil {
t.Fatal("Execute error = nil, want empty stream error")
}
assertStatusErr(t, err, http.StatusBadGateway)
if !strings.Contains(err.Error(), "empty stream response") {
t.Fatalf("Execute error = %q, want empty stream response", err.Error())
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsClaudeErrorEvent(t *testing.T) {
body := `data: {"type":"error","error":{"type":"overloaded_error","message":"upstream overloaded"}}` + "\n"
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
if err == nil {
t.Fatal("Execute error = nil, want upstream error event")
}
assertStatusErr(t, err, http.StatusBadGateway)
if !strings.Contains(err.Error(), "upstream overloaded") {
t.Fatalf("Execute error = %q, want upstream overloaded", err.Error())
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsIncompleteClaudeStream(t *testing.T) {
body := strings.Join([]string{
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
`data: {"type":"message_stop"}`,
``,
}, "\n")
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
if err == nil {
t.Fatal("Execute error = nil, want incomplete stream error")
}
assertStatusErr(t, err, http.StatusBadGateway)
if !strings.Contains(err.Error(), "ended before message completion") {
t.Fatalf("Execute error = %q, want incomplete stream error", err.Error())
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamConvertsValidClaudeStream(t *testing.T) {
body := strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
`event: content_block_delta`,
`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ok"}}`,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":2,"output_tokens":1}}`,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n")
resp, err := executeOpenAIChatCompletionThroughClaude(t, body)
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if got := gjson.GetBytes(resp.Payload, "id").String(); got != "msg_123" {
t.Fatalf("response id = %q, want msg_123; payload=%s", got, string(resp.Payload))
}
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-3-5-sonnet-20241022" {
t.Fatalf("response model = %q, want claude-3-5-sonnet-20241022", got)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "ok" {
t.Fatalf("response content = %q, want ok", got)
}
if got := gjson.GetBytes(resp.Payload, "usage.total_tokens").Int(); got != 3 {
t.Fatalf("usage.total_tokens = %d, want 3", got)
}
}
func executeOpenAIChatCompletionThroughClaude(t *testing.T, upstreamBody string) (cliproxyexecutor.Response, error) {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte(upstreamBody))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hi"}]}`)
return executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
}
func assertStatusErr(t *testing.T, err error, want int) {
t.Helper()
status, ok := err.(interface{ StatusCode() int })
if !ok {
t.Fatalf("error %T does not expose StatusCode", err)
}
if got := status.StatusCode(); got != want {
t.Fatalf("StatusCode() = %d, want %d", got, want)
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -1989,19 +2096,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if renamed {
t.Fatalf("renamed = true, want false")
out, reverseMap := remapOAuthToolNames(body)
if len(reverseMap) != 0 {
t.Fatalf("reverseMap = %v, want empty", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
@@ -2010,20 +2114,150 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if !renamed {
t.Fatalf("renamed = false, want true")
out, reverseMap := remapOAuthToolNames(body)
if reverseMap["Bash"] != "bash" {
t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
t.Fatalf("content.0.name = %q, want %q", got, "bash")
}
}
// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression
// test for a case where a single request contains both a TitleCase tool (which
// must pass through unchanged) and a lowercase tool that we forward-rename.
// Before the fix, triggering ANY forward rename caused the reverse pass to
// lowercase every TitleCase tool in the response using a global reverse map,
// corrupting tool names the client originally sent in TitleCase (notably Amp
// CLI's `Bash`, which its registry lookup cannot find as `bash`).
func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) {
body := []byte(`{"tools":[` +
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
`]}`)
out, reverseMap := remapOAuthToolNames(body)
// Forward: TitleCase `Bash` is not a forward-map key, must pass through.
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash")
}
// Forward: `glob` is a forward-map key, upstream sees `Glob`.
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" {
t.Fatalf("tools.1.name = %q, want %q", got, "Glob")
}
// Reverse map records ONLY the rename that happened.
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
}
// Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`,
// reverseRemap MUST leave it alone.
bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := reverseRemapOAuthToolNames(bashResp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash")
}
// Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`,
// reverseRemap MUST restore the original `glob`.
globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`)
reversed = reverseRemapOAuthToolNames(globResp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" {
t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob")
}
}
// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the
// SSE streaming code path against the same mixed-case bug.
func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
// Bash block was never renamed, must pass through as-is.
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`)
out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap)
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
t.Fatalf("Bash should be preserved, got: %s", string(out))
}
if bytes.Contains(out, []byte(`"name":"bash"`)) {
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
}
// Glob block IS in the reverseMap, must be restored to `glob`.
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`)
out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap)
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
}
}
func TestPrepareClaudeOAuthToolNamesForUpstream_MixedCaseWithPrefix(t *testing.T) {
body := []byte(`{"tools":[` +
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
`],"messages":[{"role":"assistant","content":[` +
`{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}},` +
`{"type":"tool_use","id":"toolu_02","name":"glob","input":{}}` +
`]}]}`)
out, reverseMap := prepareClaudeOAuthToolNamesForUpstream(body, "proxy_", false)
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Bash")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Glob" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Glob")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Bash" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Bash")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Glob" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Glob")
}
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
}
}
func TestRestoreClaudeOAuthToolNamesFromResponse_MixedCaseWithPrefix(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
resp := []byte(`{"content":[` +
`{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}},` +
`{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}` +
`]}`)
out := restoreClaudeOAuthToolNamesFromResponse(resp, "proxy_", false, reverseMap)
if got := gjson.GetBytes(out, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
if got := gjson.GetBytes(out, "content.1.name").String(); got != "glob" {
t.Fatalf("content.1.name = %q, want %q", got, "glob")
}
}
func TestRestoreClaudeOAuthToolNamesFromStreamLine_MixedCaseWithPrefix(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}}}`)
out := restoreClaudeOAuthToolNamesFromStreamLine(bashLine, "proxy_", false, reverseMap)
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
t.Fatalf("Bash should be preserved, got: %s", string(out))
}
if bytes.Contains(out, []byte(`"name":"bash"`)) {
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
}
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}}`)
out = restoreClaudeOAuthToolNamesFromStreamLine(globLine, "proxy_", false, reverseMap)
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
}
}
+11 -4
View File
@@ -30,8 +30,8 @@ import (
)
const (
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
codexOriginator = "codex-tui"
codexUserAgent = "codex_cli_rs/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9"
codexOriginator = "codex_cli_rs"
codexDefaultImageToolModel = "gpt-image-2"
)
@@ -515,13 +515,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, translatedLine, &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -188,7 +188,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body = normalizeCodexInstructions(body)
@@ -776,6 +775,11 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
default:
return "", fmt.Errorf("codex websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme)
}
if strings.TrimSpace(parsed.Host) == "" {
return "", fmt.Errorf("codex websockets executor: responses websocket URL host is empty")
}
return parsed.String(), nil
}
@@ -809,6 +813,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
setHeaderCasePreserved(headers, "session_id", cache.ID)
headers.Set("Conversation_id", cache.ID)
}
@@ -828,13 +833,19 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header.Clone()
}
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
isAPIKey := codexAuthUsesAPIKey(auth)
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
if isAPIKey {
ensureHeaderWithPriority(headers, ginHeaders, "User-Agent", "", "")
} else {
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
}
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
@@ -845,16 +856,9 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
headers.Set("OpenAI-Beta", betaHeader)
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
}
headers.Del("User-Agent")
isAPIKey := false
if auth != nil && auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
isAPIKey = true
}
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", uuid.NewString())
}
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", "")
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
@@ -864,7 +868,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
headers.Set("Chatgpt-Account-Id", trimmed)
setHeaderCasePreserved(headers, "ChatGPT-Account-ID", trimmed)
}
}
}
@@ -879,6 +883,77 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers
}
func codexAuthUsesAPIKey(auth *cliproxyauth.Auth) bool {
if auth == nil || auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["api_key"]) != ""
}
func ensureHeaderCasePreserved(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(headerValueCaseInsensitive(target, key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(headerValueCaseInsensitive(source, key)); val != "" {
setHeaderCasePreserved(target, key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
setHeaderCasePreserved(target, key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
setHeaderCasePreserved(target, key, val)
}
}
func setHeaderCasePreserved(headers http.Header, key string, value string) {
if headers == nil {
return
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
return
}
deleteHeaderCaseInsensitive(headers, key)
headers[key] = []string{value}
}
func headerValueCaseInsensitive(headers http.Header, key string) string {
key = strings.TrimSpace(key)
if headers == nil || key == "" {
return ""
}
if val := strings.TrimSpace(headers.Get(key)); val != "" {
return val
}
for existingKey, values := range headers {
if !strings.EqualFold(existingKey, key) {
continue
}
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
}
return ""
}
func deleteHeaderCaseInsensitive(headers http.Header, key string) {
for existingKey := range headers {
if strings.EqualFold(existingKey, key) {
delete(headers, existingKey)
}
}
}
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
@@ -962,25 +1037,55 @@ func parseCodexWebsocketError(payload []byte) (error, bool) {
return nil, false
}
out := []byte(`{}`)
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
raw := errNode.Raw
if errNode.Type == gjson.String {
raw = errNode.Raw
}
out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
} else {
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
}
out := buildCodexWebsocketErrorPayload(payload, status)
headers := parseCodexWebsocketErrorHeaders(payload)
statusError := statusErr{code: status, msg: string(out)}
if retryAfter := parseCodexRetryAfter(status, out, time.Now()); retryAfter != nil {
statusError.retryAfter = retryAfter
} else if isCodexWebsocketConnectionLimitError(payload) {
retryAfter := time.Duration(0)
statusError.retryAfter = &retryAfter
}
return statusErrWithHeaders{
statusErr: statusErr{code: status, msg: string(out)},
statusErr: statusError,
headers: headers,
}, true
}
func buildCodexWebsocketErrorPayload(payload []byte, status int) []byte {
out := []byte(`{}`)
out, _ = sjson.SetBytes(out, "status", status)
if bodyNode := gjson.GetBytes(payload, "body"); bodyNode.Exists() {
out, _ = sjson.SetRawBytes(out, "body", []byte(bodyNode.Raw))
if bodyErrorNode := bodyNode.Get("error"); bodyErrorNode.Exists() {
out, _ = sjson.SetRawBytes(out, "error", []byte(bodyErrorNode.Raw))
return out
}
}
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw))
return out
}
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
return out
}
func isCodexWebsocketConnectionLimitError(payload []byte) bool {
if len(payload) == 0 {
return false
}
for _, path := range []string{"error.code", "error.type", "body.error.code", "body.error.type", "code", "error"} {
if strings.TrimSpace(gjson.GetBytes(payload, path).String()) == "websocket_connection_limit_reached" {
return true
}
}
return false
}
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
headersNode := gjson.GetBytes(payload, "headers")
if !headersNode.Exists() || !headersNode.IsObject() {
@@ -1,15 +1,21 @@
package executor
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -32,14 +38,80 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
}
}
func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
capturedPayload := make(chan []byte, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
t.Fatalf("request path = %s, want /responses", r.URL.Path)
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Fatalf("upgrade websocket: %v", err)
}
defer func() { _ = conn.Close() }()
msgType, payload, err := conn.ReadMessage()
if err != nil {
t.Fatalf("read upstream websocket message: %v", err)
}
if msgType != websocket.TextMessage {
t.Fatalf("message type = %d, want text", msgType)
}
capturedPayload <- bytes.Clone(payload)
completed := []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
t.Fatalf("write completed websocket message: %v", errWrite)
}
}))
defer server.Close()
exec := NewCodexWebsocketsExecutor(&config.Config{SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll}})
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "sk-test", "base_url": server.URL}}
req := cliproxyexecutor.Request{
Model: "gpt-5-codex",
Payload: []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`),
}
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("codex")}
if _, err := exec.Execute(context.Background(), auth, req, opts); err != nil {
t.Fatalf("Execute() error = %v", err)
}
select {
case payload := <-capturedPayload:
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
t.Fatalf("upstream type = %s, want response.create; payload=%s", got, payload)
}
if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-1" {
t.Fatalf("upstream previous_response_id = %s, want resp-1; payload=%s", got, payload)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for upstream websocket payload")
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if !strings.HasPrefix(codexUserAgent, codexOriginator+"/") {
t.Fatalf("default Codex User-Agent = %s, want prefix %s/", codexUserAgent, codexOriginator)
}
if strings.HasPrefix(codexUserAgent, "codex-tui/") {
t.Fatalf("default Codex User-Agent = %s, must not use stale codex-tui prefix", codexUserAgent)
}
if strings.Contains(codexUserAgent, "(codex-tui;") {
t.Fatalf("default Codex User-Agent = %s, must not include stale codex-tui suffix", codexUserAgent)
}
if got := headers.Get("Originator"); got != codexOriginator {
t.Fatalf("Originator = %s, want %s", got, codexOriginator)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
@@ -62,9 +134,11 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
}
ctx := contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"User-Agent": "codex_cli_rs/0.1.0",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
"session_id": "sess-client",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
@@ -72,6 +146,9 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
if got := headers.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := headers.Get("User-Agent"); got != "codex_cli_rs/0.1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "codex_cli_rs/0.1.0")
}
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
@@ -81,6 +158,12 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
if got := headerValueCaseInsensitive(headers, "session_id"); got != "sess-client" {
t.Fatalf("session_id = %s, want sess-client", got)
}
if _, ok := headers["session_id"]; !ok {
t.Fatalf("expected lowercase session_id header key, got %#v", headers)
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
@@ -97,8 +180,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
@@ -129,8 +212,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "" {
t.Fatalf("User-Agent = %s, want empty", gotVal)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
@@ -155,8 +238,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
@@ -183,6 +266,131 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
if got := headers.Get("Originator"); got != "" {
t.Fatalf("Originator = %s, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersPreservesExplicitAPIKeyUserAgent(t *testing.T) {
auth := &cliproxyauth.Auth{Provider: "codex", Attributes: map[string]string{"api_key": "sk-test"}}
ctx := contextWithGinHeaders(map[string]string{"User-Agent": "api-key-client/1.0", "Originator": "explicit-origin"})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "sk-test", nil)
if got := headers.Get("User-Agent"); got != "api-key-client/1.0" {
t.Fatalf("User-Agent = %s, want api-key-client/1.0", got)
}
if got := headers.Get("Originator"); got != "explicit-origin" {
t.Fatalf("Originator = %s, want explicit-origin", got)
}
}
func TestApplyCodexPromptCacheHeadersSetsLowercaseSessionAndLegacyConversation(t *testing.T) {
req := cliproxyexecutor.Request{Model: "gpt-5-codex", Payload: []byte(`{"prompt_cache_key":"cache-1"}`)}
_, headers := applyCodexPromptCacheHeaders("openai-response", req, []byte(`{"model":"gpt-5-codex"}`))
if got := headerValueCaseInsensitive(headers, "session_id"); got != "cache-1" {
t.Fatalf("session_id = %s, want cache-1", got)
}
if _, ok := headers["session_id"]; !ok {
t.Fatalf("expected lowercase session_id key, got %#v", headers)
}
if got := headers.Get("Conversation_id"); got != "cache-1" {
t.Fatalf("Conversation_id = %s, want cache-1", got)
}
}
func TestApplyCodexWebsocketHeadersUsesCanonicalAccountHeader(t *testing.T) {
auth := &cliproxyauth.Auth{Provider: "codex", Metadata: map[string]any{"account_id": "acct-1"}}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", nil)
if got := headerValueCaseInsensitive(headers, "ChatGPT-Account-ID"); got != "acct-1" {
t.Fatalf("ChatGPT-Account-ID = %s, want acct-1", got)
}
values, ok := headers["ChatGPT-Account-ID"]
if !ok {
t.Fatalf("expected exact ChatGPT-Account-ID key, got %#v", headers)
}
if len(values) != 1 || values[0] != "acct-1" {
t.Fatalf("ChatGPT-Account-ID values = %#v, want [acct-1]", values)
}
}
func TestBuildCodexResponsesWebsocketURLRequiresHTTPURL(t *testing.T) {
if got, err := buildCodexResponsesWebsocketURL("https://example.com/backend/responses"); err != nil || got != "wss://example.com/backend/responses" {
t.Fatalf("https URL = %q, %v; want wss URL", got, err)
}
if _, err := buildCodexResponsesWebsocketURL("ftp://example.com/responses"); err == nil {
t.Fatalf("expected unsupported scheme error")
}
if _, err := buildCodexResponsesWebsocketURL("https:///responses"); err == nil {
t.Fatalf("expected empty host error")
}
}
func TestParseCodexWebsocketErrorMarksConnectionLimitRetryable(t *testing.T) {
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"code":"websocket_connection_limit_reached","message":"too many websockets"},"headers":{"retry-after":"1"}}`))
if !ok {
t.Fatalf("expected websocket error")
}
status, ok := err.(interface{ StatusCode() int })
if !ok || status.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("status = %#v, want 429", err)
}
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
if !ok || retryable.RetryAfter() == nil {
t.Fatalf("expected retryable websocket connection limit error")
}
if got := *retryable.RetryAfter(); got != 0 {
t.Fatalf("retryAfter = %v, want connection-limit fallback 0", got)
}
withHeaders, ok := err.(interface{ Headers() http.Header })
if !ok || withHeaders.Headers().Get("retry-after") != "1" {
t.Fatalf("headers = %#v, want retry-after", err)
}
}
func TestParseCodexWebsocketErrorUsesUsageLimitRetryMetadata(t *testing.T) {
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"type":"usage_limit_reached","message":"usage limit reached","resets_in_seconds":7}}}`))
if !ok {
t.Fatalf("expected websocket error")
}
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
if !ok || retryable.RetryAfter() == nil {
t.Fatalf("expected retryable usage limit websocket error")
}
if got := *retryable.RetryAfter(); got != 7*time.Second {
t.Fatalf("retryAfter = %v, want 7s", got)
}
}
func TestParseCodexWebsocketErrorPreservesWrappedBodyAndHeaders(t *testing.T) {
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"code":"websocket_connection_limit_reached","type":"server_error","message":"too many websocket connections"}},"headers":{"x-request-id":"req-1"}}`))
if !ok {
t.Fatalf("expected websocket error")
}
parsed := gjson.Parse(err.Error())
if got := parsed.Get("status").Int(); got != http.StatusTooManyRequests {
t.Fatalf("wrapped status = %d, want 429; payload=%s", got, err.Error())
}
if got := parsed.Get("body.error.code").String(); got != "websocket_connection_limit_reached" {
t.Fatalf("wrapped body error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
}
if got := parsed.Get("error.code").String(); got != "websocket_connection_limit_reached" {
t.Fatalf("surface error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
}
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
if !ok || retryable.RetryAfter() == nil {
t.Fatalf("expected body.error.code websocket connection limit to be retryable")
}
withHeaders, ok := err.(interface{ Headers() http.Header })
if !ok || withHeaders.Headers().Get("x-request-id") != "req-1" {
t.Fatalf("headers = %#v, want x-request-id", err)
}
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
@@ -411,19 +411,30 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
return
}
reporter.EnsurePublished(ctx)
@@ -434,7 +445,10 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errRead}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
case <-ctx.Done():
}
return
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
@@ -442,12 +456,20 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
+14 -3
View File
@@ -324,17 +324,28 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -338,6 +338,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
}
action := getVertexAction(baseModel, false)
@@ -459,6 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
@@ -570,6 +572,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, true)
baseURL := vertexBaseURL(location)
@@ -656,17 +659,28 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -700,6 +714,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
@@ -786,17 +801,28 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -818,6 +844,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -907,6 +934,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -18,6 +18,7 @@ import (
type UsageReporter struct {
provider string
model string
alias string
authID string
authIndex string
authType string
@@ -29,9 +30,14 @@ type UsageReporter struct {
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := APIKeyFromContext(ctx)
alias := usage.RequestedModelAliasFromContext(ctx)
if alias == "" {
alias = model
}
reporter := &UsageReporter{
provider: provider,
model: model,
alias: strings.TrimSpace(alias),
requestedAt: time.Now(),
apiKey: apiKey,
source: resolveUsageSource(auth, apiKey),
@@ -139,6 +145,7 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
return usage.Record{
Provider: r.provider,
Model: model,
Alias: r.alias,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
@@ -1,6 +1,7 @@
package helps
import (
"context"
"testing"
"time"
@@ -107,6 +108,19 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
}
}
func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
ctx := usage.WithRequestedModelAlias(context.Background(), "client-gpt")
reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
if record.Model != "gpt-5.4" {
t.Fatalf("model = %q, want %q", record.Model, "gpt-5.4")
}
if record.Alias != "client-gpt" {
t.Fatalf("alias = %q, want %q", record.Alias, "client-gpt")
}
}
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
reporter := &UsageReporter{
provider: "codex",
@@ -0,0 +1,43 @@
package helps
import (
"fmt"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// StripVertexOpenAIResponsesToolCallIDs removes OpenAI Responses call IDs that
// Vertex rejects in Gemini functionCall/functionResponse payloads.
func StripVertexOpenAIResponsesToolCallIDs(payload []byte, sourceFormat string) []byte {
if !strings.EqualFold(strings.TrimSpace(sourceFormat), "openai-response") {
return payload
}
contents := gjson.GetBytes(payload, "contents")
if !contents.IsArray() {
return payload
}
out := payload
for contentIndex, content := range contents.Array() {
parts := content.Get("parts")
if !parts.IsArray() {
continue
}
for partIndex, part := range parts.Array() {
if part.Get("functionCall.id").Exists() {
if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionCall.id", contentIndex, partIndex)); errDelete == nil {
out = updated
}
}
if part.Get("functionResponse.id").Exists() {
if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionResponse.id", contentIndex, partIndex)); errDelete == nil {
out = updated
}
}
}
}
return out
}
+115 -5
View File
@@ -290,17 +290,28 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -322,7 +333,17 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
return body, nil
}
out := body
msgs := messages.Array()
out, dropped, err := filterKimiEmptyAssistantMessages(body, msgs)
if err != nil {
return body, err
}
if dropped > 0 {
log.WithField("dropped_assistant_messages", dropped).Debug("kimi executor: dropped empty assistant messages")
}
messages = gjson.GetBytes(out, "messages")
msgs = messages.Array()
pending := make([]string, 0)
patched := 0
patchedReasoning := 0
@@ -340,7 +361,6 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
}
}
msgs := messages.Array()
for msgIdx := range msgs {
msg := msgs[msgIdx]
role := strings.TrimSpace(msg.Get("role").String())
@@ -428,6 +448,96 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
return out, nil
}
func filterKimiEmptyAssistantMessages(body []byte, msgs []gjson.Result) ([]byte, int, error) {
kept := make([]string, 0, len(msgs))
dropped := 0
for _, msg := range msgs {
if shouldDropKimiAssistantMessage(msg) {
dropped++
continue
}
kept = append(kept, msg.Raw)
}
if dropped == 0 {
return body, 0, nil
}
rawMessages := []byte("[" + strings.Join(kept, ",") + "]")
out, err := sjson.SetRawBytes(body, "messages", rawMessages)
if err != nil {
return body, 0, fmt.Errorf("kimi executor: failed to drop empty assistant messages: %w", err)
}
return out, dropped, nil
}
func shouldDropKimiAssistantMessage(msg gjson.Result) bool {
if strings.TrimSpace(msg.Get("role").String()) != "assistant" {
return false
}
if hasKimiToolCalls(msg) || hasKimiLegacyFunctionCall(msg) || hasKimiAssistantReasoning(msg) {
return false
}
return isKimiAssistantContentEmpty(msg.Get("content"))
}
func hasKimiToolCalls(msg gjson.Result) bool {
toolCalls := msg.Get("tool_calls")
return toolCalls.Exists() && toolCalls.IsArray() && len(toolCalls.Array()) > 0
}
func hasKimiLegacyFunctionCall(msg gjson.Result) bool {
functionCall := msg.Get("function_call")
if !functionCall.Exists() || functionCall.Type == gjson.Null {
return false
}
if functionCall.IsObject() && strings.TrimSpace(functionCall.Raw) == "{}" {
return false
}
return strings.TrimSpace(functionCall.Raw) != ""
}
func hasKimiAssistantReasoning(msg gjson.Result) bool {
reasoning := msg.Get("reasoning_content")
return reasoning.Exists() && strings.TrimSpace(reasoning.String()) != ""
}
func isKimiAssistantContentEmpty(content gjson.Result) bool {
if !content.Exists() || content.Type == gjson.Null {
return true
}
if content.Type == gjson.String {
return strings.TrimSpace(content.String()) == ""
}
if !content.IsArray() {
return false
}
for _, part := range content.Array() {
if !isKimiAssistantContentPartEmpty(part) {
return false
}
}
return true
}
func isKimiAssistantContentPartEmpty(part gjson.Result) bool {
if !part.Exists() || part.Type == gjson.Null {
return true
}
if part.Type == gjson.String {
return strings.TrimSpace(part.String()) == ""
}
if !part.IsObject() {
return false
}
if text := part.Get("text"); text.Exists() {
return strings.TrimSpace(text.String()) == ""
}
if strings.TrimSpace(part.Get("type").String()) == "text" {
return true
}
return strings.TrimSpace(part.Raw) == "{}"
}
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
if hasLatest && strings.TrimSpace(latest) != "" {
return latest
@@ -203,3 +203,70 @@ func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
}
}
func TestNormalizeKimiToolMessageLinks_DropsEmptyAssistantWithoutToolLink(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"user","content":"start"},
{"role":"assistant","content":""},
{"role":"assistant","content":" "},
{"role":"assistant","content":"","tool_calls":null},
{"role":"assistant","content":[{"type":"text","text":" "}]},
{"role":"assistant"},
{"role":"assistant","content":"keep"},
{"role":"user","content":"next"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
messages := gjson.GetBytes(out, "messages").Array()
if len(messages) != 3 {
t.Fatalf("messages length = %d, want 3, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
}
if got := messages[0].Get("content").String(); got != "start" {
t.Fatalf("messages.0.content = %q, want %q", got, "start")
}
if got := messages[1].Get("content").String(); got != "keep" {
t.Fatalf("messages.1.content = %q, want %q", got, "keep")
}
if got := messages[2].Get("content").String(); got != "next" {
t.Fatalf("messages.2.content = %q, want %q", got, "next")
}
}
func TestNormalizeKimiToolMessageLinks_PreservesAssistantWithToolLinkOrReasoning(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
{"role":"assistant","content":"","function_call":{"name":"legacy_call","arguments":"{}"}},
{"role":"assistant","content":"","reasoning_content":"thought"},
{"role":"assistant","content":[{"type":"text","text":" visible "}]}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
messages := gjson.GetBytes(out, "messages").Array()
if len(messages) != 4 {
t.Fatalf("messages length = %d, want 4, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
}
if !messages[0].Get("tool_calls").Exists() {
t.Fatalf("messages.0.tool_calls should exist")
}
if !messages[1].Get("function_call").Exists() {
t.Fatalf("messages.1.function_call should exist")
}
if got := messages[2].Get("reasoning_content").String(); got != "thought" {
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "thought")
}
if got := messages[3].Get("content.0.text").String(); got != " visible " {
t.Fatalf("messages.3.content.0.text = %q, want %q", got, " visible ")
}
}
@@ -96,6 +96,12 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
@@ -105,11 +111,6 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
}
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := strings.TrimSuffix(baseURL, "/") + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
@@ -199,15 +200,16 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
// Request usage data in the final streaming chunk so that token statistics
// are captured even when the upstream is an OpenAI-compatible provider.
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
@@ -281,32 +283,57 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if len(line) == 0 {
trimmedLine := bytes.TrimSpace(line)
if len(trimmedLine) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
if !bytes.HasPrefix(trimmedLine, []byte("data:")) {
if bytes.HasPrefix(trimmedLine, []byte(":")) || bytes.HasPrefix(trimmedLine, []byte("event:")) ||
bytes.HasPrefix(trimmedLine, []byte("id:")) || bytes.HasPrefix(trimmedLine, []byte("retry:")) {
continue
}
if bytes.HasPrefix(trimmedLine, []byte("{")) || bytes.HasPrefix(trimmedLine, []byte("[")) {
streamErr := statusErr{code: http.StatusBadGateway, msg: string(trimmedLine)}
helps.RecordAPIResponseError(ctx, e.cfg, streamErr)
reporter.PublishFailure(ctx)
select {
case out <- cliproxyexecutor.StreamChunk{Err: streamErr}:
case <-ctx.Done():
}
return
}
continue
}
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
// OpenAI-compatible streams must use SSE data lines.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(trimmedLine), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
} else {
// In case the upstream close the stream without a terminal [DONE] marker.
// Feed a synthetic done marker through the translator so pending
// response.completed events are still emitted exactly once.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
// Ensure we record the request if no usage chunk was ever seen
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -56,3 +57,125 @@ func TestOpenAICompatExecutorCompactPassthrough(t *testing.T) {
t.Fatalf("payload = %s", string(resp.Payload))
}
}
func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"chatcmpl_1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{
Payload: config.PayloadConfig{
Override: []config.PayloadRule{
{
Models: []config.PayloadModelRule{
{Name: "custom-openai", Protocol: "openai"},
},
Params: map[string]any{
"reasoning_effort": "low",
},
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
payload := []byte(`{"model":"custom-openai(high)","messages":[{"role":"user","content":"hi"}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "custom-openai(high)",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if got := gjson.GetBytes(gotBody, "reasoning_effort").String(); got != "low" {
t.Fatalf("reasoning_effort = %q, want %q; body=%s", got, "low", string(gotBody))
}
}
func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: error\n"))
_, _ = w.Write([]byte(`{"error":{"message":"upstream failed","type":"server_error"}}` + "\n"))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "openrouter-model",
Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var gotErr error
for chunk := range result.Chunks {
if chunk.Err != nil {
gotErr = chunk.Err
break
}
}
if gotErr == nil {
t.Fatalf("expected plain JSON stream error")
}
if status, ok := gotErr.(interface{ StatusCode() int }); !ok || status.StatusCode() != http.StatusBadGateway {
t.Fatalf("stream error status = %v, want %d", gotErr, http.StatusBadGateway)
}
if !strings.Contains(gotErr.Error(), "upstream failed") {
t.Fatalf("stream error = %v", gotErr)
}
}
func TestOpenAICompatExecutorStreamSkipsKeepAliveUntilDataLine(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: ping\nid: 1\nretry: 1000\n"))
_, _ = w.Write([]byte(`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hello"},"finish_reason":null}]}` + "\n"))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "openrouter-model",
Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var got strings.Builder
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
got.Write(chunk.Payload)
}
if gjson.Get(got.String(), "choices.0.delta.content").String() != "hello" {
t.Fatalf("stream payload = %s", got.String())
}
}
@@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct {
CreatedAt int64
ResponseID string
FinishReason string
Usage claudeUsageTokens
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
}
type claudeUsageTokens struct {
InputTokens int64
OutputTokens int64
CacheCreationInputTokens int64
CacheReadInputTokens int64
HasUsage bool
}
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
@@ -36,15 +45,30 @@ type ToolCallAccumulator struct {
Arguments strings.Builder
}
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
inputTokens := usage.Get("input_tokens").Int()
completionTokens = usage.Get("output_tokens").Int()
cachedTokens = usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
func (u *claudeUsageTokens) Merge(usage gjson.Result) {
if !usage.Exists() {
return
}
u.HasUsage = true
if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() {
u.InputTokens = inputTokens.Int()
}
if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() {
u.OutputTokens = outputTokens.Int()
}
if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() {
u.CacheCreationInputTokens = cacheCreationInputTokens.Int()
}
if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() {
u.CacheReadInputTokens = cacheReadInputTokens.Int()
}
}
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
cachedTokens = u.CacheReadInputTokens
promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens
completionTokens = u.OutputTokens
totalTokens = promptTokens + completionTokens
return promptTokens, completionTokens, totalTokens, cachedTokens
}
@@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage"))
}
return [][]byte{template}
@@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage)
promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage()
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
@@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var stopReason string
var contentParts []string
var reasoningParts []string
usageTokens := claudeUsageTokens{}
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
for _, chunk := range chunks {
@@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
usageTokens.Merge(message.Get("usage"))
}
case "content_block_start":
@@ -371,14 +399,18 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
usageTokens.Merge(usage)
}
}
}
if usageTokens.HasUsage {
promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage()
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
}
}
}
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.SetBytes(out, "id", messageID)
@@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin
}
}
func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) {
ctx := context.Background()
var param any
ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`),
&param,
)
out := ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`),
&param,
)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
@@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n")
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
@@ -339,25 +339,21 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
})
}
includedToolNames := map[string]struct{}{}
toolNameMap := map[string]string{}
// tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := []byte("[]")
tools.ForEach(func(_, tool gjson.Result) bool {
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
if n := tool.Get("name"); n.Exists() {
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
convertedTools := convertResponsesToolToClaudeTools(tool, toolNameMap)
for _, tJSON := range convertedTools {
toolName := gjson.GetBytes(tJSON, "name").String()
if toolName != "" {
includedToolNames[toolName] = struct{}{}
}
if d := tool.Get("description"); d.Exists() {
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
}
if params := tool.Get("parameters"); params.Exists() {
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
}
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
}
return true
})
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
@@ -375,15 +371,25 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case "none":
// Leave unset; implies no tools
case "required":
if len(includedToolNames) > 0 {
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
if fn == "" {
fn = toolChoice.Get("name").String()
}
if mappedName := toolNameMap[fn]; mappedName != "" {
fn = mappedName
}
if _, ok := includedToolNames[fn]; ok {
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
}
}
default:
}
@@ -391,3 +397,167 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
return out
}
func convertResponsesToolToClaudeTools(tool gjson.Result, toolNameMap map[string]string) [][]byte {
toolType := strings.TrimSpace(tool.Get("type").String())
switch toolType {
case "", "function":
if tJSON, ok := convertResponsesFunctionToolToClaude(tool, ""); ok {
return [][]byte{tJSON}
}
case "namespace":
return convertResponsesNamespaceToolToClaude(tool, toolNameMap)
case "web_search":
if tJSON, ok := convertResponsesWebSearchToolToClaude(tool); ok {
if name := gjson.GetBytes(tJSON, "name").String(); name != "" {
toolNameMap[name] = name
}
return [][]byte{tJSON}
}
default:
if isUnsupportedOpenAIBuiltinToolType(toolType) {
return nil
}
if tool.Get("name").String() != "" {
return [][]byte{[]byte(tool.Raw)}
}
}
return nil
}
func convertResponsesNamespaceToolToClaude(tool gjson.Result, toolNameMap map[string]string) [][]byte {
namespaceName := strings.TrimSpace(tool.Get("name").String())
children := tool.Get("tools")
if !children.Exists() || !children.IsArray() {
return nil
}
var out [][]byte
children.ForEach(func(_, child gjson.Result) bool {
childName := responsesToolName(child)
qualifiedName := qualifyResponsesNamespaceToolName(namespaceName, childName)
if tJSON, ok := convertResponsesFunctionToolToClaude(child, qualifiedName); ok {
out = append(out, tJSON)
toolNameMap[qualifiedName] = qualifiedName
if childName != "" {
toolNameMap[childName] = qualifiedName
}
}
return true
})
return out
}
func convertResponsesFunctionToolToClaude(tool gjson.Result, overrideName string) ([]byte, bool) {
name := strings.TrimSpace(overrideName)
if name == "" {
name = responsesToolName(tool)
}
if name == "" {
return nil, false
}
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
tJSON, _ = sjson.SetBytes(tJSON, "name", name)
if d := responsesToolDescription(tool); d != "" {
tJSON, _ = sjson.SetBytes(tJSON, "description", d)
}
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", normalizeClaudeToolInputSchema(responsesToolParameters(tool)))
return tJSON, true
}
func convertResponsesWebSearchToolToClaude(tool gjson.Result) ([]byte, bool) {
if externalWebAccess := tool.Get("external_web_access"); externalWebAccess.Exists() && !externalWebAccess.Bool() {
return nil, false
}
name := strings.TrimSpace(tool.Get("name").String())
if name == "" {
name = "web_search"
}
tJSON := []byte(`{"type":"web_search_20250305","name":""}`)
tJSON, _ = sjson.SetBytes(tJSON, "name", name)
if maxUses := tool.Get("max_uses"); maxUses.Exists() {
tJSON, _ = sjson.SetBytes(tJSON, "max_uses", maxUses.Int())
}
if allowedDomains := tool.Get("filters.allowed_domains"); allowedDomains.Exists() && allowedDomains.IsArray() {
tJSON, _ = sjson.SetRawBytes(tJSON, "allowed_domains", []byte(allowedDomains.Raw))
}
if userLocation := tool.Get("user_location"); userLocation.Exists() && userLocation.IsObject() {
tJSON, _ = sjson.SetRawBytes(tJSON, "user_location", []byte(userLocation.Raw))
}
return tJSON, true
}
func responsesToolName(tool gjson.Result) string {
if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
return name
}
return strings.TrimSpace(tool.Get("function.name").String())
}
func responsesToolDescription(tool gjson.Result) string {
if description := tool.Get("description").String(); description != "" {
return description
}
return tool.Get("function.description").String()
}
func responsesToolParameters(tool gjson.Result) gjson.Result {
for _, path := range []string{
"parameters",
"parametersJsonSchema",
"input_schema",
"function.parameters",
"function.parametersJsonSchema",
} {
if parameters := tool.Get(path); parameters.Exists() {
return parameters
}
}
return gjson.Result{}
}
func normalizeClaudeToolInputSchema(parameters gjson.Result) []byte {
raw := strings.TrimSpace(parameters.Raw)
if raw == "" || raw == "null" || !gjson.Valid(raw) {
return []byte(`{"type":"object","properties":{}}`)
}
result := gjson.Parse(raw)
if !result.IsObject() {
return []byte(`{"type":"object","properties":{}}`)
}
schema := []byte(raw)
schemaType := result.Get("type").String()
if schemaType == "" {
schema, _ = sjson.SetBytes(schema, "type", "object")
schemaType = "object"
}
if schemaType == "object" && !result.Get("properties").Exists() {
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
}
return schema
}
func qualifyResponsesNamespaceToolName(namespaceName, childName string) string {
childName = strings.TrimSpace(childName)
if childName == "" || namespaceName == "" || strings.HasPrefix(childName, "mcp__") {
return childName
}
if strings.HasPrefix(childName, namespaceName) {
return childName
}
if strings.HasSuffix(namespaceName, "__") {
return namespaceName + childName
}
return namespaceName + "__" + childName
}
func isUnsupportedOpenAIBuiltinToolType(toolType string) bool {
switch toolType {
case "image_generation", "file_search", "code_interpreter", "computer_use_preview":
return true
default:
return false
}
}
@@ -27,6 +27,7 @@ type claudeToResponsesState struct {
FuncCallIDs map[int]string // index -> call id
// message text aggregation
TextBuf strings.Builder
CurrentTextBuf strings.Builder
// reasoning state
ReasoningActive bool
ReasoningItemID string
@@ -80,6 +81,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.CreatedAt = time.Now().Unix()
// Reset per-message aggregation state
st.TextBuf.Reset()
st.CurrentTextBuf.Reset()
st.ReasoningBuf.Reset()
st.ReasoningActive = false
st.InTextBlock = false
@@ -128,6 +130,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if typ == "text" {
// open message item + content part
st.InTextBlock = true
st.CurrentTextBuf.Reset()
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
@@ -189,6 +192,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output
st.TextBuf.WriteString(t.String())
st.CurrentTextBuf.WriteString(t.String())
}
} else if dt == "input_json_delta" {
idx := int(root.Get("index").Int())
@@ -220,17 +224,21 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
case "content_block_stop":
idx := int(root.Get("index").Int())
if st.InTextBlock {
fullText := st.CurrentTextBuf.String()
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
final, _ = sjson.SetBytes(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false
} else if st.InFuncBlock {
@@ -0,0 +1,78 @@
package geminiCLI
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertGeminiCLIRequestToCodex_PreservesSchemaPropertyNamedType(t *testing.T) {
input := []byte(`{
"request": {
"tools": [
{
"functionDeclarations": [
{
"name": "ask_user",
"description": "Ask the user one or more questions.",
"parametersJsonSchema": {
"type": "object",
"properties": {
"questions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"header": {
"type": "string"
},
"type": {
"default": "choice",
"description": "Question type.",
"enum": [
"choice",
"text",
"yesno"
],
"type": "string"
}
},
"required": [
"question",
"header",
"type"
]
}
}
},
"required": [
"questions"
]
}
}
]
}
]
}
}`)
out := ConvertGeminiCLIRequestToCodex("gpt-5.2", input, true)
tool := gjson.GetBytes(out, "tools.0")
if got := tool.Get("type").String(); got != "function" {
t.Fatalf("expected tool type %q, got %q; output=%s", "function", got, string(out))
}
typeProperty := tool.Get("parameters.properties.questions.items.properties.type")
if !typeProperty.IsObject() {
t.Fatalf("expected schema property named type to stay an object; output=%s", string(out))
}
if got := typeProperty.Get("type").String(); got != "string" {
t.Fatalf("expected schema property type %q, got %q; output=%s", "string", got, string(out))
}
if got := typeProperty.Get("default").String(); got != "choice" {
t.Fatalf("expected default %q, got %q; output=%s", "choice", got, string(out))
}
if got := typeProperty.Get("enum.2").String(); got != "yesno" {
t.Fatalf("expected enum value %q, got %q; output=%s", "yesno", got, string(out))
}
}
@@ -284,7 +284,11 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
typeValue := gjson.GetBytes(out, fullPath)
if typeValue.Type != gjson.String {
continue
}
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(typeValue.String()))
}
return out
@@ -121,13 +121,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
case "tool":
// Handle tool response messages as top-level function_call_output objects
toolCallID := m.Get("tool_call_id").String()
content := m.Get("content").String()
content := m.Get("content")
// Create function_call_output object
funcOutput := []byte(`{}`)
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
funcOutput = setToolCallOutputContent(funcOutput, content)
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
default:
@@ -359,6 +359,91 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
return out
}
func setToolCallOutputContent(funcOutput []byte, content gjson.Result) []byte {
switch {
case content.Type == gjson.String:
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content.String())
case content.IsArray():
output := []byte(`[]`)
for _, item := range content.Array() {
output = appendToolOutputContentPart(output, item)
}
funcOutput, _ = sjson.SetRawBytes(funcOutput, "output", output)
default:
fallbackOutput := content.Raw
if fallbackOutput == "" {
fallbackOutput = content.String()
}
funcOutput, _ = sjson.SetBytes(funcOutput, "output", fallbackOutput)
}
return funcOutput
}
func appendToolOutputContentPart(output []byte, item gjson.Result) []byte {
switch item.Get("type").String() {
case "text":
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.SetBytes(part, "text", item.Get("text").String())
output, _ = sjson.SetRawBytes(output, "-1", part)
case "image_url":
imageURL := item.Get("image_url.url").String()
fileID := item.Get("image_url.file_id").String()
if imageURL == "" && fileID == "" {
return appendToolOutputFallbackPart(output, item)
}
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_image")
if imageURL != "" {
part, _ = sjson.SetBytes(part, "image_url", imageURL)
}
if fileID != "" {
part, _ = sjson.SetBytes(part, "file_id", fileID)
}
if detail := item.Get("image_url.detail").String(); detail != "" {
part, _ = sjson.SetBytes(part, "detail", detail)
}
output, _ = sjson.SetRawBytes(output, "-1", part)
case "file":
fileID := item.Get("file.file_id").String()
fileData := item.Get("file.file_data").String()
fileURL := item.Get("file.file_url").String()
if fileID == "" && fileData == "" && fileURL == "" {
return appendToolOutputFallbackPart(output, item)
}
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_file")
if fileID != "" {
part, _ = sjson.SetBytes(part, "file_id", fileID)
}
if fileData != "" {
part, _ = sjson.SetBytes(part, "file_data", fileData)
}
if fileURL != "" {
part, _ = sjson.SetBytes(part, "file_url", fileURL)
}
if filename := item.Get("file.filename").String(); filename != "" {
part, _ = sjson.SetBytes(part, "filename", filename)
}
output, _ = sjson.SetRawBytes(output, "-1", part)
default:
output = appendToolOutputFallbackPart(output, item)
}
return output
}
func appendToolOutputFallbackPart(output []byte, item gjson.Result) []byte {
text := item.Raw
if text == "" {
text = item.String()
}
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.SetBytes(part, "text", text)
output, _ = sjson.SetRawBytes(output, "-1", part)
return output
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.
// If the name length exceeds 64, it will try to preserve the "mcp__" prefix and last segment.
// Otherwise it truncates to 64 characters.
@@ -176,6 +176,182 @@ func TestToolCallWithContent(t *testing.T) {
}
}
func TestToolCallOutputWithMultimodalContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Show me the generated result."},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_output_1",
"type": "function",
"function": {"name": "render_output", "arguments": "{}"}
}
]
},
{
"role": "tool",
"tool_call_id": "call_output_1",
"content": [
{"type":"text","text":"Rendered result attached."},
{"type":"image_url","image_url":{"url":"https://example.com/generated.png","detail":"high"}},
{"type":"image_url","image_url":{"file_id":"file-img-123"}},
{"type":"file","file":{"file_id":"file-doc-123","filename":"doc.pdf"}},
{"type":"file","file":{"file_data":"SGVsbG8=","filename":"inline.txt"}},
{"type":"file","file":{"file_url":"https://example.com/report.pdf","filename":"report.pdf"}}
]
}
],
"tools": [
{
"type": "function",
"function": {"name": "render_output", "description": "Render output", "parameters": {"type": "object", "properties": {}}}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
output := gjson.Get(result, "input.2.output")
if !output.IsArray() {
t.Fatalf("expected tool output to be an array, got: %s", output.Raw)
}
parts := output.Array()
if len(parts) != 6 {
t.Fatalf("expected 6 output parts, got %d: %s", len(parts), output.Raw)
}
if parts[0].Get("type").String() != "input_text" || parts[0].Get("text").String() != "Rendered result attached." {
t.Fatalf("part 0: expected input_text with rendered text, got %s", parts[0].Raw)
}
if parts[1].Get("type").String() != "input_image" {
t.Fatalf("part 1: expected input_image, got %s", parts[1].Raw)
}
if parts[1].Get("image_url").String() != "https://example.com/generated.png" {
t.Errorf("part 1: unexpected image_url %s", parts[1].Get("image_url").String())
}
if parts[1].Get("detail").String() != "high" {
t.Errorf("part 1: unexpected detail %s", parts[1].Get("detail").String())
}
if parts[2].Get("type").String() != "input_image" || parts[2].Get("file_id").String() != "file-img-123" {
t.Fatalf("part 2: expected file_id-backed input_image, got %s", parts[2].Raw)
}
if parts[3].Get("type").String() != "input_file" || parts[3].Get("file_id").String() != "file-doc-123" {
t.Fatalf("part 3: expected file_id-backed input_file, got %s", parts[3].Raw)
}
if parts[3].Get("filename").String() != "doc.pdf" {
t.Errorf("part 3: unexpected filename %s", parts[3].Get("filename").String())
}
if parts[4].Get("type").String() != "input_file" || parts[4].Get("file_data").String() != "SGVsbG8=" {
t.Fatalf("part 4: expected file_data-backed input_file, got %s", parts[4].Raw)
}
if parts[5].Get("type").String() != "input_file" || parts[5].Get("file_url").String() != "https://example.com/report.pdf" {
t.Fatalf("part 5: expected file_url-backed input_file, got %s", parts[5].Raw)
}
}
func TestToolCallOutputFallsBackForInvalidStructuredParts(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Check tool output."},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "call_invalid_parts", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
]
},
{
"role": "tool",
"tool_call_id": "call_invalid_parts",
"content": [
{"type":"image_url","image_url":{"detail":"low"}},
{"type":"file","file":{"filename":"orphan.txt"}},
{"type":"unknown_type","foo":"bar","nested":{"a":1}}
]
}
],
"tools": [
{"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
parts := gjson.Get(result, "input.2.output").Array()
if len(parts) != 3 {
t.Fatalf("expected 3 output parts, got %d: %s", len(parts), gjson.Get(result, "input.2.output").Raw)
}
expectedFallbacks := []string{
`{"type":"image_url","image_url":{"detail":"low"}}`,
`{"type":"file","file":{"filename":"orphan.txt"}}`,
`{"type":"unknown_type","foo":"bar","nested":{"a":1}}`,
}
for i, expectedFallback := range expectedFallbacks {
if parts[i].Get("type").String() != "input_text" {
t.Fatalf("part %d: expected input_text fallback, got %s", i, parts[i].Raw)
}
if parts[i].Get("text").String() != expectedFallback {
t.Fatalf("part %d: expected fallback %s, got %s", i, expectedFallback, parts[i].Get("text").String())
}
}
}
func TestToolCallOutputWithNonStringJSONContent(t *testing.T) {
tests := []struct {
name string
content string
expectedOutput string
}{
{name: "null", content: `null`, expectedOutput: `null`},
{name: "object", content: `{"status":"ok","count":2}`, expectedOutput: `{"status":"ok","count":2}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Check tool output."},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "call_json", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
]
},
{
"role": "tool",
"tool_call_id": "call_json",
"content": ` + tt.content + `
}
],
"tools": [
{"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
output := gjson.Get(result, "input.2.output")
if !output.Exists() {
t.Fatalf("expected output field to exist: %s", gjson.Get(result, "input.2").Raw)
}
if output.String() != tt.expectedOutput {
t.Fatalf("expected output %s, got %s", tt.expectedOutput, output.String())
}
})
}
}
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
@@ -236,7 +236,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Handle function name
if function := toolCall.Get("function"); function.Exists() {
if name := function.Get("name"); name.Exists() {
if name := function.Get("name"); name.Exists() && name.String() != "" {
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
stopThinkingContentBlock(param, &results)
@@ -0,0 +1,41 @@
package claude
import (
"bytes"
"context"
"testing"
)
func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing.T) {
originalRequest := []byte(`{"stream":true}`)
var param any
firstChunks := ConvertOpenAIResponseToClaude(
context.Background(),
"test-model",
originalRequest,
nil,
[]byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"read_file","arguments":""}}]},"finish_reason":null}]}`),
&param,
)
firstOutput := bytes.Join(firstChunks, nil)
if !bytes.Contains(firstOutput, []byte(`"name":"read_file"`)) {
t.Fatalf("expected first chunk to start read_file tool block, got %s", string(firstOutput))
}
secondChunks := ConvertOpenAIResponseToClaude(
context.Background(),
"test-model",
originalRequest,
nil,
[]byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"{\"path\":\"/tmp/a\"}"}}]},"finish_reason":null}]}`),
&param,
)
secondOutput := bytes.Join(secondChunks, nil)
if bytes.Contains(secondOutput, []byte(`content_block_start`)) {
t.Fatalf("did not expect null tool name delta to start a new content block, got %s", string(secondOutput))
}
if bytes.Contains(secondOutput, []byte(`"name":""`)) {
t.Fatalf("did not expect null tool name delta to emit an empty tool name, got %s", string(secondOutput))
}
}
+4 -18
View File
@@ -18,7 +18,6 @@ const (
tabAuthFiles
tabAPIKeys
tabOAuth
tabUsage
tabLogs
)
@@ -40,7 +39,6 @@ type App struct {
auth authTabModel
keys keysTabModel
oauth oauthTabModel
usage usageTabModel
logs logsTabModel
client *Client
@@ -50,7 +48,7 @@ type App struct {
ready bool
// Track which tabs have been initialized (fetched data)
initialized [7]bool
initialized [6]bool
}
type authConnectMsg struct {
@@ -81,10 +79,9 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
initialized: [7]bool{
initialized: [6]bool{
tabDashboard: true,
tabLogs: true,
},
@@ -92,7 +89,7 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
app.refreshTabs()
if authRequired {
app.initialized = [7]bool{}
app.initialized = [6]bool{}
}
app.setAuthInputPrompt()
return app
@@ -128,7 +125,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
@@ -142,7 +138,7 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
a.initialized = [7]bool{}
a.initialized = [6]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
@@ -258,8 +254,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
case tabUsage:
a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
@@ -322,8 +316,6 @@ func (a *App) initTabIfNeeded(_ int) tea.Cmd {
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
case tabUsage:
return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
@@ -360,8 +352,6 @@ func (a App) View() string {
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
case tabUsage:
sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
@@ -529,10 +519,6 @@ func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
if cmd != nil {
cmds = append(cmds, cmd)
}
a.usage, cmd = a.usage.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
-5
View File
@@ -140,11 +140,6 @@ func (c *Client) PutConfigYAML(yamlContent string) error {
return err
}
// GetUsage fetches usage statistics.
func (c *Client) GetUsage() (map[string]any, error) {
return c.getJSON("/v0/management/usage")
}
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
+7 -70
View File
@@ -22,14 +22,12 @@ type dashboardModel struct {
// Cached data for re-rendering on locale change
lastConfig map[string]any
lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
@@ -47,25 +45,24 @@ func (m dashboardModel) Init() tea.Cmd {
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
for _, e := range []error{cfgErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
return dashboardDataMsg{config: cfg, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
m.content = m.renderDashboard(m.lastConfig, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
@@ -78,11 +75,10 @@ func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
m.content = m.renderDashboard(msg.config, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
@@ -121,7 +117,7 @@ func (m dashboardModel) View() string {
return m.viewport.View()
}
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
func (m dashboardModel) renderDashboard(cfg map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
@@ -138,7 +134,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
cardWidth = (m.width - 6) / 4
cardWidth = (m.width - 2) / 2
if cardWidth < 18 {
cardWidth = 18
}
@@ -173,34 +169,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
// Card 3: Total Requests
totalReqs := int64(0)
successReqs := int64(0)
failedReqs := int64(0)
totalTokens := int64(0)
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
totalReqs = int64(getFloat(usageMap, "total_requests"))
successReqs = int64(getFloat(usageMap, "success_count"))
failedReqs = int64(getFloat(usageMap, "failure_count"))
totalTokens = int64(getFloat(usageMap, "total_tokens"))
}
}
card3 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
))
// Card 4: Total Tokens
tokenStr := formatLargeNumber(totalTokens)
card4 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
@@ -258,38 +227,6 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
sb.WriteString("\n")
// ━━━ Per-Model Usage ━━━
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for _, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
reqs := int64(getFloat(stats, "total_requests"))
toks := int64(getFloat(stats, "total_tokens"))
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
sb.WriteString(tableCellStyle.Render(row))
sb.WriteString("\n")
}
}
}
}
}
}
}
}
return sb.String()
}
+2 -2
View File
@@ -50,8 +50,8 @@ var locales = map[string]map[string]string{
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
-418
View File
@@ -1,418 +0,0 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// usageTabModel displays usage statistics with charts and breakdowns.
type usageTabModel struct {
client *Client
viewport viewport.Model
usage map[string]any
err error
width int
height int
ready bool
}
type usageDataMsg struct {
usage map[string]any
err error
}
func newUsageTabModel(client *Client) usageTabModel {
return usageTabModel{
client: client,
}
}
func (m usageTabModel) Init() tea.Cmd {
return m.fetchData
}
func (m usageTabModel) fetchData() tea.Msg {
usage, err := m.client.GetUsage()
return usageDataMsg{usage: usage, err: err}
}
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case usageDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.usage = msg.usage
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *usageTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m usageTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m usageTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("usage_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("usage_help")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if m.usage == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
usageMap, _ := m.usage["usage"].(map[string]any)
if usageMap == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
totalReqs := int64(getFloat(usageMap, "total_requests"))
successCnt := int64(getFloat(usageMap, "success_count"))
failureCnt := int64(getFloat(usageMap, "failure_count"))
totalTokens := int64(getFloat(usageMap, "total_tokens"))
// ━━━ Overview Cards ━━━
cardWidth := 20
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 16 {
cardWidth = 16
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(3)
// Total Requests
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
))
// Total Tokens
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
))
// RPM
rpm := float64(0)
if totalReqs > 0 {
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
}
}
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
))
// TPM
tpm := float64(0)
if totalTokens > 0 {
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
}
}
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Requests by Hour (ASCII bar chart) ━━━
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
sb.WriteString("\n")
}
// ━━━ Tokens by Hour ━━━
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
sb.WriteString("\n")
}
// ━━━ Requests by Day ━━━
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
sb.WriteString("\n")
}
// ━━━ API Detail Stats ━━━
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for apiName, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
apiReqs := int64(getFloat(apiMap, "total_requests"))
apiToks := int64(getFloat(apiMap, "total_tokens"))
row := fmt.Sprintf(" %-30s %10d %12s",
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
sb.WriteString("\n")
// Per-model breakdown
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
mReqs := int64(getFloat(stats, "total_requests"))
mToks := int64(getFloat(stats, "total_tokens"))
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
truncate(model, 28), mReqs, formatLargeNumber(mToks))
sb.WriteString(tableCellStyle.Render(mRow))
sb.WriteString("\n")
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
// Latency breakdown from details
sb.WriteString(m.renderLatencyBreakdown(stats))
}
}
}
}
}
}
sb.WriteString("\n")
return sb.String()
}
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
tokens, ok := dm["tokens"].(map[string]any)
if !ok {
continue
}
inputTotal += int64(getFloat(tokens, "input_tokens"))
outputTotal += int64(getFloat(tokens, "output_tokens"))
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
}
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
return ""
}
parts := []string{}
if inputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
}
if outputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
}
if cachedTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
}
if reasoningTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
}
return fmt.Sprintf(" │ %s\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var totalLatency int64
var count int
var minLatency, maxLatency int64
first := true
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
latencyMs := int64(getFloat(dm, "latency_ms"))
if latencyMs <= 0 {
continue
}
totalLatency += latencyMs
count++
if first {
minLatency = latencyMs
maxLatency = latencyMs
first = false
} else {
if latencyMs < minLatency {
minLatency = latencyMs
}
if latencyMs > maxLatency {
maxLatency = latencyMs
}
}
}
if count == 0 {
return ""
}
avgLatency := totalLatency / int64(count)
return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
avgLatency, minLatency, maxLatency)
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {
maxBarWidth = 10
}
// Sort keys
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
// Find max value
maxVal := float64(0)
for _, k := range keys {
v := getFloat(data, k)
if v > maxVal {
maxVal = v
}
}
if maxVal == 0 {
return ""
}
barStyle := lipgloss.NewStyle().Foreground(barColor)
var sb strings.Builder
labelWidth := 12
barAvail := maxBarWidth - labelWidth - 12
if barAvail < 5 {
barAvail = 5
}
for _, k := range keys {
v := getFloat(data, k)
barLen := int(v / maxVal * float64(barAvail))
if barLen < 1 && v > 0 {
barLen = 1
}
bar := strings.Repeat("█", barLen)
label := k
if len(label) > labelWidth {
label = label[:labelWidth]
}
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
labelWidth, label,
barStyle.Render(bar),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
))
}
return sb.String()
}
-134
View File
@@ -1,134 +0,0 @@
package tui
import (
"strings"
"testing"
)
func TestRenderLatencyBreakdown(t *testing.T) {
tests := []struct {
name string
modelStats map[string]any
wantEmpty bool
wantContains string
}{
{
name: "no details",
modelStats: map[string]any{},
wantEmpty: true,
},
{
name: "empty details",
modelStats: map[string]any{
"details": []any{},
},
wantEmpty: true,
},
{
name: "details with zero latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(0),
},
},
},
wantEmpty: true,
},
{
name: "single request with latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1500ms min 1500ms max 1500ms",
},
{
name: "multiple requests with varying latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(100),
},
map[string]any{
"latency_ms": float64(200),
},
map[string]any{
"latency_ms": float64(300),
},
},
},
wantEmpty: false,
wantContains: "avg 200ms min 100ms max 300ms",
},
{
name: "mixed valid and invalid latency values",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(500),
},
map[string]any{
"latency_ms": float64(0),
},
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1000ms min 500ms max 1500ms",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := usageTabModel{}
result := m.renderLatencyBreakdown(tt.modelStats)
if tt.wantEmpty {
if result != "" {
t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
}
return
}
if result == "" {
t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
return
}
if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestUsageTimeTranslations(t *testing.T) {
prevLocale := CurrentLocale()
t.Cleanup(func() {
SetLocale(prevLocale)
})
tests := []struct {
locale string
want string
}{
{locale: "en", want: "Time"},
{locale: "zh", want: "时间"},
}
for _, tt := range tests {
t.Run(tt.locale, func(t *testing.T) {
SetLocale(tt.locale)
if got := T("usage_time"); got != tt.want {
t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
}
})
}
}
-484
View File
@@ -1,484 +0,0 @@
// Package usage provides usage tracking and logging functionality for the CLI Proxy API server.
// It includes plugins for monitoring API usage, token consumption, and other metrics
// to help with observability and billing purposes.
package usage
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
var statisticsEnabled atomic.Bool
func init() {
statisticsEnabled.Store(true)
coreusage.RegisterPlugin(NewLoggerPlugin())
}
// LoggerPlugin collects in-memory request statistics for usage analysis.
// It implements coreusage.Plugin to receive usage records emitted by the runtime.
type LoggerPlugin struct {
stats *RequestStatistics
}
// NewLoggerPlugin constructs a new logger plugin instance.
//
// Returns:
// - *LoggerPlugin: A new logger plugin instance wired to the shared statistics store.
func NewLoggerPlugin() *LoggerPlugin { return &LoggerPlugin{stats: defaultRequestStatistics} }
// HandleUsage implements coreusage.Plugin.
// It updates the in-memory statistics store whenever a usage record is received.
//
// Parameters:
// - ctx: The context for the usage record
// - record: The usage record to aggregate
func (p *LoggerPlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
if !statisticsEnabled.Load() {
return
}
if p == nil || p.stats == nil {
return
}
p.stats.Record(ctx, record)
}
// SetStatisticsEnabled toggles whether in-memory statistics are recorded.
func SetStatisticsEnabled(enabled bool) { statisticsEnabled.Store(enabled) }
// StatisticsEnabled reports the current recording state.
func StatisticsEnabled() bool { return statisticsEnabled.Load() }
// RequestStatistics maintains aggregated request metrics in memory.
type RequestStatistics struct {
mu sync.RWMutex
totalRequests int64
successCount int64
failureCount int64
totalTokens int64
apis map[string]*apiStats
requestsByDay map[string]int64
requestsByHour map[int]int64
tokensByDay map[string]int64
tokensByHour map[int]int64
}
// apiStats holds aggregated metrics for a single API key.
type apiStats struct {
TotalRequests int64
TotalTokens int64
Models map[string]*modelStats
}
// modelStats holds aggregated metrics for a specific model within an API.
type modelStats struct {
TotalRequests int64
TotalTokens int64
Details []RequestDetail
}
// RequestDetail stores the timestamp, latency, and token usage for a single request.
type RequestDetail struct {
Timestamp time.Time `json:"timestamp"`
LatencyMs int64 `json:"latency_ms"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
// TokenStats captures the token usage breakdown for a request.
type TokenStats struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
ReasoningTokens int64 `json:"reasoning_tokens"`
CachedTokens int64 `json:"cached_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
// StatisticsSnapshot represents an immutable view of the aggregated metrics.
type StatisticsSnapshot struct {
TotalRequests int64 `json:"total_requests"`
SuccessCount int64 `json:"success_count"`
FailureCount int64 `json:"failure_count"`
TotalTokens int64 `json:"total_tokens"`
APIs map[string]APISnapshot `json:"apis"`
RequestsByDay map[string]int64 `json:"requests_by_day"`
RequestsByHour map[string]int64 `json:"requests_by_hour"`
TokensByDay map[string]int64 `json:"tokens_by_day"`
TokensByHour map[string]int64 `json:"tokens_by_hour"`
}
// APISnapshot summarises metrics for a single API key.
type APISnapshot struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
Models map[string]ModelSnapshot `json:"models"`
}
// ModelSnapshot summarises metrics for a specific model.
type ModelSnapshot struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
Details []RequestDetail `json:"details"`
}
var defaultRequestStatistics = NewRequestStatistics()
// GetRequestStatistics returns the shared statistics store.
func GetRequestStatistics() *RequestStatistics { return defaultRequestStatistics }
// NewRequestStatistics constructs an empty statistics store.
func NewRequestStatistics() *RequestStatistics {
return &RequestStatistics{
apis: make(map[string]*apiStats),
requestsByDay: make(map[string]int64),
requestsByHour: make(map[int]int64),
tokensByDay: make(map[string]int64),
tokensByHour: make(map[int]int64),
}
}
// Record ingests a new usage record and updates the aggregates.
func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record) {
if s == nil {
return
}
if !statisticsEnabled.Load() {
return
}
timestamp := record.RequestedAt
if timestamp.IsZero() {
timestamp = time.Now()
}
detail := normaliseDetail(record.Detail)
totalTokens := detail.TotalTokens
statsKey := record.APIKey
if statsKey == "" {
statsKey = resolveAPIIdentifier(ctx, record)
}
failed := record.Failed
if !failed {
failed = !resolveSuccess(ctx)
}
success := !failed
modelName := record.Model
if modelName == "" {
modelName = "unknown"
}
dayKey := timestamp.Format("2006-01-02")
hourKey := timestamp.Hour()
s.mu.Lock()
defer s.mu.Unlock()
s.totalRequests++
if success {
s.successCount++
} else {
s.failureCount++
}
s.totalTokens += totalTokens
stats, ok := s.apis[statsKey]
if !ok {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[statsKey] = stats
}
s.updateAPIStats(stats, modelName, RequestDetail{
Timestamp: timestamp,
LatencyMs: normaliseLatency(record.Latency),
Source: record.Source,
AuthIndex: record.AuthIndex,
Tokens: detail,
Failed: failed,
})
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail RequestDetail) {
stats.TotalRequests++
stats.TotalTokens += detail.Tokens.TotalTokens
modelStatsValue, ok := stats.Models[model]
if !ok {
modelStatsValue = &modelStats{}
stats.Models[model] = modelStatsValue
}
modelStatsValue.TotalRequests++
modelStatsValue.TotalTokens += detail.Tokens.TotalTokens
modelStatsValue.Details = append(modelStatsValue.Details, detail)
}
// Snapshot returns a copy of the aggregated metrics for external consumption.
func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
result := StatisticsSnapshot{}
if s == nil {
return result
}
s.mu.RLock()
defer s.mu.RUnlock()
result.TotalRequests = s.totalRequests
result.SuccessCount = s.successCount
result.FailureCount = s.failureCount
result.TotalTokens = s.totalTokens
result.APIs = make(map[string]APISnapshot, len(s.apis))
for apiName, stats := range s.apis {
apiSnapshot := APISnapshot{
TotalRequests: stats.TotalRequests,
TotalTokens: stats.TotalTokens,
Models: make(map[string]ModelSnapshot, len(stats.Models)),
}
for modelName, modelStatsValue := range stats.Models {
requestDetails := make([]RequestDetail, len(modelStatsValue.Details))
copy(requestDetails, modelStatsValue.Details)
apiSnapshot.Models[modelName] = ModelSnapshot{
TotalRequests: modelStatsValue.TotalRequests,
TotalTokens: modelStatsValue.TotalTokens,
Details: requestDetails,
}
}
result.APIs[apiName] = apiSnapshot
}
result.RequestsByDay = make(map[string]int64, len(s.requestsByDay))
for k, v := range s.requestsByDay {
result.RequestsByDay[k] = v
}
result.RequestsByHour = make(map[string]int64, len(s.requestsByHour))
for hour, v := range s.requestsByHour {
key := formatHour(hour)
result.RequestsByHour[key] = v
}
result.TokensByDay = make(map[string]int64, len(s.tokensByDay))
for k, v := range s.tokensByDay {
result.TokensByDay[k] = v
}
result.TokensByHour = make(map[string]int64, len(s.tokensByHour))
for hour, v := range s.tokensByHour {
key := formatHour(hour)
result.TokensByHour[key] = v
}
return result
}
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.LatencyMs < 0 {
detail.LatencyMs = 0
}
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
path := ginCtx.FullPath()
if path == "" && ginCtx.Request != nil {
path = ginCtx.Request.URL.Path
}
method := ""
if ginCtx.Request != nil {
method = ginCtx.Request.Method
}
if path != "" {
if method != "" {
return method + " " + path
}
return path
}
}
}
if record.Provider != "" {
return record.Provider
}
return "unknown"
}
func resolveSuccess(ctx context.Context) bool {
if ctx == nil {
return true
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return true
}
status := ginCtx.Writer.Status()
if status == 0 {
return true
}
return status < httpStatusBadRequest
}
const httpStatusBadRequest = 400
func normaliseDetail(detail coreusage.Detail) TokenStats {
tokens := TokenStats{
InputTokens: detail.InputTokens,
OutputTokens: detail.OutputTokens,
ReasoningTokens: detail.ReasoningTokens,
CachedTokens: detail.CachedTokens,
TotalTokens: detail.TotalTokens,
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens + detail.CachedTokens
}
return tokens
}
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func normaliseLatency(latency time.Duration) int64 {
if latency <= 0 {
return 0
}
return latency.Milliseconds()
}
func formatHour(hour int) string {
if hour < 0 {
hour = 0
}
hour = hour % 24
return fmt.Sprintf("%02d", hour)
}
-96
View File
@@ -1,96 +0,0 @@
package usage
import (
"context"
"testing"
"time"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestRequestStatisticsRecordIncludesLatency(t *testing.T) {
stats := NewRequestStatistics()
stats.Record(context.Background(), coreusage.Record{
APIKey: "test-key",
Model: "gpt-5.4",
RequestedAt: time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC),
Latency: 1500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
})
snapshot := stats.Snapshot()
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
if len(details) != 1 {
t.Fatalf("details len = %d, want 1", len(details))
}
if details[0].LatencyMs != 1500 {
t.Fatalf("latency_ms = %d, want 1500", details[0].LatencyMs)
}
}
func TestRequestStatisticsMergeSnapshotDedupIgnoresLatency(t *testing.T) {
stats := NewRequestStatistics()
timestamp := time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC)
first := StatisticsSnapshot{
APIs: map[string]APISnapshot{
"test-key": {
Models: map[string]ModelSnapshot{
"gpt-5.4": {
Details: []RequestDetail{{
Timestamp: timestamp,
LatencyMs: 0,
Source: "user@example.com",
AuthIndex: "0",
Tokens: TokenStats{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
}},
},
},
},
},
}
second := StatisticsSnapshot{
APIs: map[string]APISnapshot{
"test-key": {
Models: map[string]ModelSnapshot{
"gpt-5.4": {
Details: []RequestDetail{{
Timestamp: timestamp,
LatencyMs: 2500,
Source: "user@example.com",
AuthIndex: "0",
Tokens: TokenStats{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
}},
},
},
},
},
}
result := stats.MergeSnapshot(first)
if result.Added != 1 || result.Skipped != 0 {
t.Fatalf("first merge = %+v, want added=1 skipped=0", result)
}
result = stats.MergeSnapshot(second)
if result.Added != 0 || result.Skipped != 1 {
t.Fatalf("second merge = %+v, want added=0 skipped=1", result)
}
snapshot := stats.Snapshot()
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
if len(details) != 1 {
t.Fatalf("details len = %d, want 1", len(details))
}
}
+3
View File
@@ -39,6 +39,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
}
if oldCfg.RedisUsageQueueRetentionSeconds != newCfg.RedisUsageQueueRetentionSeconds {
changes = append(changes, fmt.Sprintf("redis-usage-queue-retention-seconds: %d -> %d", oldCfg.RedisUsageQueueRetentionSeconds, newCfg.RedisUsageQueueRetentionSeconds))
}
if oldCfg.DisableCooling != newCfg.DisableCooling {
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
}
+28 -4
View File
@@ -375,11 +375,32 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
} else if requestID := logging.GetGinRequestID(c); requestID != "" {
} else if requestID = logging.GetGinRequestID(c); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
}
}
newCtx, cancel := context.WithCancel(parentCtx)
endpoint := ""
if c != nil && c.Request != nil {
path := strings.TrimSpace(c.FullPath())
if path == "" && c.Request.URL != nil {
path = strings.TrimSpace(c.Request.URL.Path)
}
if path != "" {
method := strings.TrimSpace(c.Request.Method)
if method != "" {
endpoint = method + " " + path
} else {
endpoint = path
}
}
}
if endpoint != "" {
newCtx = logging.WithEndpoint(newCtx, endpoint)
}
newCtx = logging.WithResponseStatusHolder(newCtx)
cancelCtx := newCtx
if requestCtx != nil && requestCtx != parentCtx {
go func() {
@@ -393,6 +414,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) {
if c != nil {
logging.SetResponseStatus(cancelCtx, c.Writer.Status())
}
if h.Cfg.RequestLog && len(params) == 1 {
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(bytes.TrimSpace(existingBytes)) > 0 {
@@ -515,7 +539,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -563,7 +587,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -615,7 +639,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return nil, nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON
if len(payload) == 0 {
payload = nil
@@ -13,6 +13,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -46,6 +47,9 @@ func writeResponsesSSEChunk(w io.Writer, chunk []byte) {
type responsesSSEFramer struct {
pending []byte
outputItems map[int][]byte
outputOrder []int
unindexedOutputItems [][]byte
}
func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
@@ -61,7 +65,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if frameLen == 0 {
break
}
writeResponsesSSEChunk(w, f.pending[:frameLen])
f.writeFrame(w, f.pending[:frameLen])
copy(f.pending, f.pending[frameLen:])
f.pending = f.pending[:len(f.pending)-frameLen]
}
@@ -72,7 +76,7 @@ func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) {
if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) {
return
}
writeResponsesSSEChunk(w, f.pending)
f.writeFrame(w, f.pending)
f.pending = f.pending[:0]
}
@@ -88,10 +92,133 @@ func (f *responsesSSEFramer) Flush(w io.Writer) {
f.pending = f.pending[:0]
return
}
writeResponsesSSEChunk(w, f.pending)
f.writeFrame(w, f.pending)
f.pending = f.pending[:0]
}
func (f *responsesSSEFramer) writeFrame(w io.Writer, frame []byte) {
writeResponsesSSEChunk(w, f.repairFrame(frame))
}
func (f *responsesSSEFramer) repairFrame(frame []byte) []byte {
payload, ok := responsesSSEDataPayload(frame)
if !ok || len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) || !json.Valid(payload) {
return frame
}
switch gjson.GetBytes(payload, "type").String() {
case "response.output_item.done":
f.recordOutputItem(payload)
case "response.completed":
repaired := f.repairCompletedPayload(payload)
if !bytes.Equal(repaired, payload) {
return responsesSSEFrameWithData(frame, repaired)
}
}
return frame
}
func responsesSSEDataPayload(frame []byte) ([]byte, bool) {
var payload []byte
found := false
for _, line := range bytes.Split(frame, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
trimmed := bytes.TrimSpace(line)
if !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
data := bytes.TrimSpace(trimmed[len("data:"):])
if found {
payload = append(payload, '\n')
}
payload = append(payload, data...)
found = true
}
return payload, found
}
func responsesSSEFrameWithData(frame, payload []byte) []byte {
var out bytes.Buffer
for _, line := range bytes.Split(frame, []byte("\n")) {
line = bytes.TrimRight(line, "\r")
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
out.Write(line)
out.WriteByte('\n')
}
for _, line := range bytes.Split(payload, []byte("\n")) {
out.WriteString("data: ")
out.Write(line)
out.WriteByte('\n')
}
out.WriteByte('\n')
return out.Bytes()
}
func (f *responsesSSEFramer) recordOutputItem(payload []byte) {
item := gjson.GetBytes(payload, "item")
if !item.Exists() || !item.IsObject() || item.Get("type").String() == "" {
return
}
if outputIndex := gjson.GetBytes(payload, "output_index"); outputIndex.Exists() {
index := int(outputIndex.Int())
if f.outputItems == nil {
f.outputItems = make(map[int][]byte)
}
if _, exists := f.outputItems[index]; !exists {
f.outputOrder = append(f.outputOrder, index)
}
f.outputItems[index] = append([]byte(nil), item.Raw...)
return
}
f.unindexedOutputItems = append(f.unindexedOutputItems, append([]byte(nil), item.Raw...))
}
func (f *responsesSSEFramer) repairCompletedPayload(payload []byte) []byte {
if len(f.outputOrder) == 0 && len(f.unindexedOutputItems) == 0 {
return payload
}
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && (!output.IsArray() || len(output.Array()) > 0) {
return payload
}
var outputJSON bytes.Buffer
outputJSON.WriteByte('[')
indexes := append([]int(nil), f.outputOrder...)
sort.Ints(indexes)
written := 0
for _, index := range indexes {
item, ok := f.outputItems[index]
if !ok {
continue
}
if written > 0 {
outputJSON.WriteByte(',')
}
outputJSON.Write(item)
written++
}
for _, item := range f.unindexedOutputItems {
if written > 0 {
outputJSON.WriteByte(',')
}
outputJSON.Write(item)
written++
}
outputJSON.WriteByte(']')
repaired, err := sjson.SetRawBytes(payload, "response.output", outputJSON.Bytes())
if err != nil {
return payload
}
return repaired
}
func responsesSSEFrameLen(chunk []byte) int {
if len(chunk) == 0 {
return 0
@@ -10,6 +10,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) {
@@ -53,12 +54,108 @@ func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) {
t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1)
}
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}"
expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"function_call\",\"arguments\":\"{}\"}]}}"
if parts[1] != expectedPart2 {
t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2)
}
}
func TestForwardResponsesStreamRepairsEmptyCompletedOutputFromDoneItems(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 3)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte(`data: {"type":"response.output_item.done","output_index":0,"item":{"type":"reasoning","id":"rs-1","summary":[]}}`)
data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{\"cmd\":\"pwd\"}","status":"completed"}}`)
data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
if len(parts) != 3 {
t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
}
payload := strings.TrimPrefix(parts[2], "data: ")
output := gjson.Get(payload, "response.output")
if !output.IsArray() || len(output.Array()) != 2 {
t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
}
if got := gjson.Get(payload, "response.output.1.name").String(); got != "shell" {
t.Fatalf("expected function_call name to be preserved, got %q in %s", got, payload)
}
if got := gjson.Get(payload, "response.output.1.arguments").String(); got != `{"cmd":"pwd"}` {
t.Fatalf("expected function_call arguments to be preserved, got %q in %s", got, payload)
}
}
func TestForwardResponsesStreamRepairsMixedIndexedAndUnindexedDoneItems(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 3)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte(`data: {"type":"response.output_item.done","output_index":1,"item":{"type":"function_call","id":"fc-1","call_id":"call-1","name":"shell","arguments":"{}","status":"completed"}}`)
data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"message","id":"msg-1","role":"assistant","content":[{"type":"output_text","text":"done"}]}}`)
data <- []byte(`data: {"type":"response.completed","response":{"id":"resp-1","output":[]}}`)
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
if len(parts) != 3 {
t.Fatalf("expected 3 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
}
payload := strings.TrimPrefix(parts[2], "data: ")
output := gjson.Get(payload, "response.output")
if !output.IsArray() || len(output.Array()) != 2 {
t.Fatalf("expected repaired completed output with 2 items, got %s", output.Raw)
}
if got := gjson.Get(payload, "response.output.0.name").String(); got != "shell" {
t.Fatalf("expected indexed function_call to be preserved first, got %q in %s", got, payload)
}
if got := gjson.Get(payload, "response.output.1.id").String(); got != "msg-1" {
t.Fatalf("expected unindexed message to be appended, got %q in %s", got, payload)
}
}
func TestForwardResponsesStreamRepairsMultilineCompletedOutputAsSSEDataLines(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
data := make(chan []byte, 2)
errs := make(chan *interfaces.ErrorMessage)
data <- []byte(`data: {"type":"response.output_item.done","item":{"type":"function_call","arguments":"{}"}}`)
data <- []byte("data: {\"type\":\"response.completed\",\ndata: \"response\":{\"id\":\"resp-1\",\"output\":[]}}\n\n")
close(data)
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs, nil)
parts := strings.Split(strings.TrimSpace(recorder.Body.String()), "\n\n")
if len(parts) != 2 {
t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), recorder.Body.String())
}
completedFrame := []byte(parts[1])
for _, line := range strings.Split(parts[1], "\n") {
if line != "" && !strings.HasPrefix(line, "data: ") {
t.Fatalf("expected every completed payload line to be an SSE data line, got %q in %q", line, parts[1])
}
}
payload, ok := responsesSSEDataPayload(completedFrame)
if !ok {
t.Fatalf("expected completed frame to contain data payload: %q", parts[1])
}
output := gjson.GetBytes(payload, "response.output")
if !output.IsArray() || len(output.Array()) != 1 {
t.Fatalf("expected repaired completed output with 1 item, got %s from %q", output.Raw, payload)
}
}
func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) {
h, recorder, c, flusher := newResponsesStreamTestHandler(t)
@@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
forceTranscriptReplayNextRequest := false
for {
msgType, payload, errReadMessage := conn.ReadMessage()
@@ -115,6 +116,22 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
if forceTranscriptReplayNextRequest {
allowIncrementalInputWithPreviousResponseID = false
}
allowCompactionReplayBypass := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowCompactionReplayBypass = responsesWebsocketAuthSupportsCompactionReplay(pinnedAuth)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowCompactionReplayBypass = h.websocketUpstreamSupportsCompactionReplayForModel(requestModelName)
}
var requestJSON []byte
var updatedLastRequest []byte
@@ -124,6 +141,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
allowCompactionReplayBypass,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
@@ -165,7 +183,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
updatedLastRequest = bytes.Clone(requestJSON)
previousLastRequest := bytes.Clone(lastRequest)
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
forcedTranscriptReplay := forceTranscriptReplayNextRequest
lastRequest = updatedLastRequest
if forcedTranscriptReplay {
forceTranscriptReplayNextRequest = false
}
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
@@ -190,12 +214,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
pinnedAuthID = ""
forceTranscriptReplayNextRequest = true
lastRequest = previousLastRequest
lastResponseOutput = previousLastResponseOutput
continue
}
lastResponseOutput = completedOutput
}
}
@@ -222,10 +253,10 @@ func websocketUpgradeHeaders(req *http.Request) http.Header {
}
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true, true)
}
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
@@ -233,10 +264,10 @@ func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []by
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
@@ -265,7 +296,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
return normalized, bytes.Clone(normalized), nil
}
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
@@ -315,8 +346,24 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}
// When the client sends a compact replay for a downstream that can consume it
// directly, the input already carries the canonical history. In that case,
// skip merging with stale lastRequest/lastResponseOutput to avoid breaking
// function_call / function_call_output pairings.
// See: https://github.com/router-for-me/CLIProxyAPI/issues/2207
var mergedInput string
if allowCompactionReplayBypass && inputContainsFullTranscript(nextInput) {
log.Infof("responses websocket: full transcript detected, skipping stale merge (input items=%d)", len(nextInput.Array()))
mergedInput = nextInput.Raw
} else {
appendInputRaw := nextInput.Raw
if inputContainsFullTranscript(nextInput) {
appendInputRaw = inputWithoutCompactionItems(nextInput)
}
existingInput := gjson.GetBytes(lastRequest, "input")
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
var errMerge error
mergedInput, errMerge = mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
@@ -324,13 +371,14 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
}
}
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, appendInputRaw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
}
dedupedInput, errDedupeFunctionCalls := dedupeFunctionCallsByCallID(mergedInput)
if errDedupeFunctionCalls == nil {
mergedInput = dedupedInput
@@ -480,73 +528,105 @@ func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, met
}
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
if h == nil || h.AuthManager == nil {
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
for _, auth := range auths {
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}
resolvedModelName := modelName
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsCompactionReplayForModel(modelName string) bool {
auths, _ := h.responsesWebsocketAvailableAuthsForModel(modelName)
if len(auths) == 0 {
return false
}
for _, auth := range auths {
if !responsesWebsocketAuthSupportsCompactionReplay(auth) {
return false
}
}
return true
}
func (h *OpenAIResponsesAPIHandler) responsesWebsocketAvailableAuthsForModel(modelName string) ([]*coreauth.Auth, string) {
if h == nil || h.AuthManager == nil {
return nil, ""
}
resolvedModelName := responsesWebsocketResolvedModelName(modelName)
providerSet, modelKey := responsesWebsocketProviderSetForModel(resolvedModelName)
if len(providerSet) == 0 {
return nil, modelKey
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
available := make([]*coreauth.Auth, 0, len(auths))
for _, auth := range auths {
if !responsesWebsocketAuthMatchesModel(auth, providerSet, modelKey, registryRef, now) {
continue
}
available = append(available, auth)
}
return available, modelKey
}
func responsesWebsocketResolvedModelName(modelName string) string {
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
return fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
return resolvedBase
}
return util.ResolveAutoModel(modelName)
}
func responsesWebsocketProviderSetForModel(resolvedModelName string) (map[string]struct{}, string) {
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return false
}
providerSet := make(map[string]struct{}, len(providers))
for i := 0; i < len(providers); i++ {
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
if len(providerSet) == 0 {
return false
}
modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
for i := 0; i < len(auths); i++ {
auth := auths[i]
return providerSet, modelKey
}
func responsesWebsocketAuthMatchesModel(auth *coreauth.Auth, providerSet map[string]struct{}, modelKey string, registryRef *registry.ModelRegistry, now time.Time) bool {
if auth == nil {
continue
return false
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
return false
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
continue
}
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
continue
}
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}
return responsesWebsocketAuthAvailableForModel(auth, modelKey, now)
}
func responsesWebsocketAuthSupportsCompactionReplay(auth *coreauth.Auth) bool {
if auth == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(auth.Provider), "codex")
}
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
if auth == nil {
@@ -691,6 +771,42 @@ func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
return string(out), nil
}
// inputContainsFullTranscript returns true when the input array carries compact
// replay markers that indicate the client already sent the full conversation
// transcript. Merging that input with stale lastRequest/lastResponseOutput
// would duplicate or break function_call/function_call_output pairings, so the
// caller should use the input as-is.
//
// Assistant messages alone are not enough to classify the payload as a replay:
// incremental websocket requests may legitimately append assistant items.
func inputContainsFullTranscript(input gjson.Result) bool {
if !input.IsArray() {
return false
}
for _, item := range input.Array() {
t := item.Get("type").String()
if t == "compaction" || t == "compaction_summary" {
return true
}
}
return false
}
func inputWithoutCompactionItems(input gjson.Result) string {
if !input.IsArray() {
return normalizeJSONArrayRaw([]byte(input.Raw))
}
filtered := make([]string, 0, len(input.Array()))
for _, item := range input.Array() {
t := item.Get("type").String()
if t == "compaction" || t == "compaction_summary" {
continue
}
filtered = append(filtered, item.Raw)
}
return "[" + strings.Join(filtered, ",") + "]"
}
func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
@@ -711,7 +827,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errs <-chan *interfaces.ErrorMessage,
wsTimelineLog *strings.Builder,
sessionID string,
) ([]byte, error) {
) ([]byte, *interfaces.ErrorMessage, error) {
completed := false
completedOutput := []byte("[]")
downstreamSessionKey := ""
@@ -723,7 +839,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
return completedOutput, nil, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
@@ -748,7 +864,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errMsg, errWrite
}
}
if errMsg != nil {
@@ -756,7 +872,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
} else {
cancel(nil)
}
return completedOutput, nil
return completedOutput, errMsg, nil
case chunk, ok := <-data:
if !ok {
if !completed {
@@ -782,13 +898,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
return completedOutput, errMsg, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
return completedOutput, errMsg, nil
}
cancel(nil)
return completedOutput, nil
return completedOutput, nil, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
@@ -815,13 +931,31 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
return completedOutput, nil, errWrite
}
}
}
}
}
func shouldReleaseResponsesWebsocketPinnedAuth(errMsg *interfaces.ErrorMessage) bool {
if errMsg == nil {
return false
}
status := errMsg.StatusCode
if status <= 0 && errMsg.Error != nil {
if se, ok := errMsg.Error.(interface{ StatusCode() int }); ok && se != nil {
status = se.StatusCode()
}
}
switch status {
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests:
return true
default:
return false
}
}
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
@@ -69,6 +69,22 @@ type websocketAuthCaptureExecutor struct {
authIDs []string
}
type websocketPinnedFailoverExecutor struct {
mu sync.Mutex
authIDs []string
calls map[string]int
payloads map[string][][]byte
}
type websocketPinnedFailoverStatusError struct {
status int
msg string
}
func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -106,6 +122,76 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
return append([]string(nil), e.authIDs...)
}
func (e *websocketPinnedFailoverExecutor) Identifier() string { return "test-provider" }
func (e *websocketPinnedFailoverExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketPinnedFailoverExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
authID := ""
if auth != nil {
authID = auth.ID
}
e.mu.Lock()
if e.calls == nil {
e.calls = make(map[string]int)
}
if e.payloads == nil {
e.payloads = make(map[string][][]byte)
}
e.authIDs = append(e.authIDs, authID)
e.calls[authID]++
call := e.calls[authID]
e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload))
e.mu.Unlock()
if authID == "auth-a" && call == 2 {
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{
status: http.StatusTooManyRequests,
msg: `{"error":{"message":"quota exhausted","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
}}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-%s-%d","output":[{"type":"message","id":"out-%s-%d"}]}}`, authID, call, authID, call))}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketPinnedFailoverExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketPinnedFailoverExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketPinnedFailoverExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func (e *websocketPinnedFailoverExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
return append([]string(nil), e.authIDs...)
}
func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte {
e.mu.Lock()
defer e.mu.Unlock()
src := e.payloads[authID]
out := make([][]byte, len(src))
for i := range src {
out[i] = bytes.Clone(src[i])
}
return out
}
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -242,7 +328,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
@@ -278,7 +364,7 @@ func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncre
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
@@ -681,7 +767,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
close(errCh)
var timelineLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
completedOutput, errMsg, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
@@ -694,6 +780,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
serverErrCh <- err
return
}
if errMsg != nil {
serverErrCh <- fmt.Errorf("unexpected websocket error message: %v", errMsg.Error)
return
}
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
serverErrCh <- errors.New("completed output not captured")
return
@@ -760,7 +850,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing
return
}
_, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
_, _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
@@ -867,6 +957,53 @@ func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
}
}
func TestWebsocketUpstreamSupportsCompactionReplayForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
ID: "auth-codex",
Provider: "codex",
Status: coreauth.StatusActive,
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
t.Fatalf("expected codex upstream to support compaction replay")
}
}
func TestWebsocketUpstreamSupportsCompactionReplayForModelFalseWhenMixedBackends(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auths := []*coreauth.Auth{
{ID: "auth-codex", Provider: "codex", Status: coreauth.StatusActive},
{ID: "auth-claude", Provider: "claude", Status: coreauth.StatusActive},
}
for _, auth := range auths {
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth %s: %v", auth.ID, err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
}
t.Cleanup(func() {
for _, auth := range auths {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
}
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if h.websocketUpstreamSupportsCompactionReplayForModel("test-model") {
t.Fatalf("expected mixed backend model to disable compaction replay bypass")
}
}
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -1066,6 +1203,99 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
}
}
func TestResponsesWebsocketReleasesPinnedAuthAfterQuotaError(t *testing.T) {
gin.SetMode(gin.TestMode)
selector := &orderedWebsocketSelector{order: []string{"auth-a", "auth-b"}}
executor := &websocketPinnedFailoverExecutor{}
manager := coreauth.NewManager(nil, selector, nil)
manager.RegisterExecutor(executor)
authA := &coreauth.Auth{
ID: "auth-a",
Provider: executor.Identifier(),
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), authA); err != nil {
t.Fatalf("Register auth A: %v", err)
}
authB := &coreauth.Auth{
ID: "auth-b",
Provider: executor.Identifier(),
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), authB); err != nil {
t.Fatalf("Register auth B: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(authA.ID, authA.Provider, []*registry.ModelInfo{{ID: "quota-model"}})
registry.GetGlobalRegistry().RegisterClient(authB.ID, authB.Provider, []*registry.ModelInfo{{ID: "quota-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(authA.ID)
registry.GetGlobalRegistry().UnregisterClient(authB.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
if errClose := conn.Close(); errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
requests := []string{
`{"type":"response.create","model":"quota-model","input":[{"type":"message","id":"msg-1"}]}`,
`{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-2"}]}`,
`{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-3"}]}`,
}
wantTypes := []string{wsEventTypeCompleted, wsEventTypeError, wsEventTypeCompleted}
for i := range requests {
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wantTypes[i] {
t.Fatalf("message %d payload type = %s, want %s: %s", i+1, got, wantTypes[i], payload)
}
if i == 1 && int(gjson.GetBytes(payload, "status").Int()) != http.StatusTooManyRequests {
t.Fatalf("quota payload status = %d, want %d: %s", gjson.GetBytes(payload, "status").Int(), http.StatusTooManyRequests, payload)
}
}
if got := executor.AuthIDs(); len(got) != 3 || got[0] != "auth-a" || got[1] != "auth-a" || got[2] != "auth-b" {
t.Fatalf("selected auth IDs = %v, want [auth-a auth-a auth-b]", got)
}
authBPayloads := executor.Payloads("auth-b")
if len(authBPayloads) != 1 {
t.Fatalf("auth-b payload count = %d, want 1", len(authBPayloads))
}
authBPayload := authBPayloads[0]
if gjson.GetBytes(authBPayload, "previous_response_id").Exists() {
t.Fatalf("previous_response_id leaked after auth failover: %s", authBPayload)
}
authBInput := gjson.GetBytes(authBPayload, "input").Raw
if !strings.Contains(authBInput, `"id":"msg-1"`) || !strings.Contains(authBInput, `"id":"msg-3"`) {
t.Fatalf("auth-b replay input missing expected transcript items: %s", authBInput)
}
}
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
lastResponseOutput := []byte(`[
@@ -1400,3 +1630,171 @@ func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *t
t.Fatalf("post-compact function call id = %s, want call-1", items[0].Get("call_id").String())
}
}
func TestInputContainsFullTranscriptFalseForAssistantMessageOnly(t *testing.T) {
input := gjson.Parse(`[
{"type":"message","role":"user","content":"hello"},
{"type":"message","role":"assistant","content":"hi there"}
]`)
if inputContainsFullTranscript(input) {
t.Fatal("assistant message alone must not be treated as full transcript")
}
}
func TestInputContainsFullTranscriptDetectsCompactionItem(t *testing.T) {
for _, typ := range []string{"compaction", "compaction_summary"} {
input := gjson.Parse(`[{"type":"message","role":"user","content":"hello"},{"type":"` + typ + `","encrypted_content":"summary"}]`)
if !inputContainsFullTranscript(input) {
t.Fatalf("expected full transcript for type=%s", typ)
}
}
}
func TestInputContainsFullTranscriptFalseForIncremental(t *testing.T) {
// Normal incremental turns: user messages or function_call_output only.
for _, raw := range []string{
`[{"type":"function_call_output","call_id":"call-1","output":"result"}]`,
`[{"type":"message","role":"user","content":"next question"}]`,
`[]`,
} {
if inputContainsFullTranscript(gjson.Parse(raw)) {
t.Fatalf("incremental input must not be detected as full transcript: %s", raw)
}
}
}
func TestNormalizeSubsequentRequestCompactSkipsMerge(t *testing.T) {
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
{"type":"message","role":"user","id":"msg-1","content":"original long prompt"},
{"type":"message","role":"assistant","id":"msg-2","content":"original long response"},
{"type":"function_call","id":"fc-1","call_id":"call-old","name":"bash","arguments":"{}"},
{"type":"function_call_output","id":"fco-1","call_id":"call-old","output":"old result"}
]}`)
lastResponseOutput := []byte(`[
{"type":"message","role":"assistant","id":"msg-3","content":"another assistant reply"},
{"type":"function_call","id":"fc-2","call_id":"call-stale","name":"read","arguments":"{}"}
]`)
// Remote compact response: user messages + compaction item, NO assistant message.
// This is the primary compact scenario from Codex CLI.
raw := []byte(`{"type":"response.create","input":[
{"type":"message","role":"user","id":"msg-1c","content":"compacted user msg"},
{"type":"compaction","encrypted_content":"conversation summary"}
]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 2 {
t.Fatalf("input len = %d, want 2 (compacted only); stale state was not skipped", len(input))
}
if input[0].Get("id").String() != "msg-1c" {
t.Fatalf("input[0].id = %q, want %q", input[0].Get("id").String(), "msg-1c")
}
if input[1].Get("type").String() != "compaction" {
t.Fatalf("input[1].type = %q, want %q", input[1].Get("type").String(), "compaction")
}
}
func TestNormalizeSubsequentRequestCompactMergesWhenCompactionReplayUnsupported(t *testing.T) {
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
{"type":"message","role":"user","id":"msg-1","content":"original long prompt"},
{"type":"message","role":"assistant","id":"msg-2","content":"original long response"},
{"type":"function_call","id":"fc-1","call_id":"call-old","name":"bash","arguments":"{}"},
{"type":"function_call_output","id":"fco-1","call_id":"call-old","output":"old result"}
]}`)
lastResponseOutput := []byte(`[
{"type":"message","role":"assistant","id":"msg-3","content":"another assistant reply"},
{"type":"function_call","id":"fc-2","call_id":"call-stale","name":"read","arguments":"{}"}
]`)
raw := []byte(`{"type":"response.create","input":[
{"type":"message","role":"user","id":"msg-1c","content":"compacted user msg"},
{"type":"compaction","encrypted_content":"conversation summary"}
]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 7 {
t.Fatalf("input len = %d, want 7 (merged fallback without compaction items)", len(input))
}
wantIDs := []string{"msg-1", "msg-2", "fc-1", "fco-1", "msg-3", "fc-2", "msg-1c"}
for i, want := range wantIDs {
got := input[i].Get("id").String()
if got != want {
t.Fatalf("input[%d].id = %q, want %q", i, got, want)
}
}
for _, item := range input {
if item.Get("type").String() == "compaction" || item.Get("type").String() == "compaction_summary" {
t.Fatalf("compaction items must be stripped for unsupported downstream fallback: %s", item.Raw)
}
}
}
func TestNormalizeSubsequentRequestIncrementalInputStillMerges(t *testing.T) {
// Normal incremental flow: user sends function_call_output (no assistant message).
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
{"type":"message","role":"user","id":"msg-1","content":"hello"}
]}`)
lastResponseOutput := []byte(`[
{"type":"message","role":"assistant","id":"msg-2","content":"let me check"},
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"bash","arguments":"{}"}
]`)
raw := []byte(`{"type":"response.create","input":[
{"type":"function_call_output","call_id":"call-1","id":"fco-1","output":"done"}
]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
// Should be merged: msg-1 + msg-2 + fc-1 + fco-1 = 4 items
if len(input) != 4 {
t.Fatalf("input len = %d, want 4 (merged)", len(input))
}
wantIDs := []string{"msg-1", "msg-2", "fc-1", "fco-1"}
for i, want := range wantIDs {
got := input[i].Get("id").String()
if got != want {
t.Fatalf("input[%d].id = %q, want %q", i, got, want)
}
}
}
func TestNormalizeSubsequentRequestAssistantInputTriggersTranscriptReplacement(t *testing.T) {
// After dev's shouldReplaceWebsocketTranscript, assistant messages in input
// trigger transcript replacement (no merge with prior state).
lastRequest := []byte(`{"model":"gpt-5.4","stream":true,"input":[
{"type":"message","role":"user","id":"msg-1","content":"hello"}
]}`)
lastResponseOutput := []byte(`[
{"type":"message","role":"assistant","id":"msg-2","content":"prior assistant"},
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"bash","arguments":"{}"}
]`)
raw := []byte(`{"type":"response.append","input":[
{"type":"message","role":"assistant","id":"msg-3","content":"patched assistant turn"}
]}`)
normalized, _, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 1 {
t.Fatalf("input len = %d, want 1 (transcript replacement, not merge)", len(input))
}
if input[0].Get("id").String() != "msg-3" {
t.Fatalf("input[0].id = %q, want %q", input[0].Get("id").String(), "msg-3")
}
}
+44
View File
@@ -22,6 +22,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
log "github.com/sirupsen/logrus"
)
@@ -827,6 +828,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
ctx = contextWithRequestedModelAlias(ctx, opts, routeModel)
var lastErr error
for idx, execModel := range execModels {
resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled)
@@ -1126,6 +1128,9 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
auth.Success = existing.Success
auth.Failed = existing.Failed
auth.recentRequests = existing.recentRequests
if !existing.Disabled && existing.Status != StatusDisabled && !auth.Disabled && auth.Status != StatusDisabled {
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
auth.ModelStates = existing.ModelStates
@@ -1316,6 +1321,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execCtx = contextWithRequestedModelAlias(execCtx, opts, routeModel)
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
@@ -1394,6 +1400,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execCtx = contextWithRequestedModelAlias(execCtx, opts, routeModel)
models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 {
@@ -1531,6 +1538,36 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
}
}
func contextWithRequestedModelAlias(ctx context.Context, opts cliproxyexecutor.Options, fallback string) context.Context {
alias := requestedModelAliasFromOptions(opts, fallback)
return coreusage.WithRequestedModelAlias(ctx, alias)
}
func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 {
return fallback
}
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return fallback
}
switch value := raw.(type) {
case string:
if strings.TrimSpace(value) == "" {
return fallback
}
return strings.TrimSpace(value)
case []byte:
if len(value) == 0 {
return fallback
}
return strings.TrimSpace(string(value))
default:
return fallback
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
@@ -2021,6 +2058,12 @@ func (m *Manager) MarkResult(ctx context.Context, result Result) {
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
now := time.Now()
auth.recordRecentRequest(now, result.Success)
if result.Success {
auth.Success++
} else {
auth.Failed++
}
if result.Success {
if result.Model != "" {
@@ -3087,6 +3130,7 @@ func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxy
creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt)
}
creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
creditsCtx = contextWithRequestedModelAlias(creditsCtx, creditsOpts, routeModel)
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
models := m.executionModelCandidates(c.auth, routeModel)
if len(models) == 0 {
@@ -10,6 +10,7 @@ import (
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
type aliasRoutingExecutor struct {
@@ -17,13 +18,15 @@ type aliasRoutingExecutor struct {
mu sync.Mutex
executeModels []string
executeAliases []string
}
func (e *aliasRoutingExecutor) Identifier() string { return e.id }
func (e *aliasRoutingExecutor) Execute(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
func (e *aliasRoutingExecutor) Execute(ctx context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
e.executeAliases = append(e.executeAliases, coreusage.RequestedModelAliasFromContext(ctx))
e.mu.Unlock()
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
@@ -52,6 +55,14 @@ func (e *aliasRoutingExecutor) ExecuteModels() []string {
return out
}
func (e *aliasRoutingExecutor) ExecuteAliases() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeAliases))
copy(out, e.executeAliases)
return out
}
func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
const (
provider = "antigravity"
@@ -108,4 +119,12 @@ func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
if gotModels[0] != targetModel {
t.Fatalf("execute model = %q, want %q", gotModels[0], targetModel)
}
gotAliases := executor.ExecuteAliases()
if len(gotAliases) != 1 {
t.Fatalf("execute aliases len = %d, want 1", len(gotAliases))
}
if gotAliases[0] != routeModel {
t.Fatalf("execute alias = %q, want %q", gotAliases[0], routeModel)
}
}
@@ -0,0 +1,95 @@
package auth
import (
"context"
"testing"
"time"
)
func TestManagerMarkResultRecordsRecentRequests(t *testing.T) {
mgr := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Attributes: map[string]string{
"runtime_only": "true",
},
Metadata: map[string]any{
"type": "antigravity",
},
}
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
t.Fatalf("Register returned error: %v", err)
}
mgr.MarkResult(context.Background(), Result{AuthID: "auth-1", Provider: "antigravity", Model: "gpt-5", Success: true})
mgr.MarkResult(context.Background(), Result{AuthID: "auth-1", Provider: "antigravity", Model: "gpt-5", Success: false})
gotAuth, ok := mgr.GetByID("auth-1")
if !ok || gotAuth == nil {
t.Fatalf("GetByID returned ok=%v auth=%v", ok, gotAuth)
}
if gotAuth.Success != 1 || gotAuth.Failed != 1 {
t.Fatalf("auth totals = success=%d failed=%d, want 1/1", gotAuth.Success, gotAuth.Failed)
}
snapshot := gotAuth.RecentRequestsSnapshot(time.Now())
var successTotal int64
var failedTotal int64
for _, bucket := range snapshot {
successTotal += bucket.Success
failedTotal += bucket.Failed
}
if successTotal != 1 || failedTotal != 1 {
t.Fatalf("totals = success=%d failed=%d, want 1/1", successTotal, failedTotal)
}
}
func TestManagerUpdatePreservesRecentRequestsAndTotals(t *testing.T) {
mgr := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
},
}
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
t.Fatalf("Register returned error: %v", err)
}
mgr.MarkResult(context.Background(), Result{AuthID: "auth-1", Provider: "antigravity", Model: "gpt-5", Success: true})
updated := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{
"type": "antigravity",
"note": "updated",
},
}
if _, err := mgr.Update(WithSkipPersist(context.Background()), updated); err != nil {
t.Fatalf("Update returned error: %v", err)
}
gotAuth, ok := mgr.GetByID("auth-1")
if !ok || gotAuth == nil {
t.Fatalf("GetByID returned ok=%v auth=%v", ok, gotAuth)
}
if gotAuth.Success != 1 || gotAuth.Failed != 0 {
t.Fatalf("auth totals = success=%d failed=%d, want 1/0", gotAuth.Success, gotAuth.Failed)
}
snapshot := gotAuth.RecentRequestsSnapshot(time.Now())
var successTotal int64
var failedTotal int64
for _, bucket := range snapshot {
successTotal += bucket.Success
failedTotal += bucket.Failed
}
if successTotal != 1 || failedTotal != 0 {
t.Fatalf("bucket totals = success=%d failed=%d, want 1/0", successTotal, failedTotal)
}
}
+89
View File
@@ -92,9 +92,34 @@ type Auth struct {
// Runtime carries non-serialisable data used during execution (in-memory only).
Runtime any `json:"-"`
Success int64 `json:"-"`
Failed int64 `json:"-"`
recentRequests recentRequestRing `json:"-"`
indexAssigned bool `json:"-"`
}
const (
recentRequestBucketSeconds int64 = 10 * 60
recentRequestBucketCount = 20
)
type recentRequestBucket struct {
bucketID int64
success int64
failed int64
}
type recentRequestRing struct {
buckets [recentRequestBucketCount]recentRequestBucket
}
type RecentRequestBucket struct {
Time string `json:"time"`
Success int64 `json:"success"`
Failed int64 `json:"failed"`
}
// QuotaState contains limiter tracking data for a credential.
type QuotaState struct {
// Exceeded indicates the credential recently hit a quota error.
@@ -125,6 +150,70 @@ type ModelState struct {
UpdatedAt time.Time `json:"updated_at"`
}
func recentRequestBucketID(now time.Time) int64 {
if now.IsZero() {
return 0
}
return now.Unix() / recentRequestBucketSeconds
}
func recentRequestBucketIndex(bucketID int64) int {
mod := bucketID % int64(recentRequestBucketCount)
if mod < 0 {
mod += int64(recentRequestBucketCount)
}
return int(mod)
}
func formatRecentRequestBucketLabel(bucketID int64) string {
start := time.Unix(bucketID*recentRequestBucketSeconds, 0).In(time.Local)
end := start.Add(time.Duration(recentRequestBucketSeconds) * time.Second)
return start.Format("15:04") + "-" + end.Format("15:04")
}
func (a *Auth) recordRecentRequest(now time.Time, success bool) {
if a == nil {
return
}
bucketID := recentRequestBucketID(now)
idx := recentRequestBucketIndex(bucketID)
bucket := &a.recentRequests.buckets[idx]
if bucket.bucketID != bucketID {
bucket.bucketID = bucketID
bucket.success = 0
bucket.failed = 0
}
if success {
bucket.success++
return
}
bucket.failed++
}
func (a *Auth) RecentRequestsSnapshot(now time.Time) []RecentRequestBucket {
out := make([]RecentRequestBucket, 0, recentRequestBucketCount)
if a == nil {
return out
}
currentBucketID := recentRequestBucketID(now)
for i := recentRequestBucketCount - 1; i >= 0; i-- {
bucketID := currentBucketID - int64(i)
idx := recentRequestBucketIndex(bucketID)
bucket := a.recentRequests.buckets[idx]
entry := RecentRequestBucket{
Time: formatRecentRequestBucketLabel(bucketID),
}
if bucket.bucketID == bucketID {
entry.Success = bucket.success
entry.Failed = bucket.failed
}
out = append(out, entry)
}
return out
}
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
func (a *Auth) Clone() *Auth {
if a == nil {
+74 -1
View File
@@ -1,6 +1,10 @@
package auth
import "testing"
import (
"strings"
"testing"
"time"
)
func TestToolPrefixDisabled(t *testing.T) {
var a *Auth
@@ -96,3 +100,72 @@ func TestEnsureIndexUsesCredentialIdentity(t *testing.T) {
t.Fatalf("duplicate config entries should be separated by source-derived seed, got %q", geminiIndex)
}
}
func TestRecentRequestsSnapshotEmptyReturnsTwentyBuckets(t *testing.T) {
now := time.Unix(1_700_000_000, 0).In(time.Local)
a := &Auth{}
got := a.RecentRequestsSnapshot(now)
if len(got) != recentRequestBucketCount {
t.Fatalf("len = %d, want %d", len(got), recentRequestBucketCount)
}
currentBucketID := now.Unix() / recentRequestBucketSeconds
baseBucketID := currentBucketID - int64(recentRequestBucketCount-1)
for i, bucket := range got {
if bucket.Success != 0 || bucket.Failed != 0 {
t.Fatalf("bucket[%d] counts = %d/%d, want 0/0", i, bucket.Success, bucket.Failed)
}
if strings.TrimSpace(bucket.Time) == "" {
t.Fatalf("bucket[%d] time label is empty", i)
}
expectedBucketID := baseBucketID + int64(i)
start := time.Unix(expectedBucketID*recentRequestBucketSeconds, 0).In(time.Local)
end := start.Add(10 * time.Minute)
expected := start.Format("15:04") + "-" + end.Format("15:04")
if bucket.Time != expected {
t.Fatalf("bucket[%d] time = %q, want %q", i, bucket.Time, expected)
}
}
}
func TestRecentRequestsSnapshotIncludesCounts(t *testing.T) {
now := time.Unix(1_700_000_000, 0).In(time.Local)
a := &Auth{}
a.recordRecentRequest(now, true)
a.recordRecentRequest(now, false)
got := a.RecentRequestsSnapshot(now)
if len(got) != recentRequestBucketCount {
t.Fatalf("len = %d, want %d", len(got), recentRequestBucketCount)
}
newest := got[len(got)-1]
if newest.Success != 1 || newest.Failed != 1 {
t.Fatalf("newest bucket = success=%d failed=%d, want 1/1", newest.Success, newest.Failed)
}
}
func TestRecentRequestsSnapshotBucketAdvanceMovesCounts(t *testing.T) {
now := time.Unix(1_700_000_000, 0).In(time.Local)
next := now.Add(10 * time.Minute)
a := &Auth{}
a.recordRecentRequest(now, true)
a.recordRecentRequest(next, false)
got := a.RecentRequestsSnapshot(next)
if len(got) != recentRequestBucketCount {
t.Fatalf("len = %d, want %d", len(got), recentRequestBucketCount)
}
secondNewest := got[len(got)-2]
newest := got[len(got)-1]
if secondNewest.Success != 1 || secondNewest.Failed != 0 {
t.Fatalf("second newest bucket = success=%d failed=%d, want 1/0", secondNewest.Success, secondNewest.Failed)
}
if newest.Success != 0 || newest.Failed != 1 {
t.Fatalf("newest bucket = success=%d failed=%d, want 0/1", newest.Success, newest.Failed)
}
}
-1
View File
@@ -16,7 +16,6 @@ import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
+32
View File
@@ -2,6 +2,7 @@ package usage
import (
"context"
"strings"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
type Record struct {
Provider string
Model string
Alias string
APIKey string
AuthID string
AuthIndex string
@@ -32,6 +34,36 @@ type Detail struct {
TotalTokens int64
}
type requestedModelAliasContextKey struct{}
// WithRequestedModelAlias stores the client-requested model name for usage sinks.
func WithRequestedModelAlias(ctx context.Context, alias string) context.Context {
if ctx == nil {
ctx = context.Background()
}
alias = strings.TrimSpace(alias)
if alias == "" {
return ctx
}
return context.WithValue(ctx, requestedModelAliasContextKey{}, alias)
}
// RequestedModelAliasFromContext returns the client-requested model name stored in ctx.
func RequestedModelAliasFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(requestedModelAliasContextKey{})
switch value := raw.(type) {
case string:
return strings.TrimSpace(value)
case []byte:
return strings.TrimSpace(string(value))
default:
return ""
}
}
// Plugin consumes usage records emitted by the proxy runtime.
type Plugin interface {
HandleUsage(ctx context.Context, record Record)
+49 -24
View File
@@ -2,6 +2,7 @@ package test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -9,14 +10,14 @@ import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
runtimeexecutor "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
func TestGeminiExecutorRecordsSuccessfulZeroUsageInStatistics(t *testing.T) {
func TestGeminiExecutorRecordsSuccessfulZeroUsageInQueue(t *testing.T) {
model := fmt.Sprintf("gemini-2.5-flash-zero-usage-%d", time.Now().UnixNano())
source := fmt.Sprintf("zero-usage-%d@example.com", time.Now().UnixNano())
@@ -42,10 +43,15 @@ func TestGeminiExecutorRecordsSuccessfulZeroUsageInStatistics(t *testing.T) {
},
}
prevStatsEnabled := internalusage.StatisticsEnabled()
internalusage.SetStatisticsEnabled(true)
prevQueueEnabled := redisqueue.Enabled()
prevUsageEnabled := redisqueue.UsageStatisticsEnabled()
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(true)
redisqueue.SetUsageStatisticsEnabled(true)
t.Cleanup(func() {
internalusage.SetStatisticsEnabled(prevStatsEnabled)
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(prevQueueEnabled)
redisqueue.SetUsageStatisticsEnabled(prevUsageEnabled)
})
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
@@ -59,39 +65,58 @@ func TestGeminiExecutorRecordsSuccessfulZeroUsageInStatistics(t *testing.T) {
t.Fatalf("Execute error: %v", err)
}
detail := waitForStatisticsDetail(t, "gemini", model, source)
if detail.Failed {
t.Fatalf("detail failed = true, want false")
}
if detail.Tokens.TotalTokens != 0 {
t.Fatalf("total tokens = %d, want 0", detail.Tokens.TotalTokens)
}
waitForQueuedUsageModelTotalTokens(t, "gemini", model, 0)
}
func waitForStatisticsDetail(t *testing.T, apiName, model, source string) internalusage.RequestDetail {
func waitForQueuedUsageModelTotalTokens(t *testing.T, wantProvider, wantModel string, wantTokens int64) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
snapshot := internalusage.GetRequestStatistics().Snapshot()
apiSnapshot, ok := snapshot.APIs[apiName]
items := redisqueue.PopOldest(10)
for _, item := range items {
got, ok := parseQueuedUsagePayload(t, item)
if !ok {
time.Sleep(10 * time.Millisecond)
continue
}
modelSnapshot, ok := apiSnapshot.Models[model]
if !ok {
time.Sleep(10 * time.Millisecond)
if got.Provider != wantProvider || got.Model != wantModel {
continue
}
for _, detail := range modelSnapshot.Details {
if detail.Source == source {
return detail
if got.Failed {
t.Fatalf("payload failed = true, want false")
}
if got.Tokens.TotalTokens != wantTokens {
t.Fatalf("payload total tokens = %d, want %d", got.Tokens.TotalTokens, wantTokens)
}
return
}
time.Sleep(10 * time.Millisecond)
}
t.Fatalf("timed out waiting for statistics detail for api=%q model=%q source=%q", apiName, model, source)
return internalusage.RequestDetail{}
t.Fatalf("timed out waiting for queued usage payload for provider=%q model=%q", wantProvider, wantModel)
}
type queuedUsagePayload struct {
Provider string `json:"provider"`
Model string `json:"model"`
Failed bool `json:"failed"`
Tokens struct {
TotalTokens int64 `json:"total_tokens"`
} `json:"tokens"`
}
func parseQueuedUsagePayload(t *testing.T, payload []byte) (queuedUsagePayload, bool) {
t.Helper()
var parsed queuedUsagePayload
if len(payload) == 0 {
return parsed, false
}
if err := json.Unmarshal(payload, &parsed); err != nil {
return parsed, false
}
if parsed.Provider == "" || parsed.Model == "" {
return parsed, false
}
return parsed, true
}