Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a9472dfdee | |||
| 087045a5f1 | |||
| aaec9194d5 | |||
| 33f4904b25 | |||
| cecd39317d | |||
| 3c62a9a9b0 | |||
| 21fad9dbb4 | |||
| 48a1c88115 | |||
| 8b9ecffc2f | |||
| 42e9605871 | |||
| a726e37394 | |||
| f1ee883cd3 | |||
| 1c632d151d | |||
| 0ec07e57dd | |||
| fdffe49974 | |||
| de0394917a | |||
| ea25949479 | |||
| 99fa530967 | |||
| b9589e8ed6 | |||
| 0de0ad0d36 | |||
| 5ef7693933 | |||
| 7f68fa2414 | |||
| bb5ac40a67 | |||
| 7efc1629ba | |||
| 67f22514ed | |||
| ad868308c0 | |||
| bbe30f53b5 | |||
| feebe6c7f2 | |||
| b67eb6f25d | |||
| 644823529f | |||
| bac006e72b | |||
| ad98c9549a | |||
| 77ba15f71b | |||
| 32a0d69b17 | |||
| 1583cb4ef0 | |||
| cc0cb057b3 | |||
| 2710f56ae1 | |||
| 8bc2eff58a | |||
| ec79951e7f | |||
| 24602055a8 | |||
| 4ad6ffefb7 | |||
| 1c2153a2cb | |||
| 64d233fe93 | |||
| 66c5d60b3d | |||
| 5f039654f0 | |||
| ed0ac68324 | |||
| d606faa99c | |||
| bfdc0b3989 | |||
| 809feb1e86 | |||
| 33130f18d2 |
@@ -0,0 +1,5 @@
|
||||
# Cluster JWT example.
|
||||
# After deploying https://github.com/router-for-me/CLIProxyAPIHome, get the JWT value with:
|
||||
# curl -sS -X POST "http://<home-host>:8327/v0/management/certificates/clients" -H "X-MANAGEMENT-KEY: <management-key>" | jq -r '.home_jwt'
|
||||
# Then paste it into HOME_JWT here or export it before starting Compose.
|
||||
HOME_JWT=your-home-jwt-here
|
||||
@@ -215,6 +215,7 @@ sudo /usr/local/bin/docker logs -f cli-proxy-api
|
||||
|
||||
| 날짜 | 버전 | 비고 |
|
||||
|------|------|------|
|
||||
| 2026-05-24 | v7.1.20 | v7.1.10 → v7.1.20 패치 — Claude tool-use 이름 손실/중복 수정(v7.1.12), Claude 요청 변환 system→developer 처리(v7.1.20), Gemini 3.5 Flash 모델 추가(v7.1.18), Grok Build 0.1, Redis timeout/failover, xAI reasoning.effort. 무중단, 재인증 불필요 |
|
||||
| 2026-05-18 | v7.1.10 | 메이저 v6→v7 — Home Control Plane(Redis) 신설, ClaudeCodeSessionAffinity 제거, Usage tracking 제거(v6.10.0), xAI Grok 이미지/비디오, Codex client models, Local mgmt password validation + spoofed IP rejection. Auth 파일 호환(재인증 불필요), config 신규 필드 모두 옵션 |
|
||||
| 2026-05-04 | v6.10.4 | 69개 커밋 변경 — WebSocket compact 처리 개선, X-Amp-Thread-Id 기반 session affinity, Codex reasoning/이미지 처리 강화, GPT-5.5 모델 추가, OpenAI 호환 provider 비활성화 옵션. 무중단 업데이트, 재인증 불필요 |
|
||||
| 2026-04-26 | v6.9.38 | Protocol multiplexer + Redis queue 도입, 관리키/Redis AUTH 반복 실패 시 IP 차단 추가. 무중단 업데이트, 재인증 불필요 |
|
||||
|
||||
@@ -32,9 +32,9 @@ PackyCode provides special discounts for our software users: register using <a h
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
|
||||
<td>Thanks to VisionCoder for supporting this project. <a href="https://coder.visioncoder.cn" target="_blank">VisionCoder Developer Platform</a> is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity.
|
||||
<td>Thanks to <b>VisionCoder</b> for supporting this project. <a href="https://coder.visioncoder.cn" target="_blank">VisionCoder Developer Platform</a> is a reliable and efficient API relay service provider, offering access to mainstream AI models such as Claude Code, Codex, and Gemini. It helps developers and teams integrate AI capabilities more easily and improve productivity.
|
||||
<p></p>
|
||||
VisionCoder is also offering our users a limited-time <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> promotion: buy 1 month and get 1 month free.</td>
|
||||
VisionCoder is also offering our users a limited-time <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> promotion: <b>buy 1 month and get 1 month free</b>.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -222,6 +222,10 @@ OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoin
|
||||
|
||||
A public CLIProxyAPI-compatible fork and bundled management panel. It keeps upstream-style usage while restoring built-in usage statistics, adding cache hit rate, first-byte latency, TPS tracking, and Docker-oriented self-hosted installation docs.
|
||||
|
||||
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
|
||||
|
||||
This is a tool built with Tauri 2 + Vue 3 for managing multiple OpenAI Codex desktop accounts. Switch between saved ChatGPT/Codex certification profiles, check 5-hour and weekly quota usage in real time, verify token health, view active account details, and import or save auth.json files without manual copying.
|
||||
|
||||
> [!NOTE]
|
||||
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
|
||||
|
||||
|
||||
+6
-2
@@ -32,9 +32,9 @@ PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.p
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
|
||||
<td>感谢 VisionCoder 对本项目的支持。<a href="https://coder.visioncoder.cn" target="_blank">VisionCoder 开发平台</a> 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。
|
||||
<td>感谢 <b>VisionCoder</b> 对本项目的支持。<a href="https://coder.visioncoder.cn" target="_blank">VisionCoder 开发平台</a> 是一个可靠高效的 API 中继服务提供商,提供 Claude Code、Codex、Gemini 等主流 AI 模型,帮助开发者和团队更轻松地集成 AI 功能,提升工作效率。
|
||||
<p></p>
|
||||
VisionCoder 还为我们的用户提供 <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> 限时活动:购买 1 个月,赠送 1 个月。</td>
|
||||
VisionCoder 还为我们的用户提供 <a href="https://coder.visioncoder.cn" target="_blank">Token Plan</a> 限时活动:<b>购买 1 个月,赠送 1 个月</b>。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -218,6 +218,10 @@ OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼
|
||||
|
||||
一个公开的 CLIProxyAPI 兼容二开版本和配套管理面板,尽量保持与上游一致的使用方式,同时恢复内置使用量统计,并补充缓存命中率、首字响应时间、TPS 记录和面向 Docker 自托管的安装说明。
|
||||
|
||||
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
|
||||
|
||||
这是一个使用 Tauri 2 + Vue 3 构建的工具,用于管理多个 OpenAI Codex 桌面账户。它可以在已保存的 ChatGPT/Codex 认证配置之间切换,实时查看 5 小时和每周配额使用情况,验证 token 健康状态,查看当前账户详情,并在无需手动复制的情况下导入或保存 auth.json 文件。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||
|
||||
|
||||
+5
-1
@@ -32,7 +32,7 @@ PackyCodeは当ソフトウェアのユーザーに特別割引を提供して
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://coder.visioncoder.cn"><img src="./assets/visioncoder.png" alt="VisionCoder" width="150"></a></td>
|
||||
<td>VisionCoderのご支援に感謝します!<a href="https://coder.visioncoder.cn">VisionCoder 開発プラットフォーム</a> は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに <a href="https://coder.visioncoder.cn">Token Plan</a> の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。</td>
|
||||
<td><b>VisionCoder</b>のご支援に感謝します!<a href="https://coder.visioncoder.cn">VisionCoder 開発プラットフォーム</a> は、信頼性が高く効率的なAPIリレーサービスプロバイダーで、Claude Code、Codex、Geminiなどの主要AIモデルを提供し、開発者やチームがより簡単にAI機能を統合して生産性を向上できるよう支援します。さらに、VisionCoderはユーザー向けに <a href="https://coder.visioncoder.cn">Token Plan</a> の期間限定キャンペーン(1か月購入で1か月分プレゼント)も提供しています。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
@@ -217,6 +217,10 @@ OmniRouteはマルチプロバイダーLLM向けのAIゲートウェイです:
|
||||
|
||||
上流に近い使い方を維持する公開CLIProxyAPI互換フォーク兼管理パネルです。内蔵の使用量統計を復元し、キャッシュヒット率、初回バイト待ち時間、TPSの記録、Docker向けのセルフホスト手順を追加しています。
|
||||
|
||||
### [Codex Switch](https://github.com/9ycrooked/CodexSwitch)
|
||||
|
||||
Tauri 2 + Vue 3で構築された、複数のOpenAI Codexデスクトップアカウントを管理するためのツールです。保存済みのChatGPT/Codex認証プロファイルを切り替え、5時間および週次クォータ使用量をリアルタイムで確認し、tokenの状態を検証し、現在のアカウント詳細を表示し、手動コピーなしでauth.jsonファイルをインポートまたは保存できます。
|
||||
|
||||
> [!NOTE]
|
||||
> CLIProxyAPIの移植版またはそれに触発されたプロジェクトを開発した場合は、PRを送ってこのリストに追加してください。
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseHomeFlagConfigHostPort(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("home.example.com:8327", "secret")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Enabled {
|
||||
t.Fatal("Enabled = false, want true")
|
||||
}
|
||||
if cfg.Host != "home.example.com" {
|
||||
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
|
||||
}
|
||||
if cfg.Port != 8327 {
|
||||
t.Fatalf("Port = %d, want 8327", cfg.Port)
|
||||
}
|
||||
if cfg.Password != "secret" {
|
||||
t.Fatalf("Password = %q, want secret", cfg.Password)
|
||||
}
|
||||
if cfg.TLS.Enable {
|
||||
t.Fatal("TLS.Enable = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHomeFlagConfigRediss(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444?server-name=home.example.com&skip_verify=true&ca-cert=C%3A%2Fcerts%2Fca.pem", "")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Host != "home.example.com" {
|
||||
t.Fatalf("Host = %q, want home.example.com", cfg.Host)
|
||||
}
|
||||
if cfg.Port != 444 {
|
||||
t.Fatalf("Port = %d, want 444", cfg.Port)
|
||||
}
|
||||
if cfg.Password != "url-secret" {
|
||||
t.Fatalf("Password = %q, want url-secret", cfg.Password)
|
||||
}
|
||||
if !cfg.TLS.Enable {
|
||||
t.Fatal("TLS.Enable = false, want true")
|
||||
}
|
||||
if cfg.TLS.ServerName != "home.example.com" {
|
||||
t.Fatalf("TLS.ServerName = %q, want home.example.com", cfg.TLS.ServerName)
|
||||
}
|
||||
if !cfg.TLS.InsecureSkipVerify {
|
||||
t.Fatal("TLS.InsecureSkipVerify = false, want true")
|
||||
}
|
||||
if cfg.TLS.CACert != "C:/certs/ca.pem" {
|
||||
t.Fatalf("TLS.CACert = %q, want C:/certs/ca.pem", cfg.TLS.CACert)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHomeFlagConfigPasswordFlagOverridesURLPassword(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("rediss://:url-secret@home.example.com:444", "flag-secret")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.Password != "flag-secret" {
|
||||
t.Fatalf("Password = %q, want flag-secret", cfg.Password)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHomeFlagConfigDisableClusterDiscovery(t *testing.T) {
|
||||
cfg, err := parseHomeFlagConfig("redis://home.example.com:8327?disable-cluster-discovery=true", "")
|
||||
if err != nil {
|
||||
t.Fatalf("parseHomeFlagConfig() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.DisableClusterDiscovery {
|
||||
t.Fatal("DisableClusterDiscovery = false, want true")
|
||||
}
|
||||
}
|
||||
+18
-128
@@ -10,11 +10,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -53,120 +51,6 @@ func init() {
|
||||
buildinfo.BuildDate = BuildDate
|
||||
}
|
||||
|
||||
func parseHomeFlagConfig(rawAddr string, password string) (config.HomeConfig, error) {
|
||||
rawAddr = strings.TrimSpace(rawAddr)
|
||||
if rawAddr == "" {
|
||||
return config.HomeConfig{}, fmt.Errorf("address is empty")
|
||||
}
|
||||
|
||||
if strings.Contains(rawAddr, "://") {
|
||||
return parseHomeURLConfig(rawAddr, password)
|
||||
}
|
||||
|
||||
host, portStr, errSplit := net.SplitHostPort(rawAddr)
|
||||
if errSplit != nil {
|
||||
return config.HomeConfig{}, fmt.Errorf("expected host:port, redis://host:port, or rediss://host:port: %w", errSplit)
|
||||
}
|
||||
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return config.HomeConfig{}, fmt.Errorf("host is empty")
|
||||
}
|
||||
|
||||
port, errPort := parseHomePort(portStr)
|
||||
if errPort != nil {
|
||||
return config.HomeConfig{}, errPort
|
||||
}
|
||||
|
||||
return config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseHomeURLConfig(rawAddr string, password string) (config.HomeConfig, error) {
|
||||
parsed, errParse := url.Parse(rawAddr)
|
||||
if errParse != nil {
|
||||
return config.HomeConfig{}, fmt.Errorf("parse URL: %w", errParse)
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme))
|
||||
if scheme != "redis" && scheme != "rediss" {
|
||||
return config.HomeConfig{}, fmt.Errorf("unsupported URL scheme %q", parsed.Scheme)
|
||||
}
|
||||
|
||||
host := strings.TrimSpace(parsed.Hostname())
|
||||
if host == "" {
|
||||
return config.HomeConfig{}, fmt.Errorf("host is empty")
|
||||
}
|
||||
|
||||
port, errPort := parseHomePort(parsed.Port())
|
||||
if errPort != nil {
|
||||
return config.HomeConfig{}, errPort
|
||||
}
|
||||
|
||||
if password == "" && parsed.User != nil {
|
||||
if urlPassword, ok := parsed.User.Password(); ok {
|
||||
password = urlPassword
|
||||
}
|
||||
}
|
||||
|
||||
homeCfg := config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Password: password,
|
||||
}
|
||||
query := parsed.Query()
|
||||
homeCfg.DisableClusterDiscovery = parseHomeBoolQuery(query, "disable-cluster-discovery", "disable_cluster_discovery")
|
||||
|
||||
if scheme == "rediss" {
|
||||
homeCfg.TLS.Enable = true
|
||||
homeCfg.TLS.ServerName = strings.TrimSpace(firstHomeQueryValue(query, "server-name", "server_name"))
|
||||
homeCfg.TLS.InsecureSkipVerify = parseHomeBoolQuery(query, "insecure-skip-verify", "insecure_skip_verify", "skip_verify")
|
||||
homeCfg.TLS.CACert = strings.TrimSpace(firstHomeQueryValue(query, "ca-cert", "ca_cert"))
|
||||
}
|
||||
|
||||
return homeCfg, nil
|
||||
}
|
||||
|
||||
func parseHomePort(rawPort string) (int, error) {
|
||||
rawPort = strings.TrimSpace(rawPort)
|
||||
if rawPort == "" {
|
||||
return 0, fmt.Errorf("port is empty")
|
||||
}
|
||||
|
||||
port, errPort := strconv.Atoi(rawPort)
|
||||
if errPort != nil || port <= 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port %q", rawPort)
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func firstHomeQueryValue(values url.Values, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if value := values.Get(key); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseHomeBoolQuery(values url.Values, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
value := strings.TrimSpace(values.Get(key))
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
parsed, errParse := strconv.ParseBool(value)
|
||||
return errParse == nil && parsed
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// main is the entry point of the application.
|
||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||
// service based on the provided flags (login, codex-login, or server mode).
|
||||
@@ -188,8 +72,7 @@ func main() {
|
||||
var vertexImportPrefix string
|
||||
var configPath string
|
||||
var password string
|
||||
var homeAddr string
|
||||
var homePassword string
|
||||
var homeJWT string
|
||||
var homeDisableClusterDiscovery bool
|
||||
var tuiMode bool
|
||||
var standalone bool
|
||||
@@ -210,9 +93,8 @@ func main() {
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&vertexImportPrefix, "vertex-import-prefix", "", "Prefix for Vertex model namespacing (use with -vertex-import)")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
flag.StringVar(&homeAddr, "home", "", "Home control plane address in host:port, redis://host:port, or rediss://host:port format (loads config from home and skips local config file)")
|
||||
flag.StringVar(&homePassword, "home-password", "", "Home control plane password (Redis AUTH)")
|
||||
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home address")
|
||||
flag.StringVar(&homeJWT, "home-jwt", "", "Home control plane JWT for mTLS certificate bootstrap and connection")
|
||||
flag.BoolVar(&homeDisableClusterDiscovery, "home-disable-cluster-discovery", false, "Disable Home CLUSTER NODES discovery and keep using the configured -home-jwt address")
|
||||
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
|
||||
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
|
||||
flag.BoolVar(&localModel, "local-model", false, "Use embedded model catalog only, skip remote model fetching")
|
||||
@@ -299,6 +181,13 @@ func main() {
|
||||
return "", false
|
||||
}
|
||||
writableBase := util.WritablePath()
|
||||
|
||||
if strings.TrimSpace(homeJWT) == "" {
|
||||
if v, ok := lookupEnv("HOME_JWT", "home_jwt"); ok {
|
||||
homeJWT = v
|
||||
}
|
||||
}
|
||||
|
||||
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
|
||||
usePostgresStore = true
|
||||
pgStoreDSN = value
|
||||
@@ -362,12 +251,13 @@ func main() {
|
||||
// Determine and load the configuration file.
|
||||
// Prefer the Postgres store when configured, otherwise fallback to git or local files.
|
||||
var configFilePath string
|
||||
if strings.TrimSpace(homeAddr) != "" {
|
||||
if strings.TrimSpace(homeJWT) != "" {
|
||||
configLoadedFromHome = true
|
||||
trimmedHomePassword := strings.TrimSpace(homePassword)
|
||||
homeCfg, errHomeCfg := parseHomeFlagConfig(homeAddr, trimmedHomePassword)
|
||||
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
homeCfg, errHomeCfg := home.ConfigFromJWT(ctxHome, homeJWT)
|
||||
cancelHome()
|
||||
if errHomeCfg != nil {
|
||||
log.Errorf("invalid -home address %q: %v", homeAddr, errHomeCfg)
|
||||
log.Errorf("invalid -home-jwt: %v", errHomeCfg)
|
||||
return
|
||||
}
|
||||
if homeDisableClusterDiscovery {
|
||||
@@ -376,9 +266,9 @@ func main() {
|
||||
homeClient := home.New(homeCfg)
|
||||
defer homeClient.Close()
|
||||
|
||||
ctxHome, cancelHome := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
raw, errGetConfig := homeClient.GetConfig(ctxHome)
|
||||
cancelHome()
|
||||
ctxHomeConfig, cancelHomeConfig := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
raw, errGetConfig := homeClient.GetConfig(ctxHomeConfig)
|
||||
cancelHomeConfig()
|
||||
if errGetConfig != nil {
|
||||
log.Errorf("failed to fetch config from home: %v", errGetConfig)
|
||||
return
|
||||
|
||||
+3
-22
@@ -11,26 +11,6 @@ tls:
|
||||
cert: ""
|
||||
key: ""
|
||||
|
||||
# Optional "home" control plane integration over Redis protocol.
|
||||
home:
|
||||
enabled: false
|
||||
host: "127.0.0.1"
|
||||
port: 6379
|
||||
password: ""
|
||||
# Keep CPA pinned to the configured home address instead of switching to CLUSTER NODES entries.
|
||||
# Useful when Home is behind NAT, Docker networking, or a reverse proxy.
|
||||
disable-cluster-discovery: false
|
||||
# Optional TLS for the outbound Redis connection to the home control plane.
|
||||
# Enable this when connecting through rediss:// or an SSL stream proxy.
|
||||
tls:
|
||||
enable: false
|
||||
# Optional SNI/certificate name override. Leave empty to use the configured home host.
|
||||
server-name: ""
|
||||
# Trust a private CA bundle in addition to system roots.
|
||||
ca-cert: ""
|
||||
# Only for testing self-signed endpoints; disables certificate verification.
|
||||
insecure-skip-verify: false
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
# Whether to allow remote (non-localhost) management access.
|
||||
@@ -86,8 +66,8 @@ error-logs-max-files: 10
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# How long (in seconds) Redis usage queue items are retained in memory for the RESP interface (LPOP/RPOP).
|
||||
# Note: the in-process Redis RESP usage output is disabled when home.enabled is true.
|
||||
# How long (in seconds) usage queue items are retained in memory for the Management API.
|
||||
# The local Redis RESP usage output is disabled.
|
||||
# Default: 60. Max: 3600.
|
||||
redis-usage-queue-retention-seconds: 60
|
||||
|
||||
@@ -277,6 +257,7 @@ nonstream-keepalive-interval: 0
|
||||
# models: # The models supported by the provider.
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
# image: false # optional: set true to allow this model on /v1/images/generations and /v1/images/edits
|
||||
# thinking: # optional: omit to default to levels ["low","medium","high"]
|
||||
# levels: ["low", "medium", "high"]
|
||||
# # You may repeat the same alias to build an internal model pool.
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
services:
|
||||
cli-proxy-api:
|
||||
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
|
||||
pull_policy: always
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
VERSION: ${VERSION:-dev}
|
||||
COMMIT: ${COMMIT:-none}
|
||||
BUILD_DATE: ${BUILD_DATE:-unknown}
|
||||
container_name: cli-proxy-api-cluster
|
||||
environment:
|
||||
HOME_JWT: ${HOME_JWT:-}
|
||||
ports:
|
||||
- "8317:8317"
|
||||
volumes:
|
||||
- ./home:/root/.cli-proxy-api
|
||||
- ./logs:/CLIProxyAPI/logs
|
||||
command: >
|
||||
sh -eu -c '
|
||||
if [ -z "$$HOME_JWT" ]; then
|
||||
echo "HOME_JWT is required" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exec ./CLIProxyAPI -home-jwt "$$HOME_JWT"
|
||||
'
|
||||
restart: unless-stopped
|
||||
@@ -2081,7 +2081,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||
} else {
|
||||
projectID = fetchedProjectID
|
||||
log.Infof("antigravity: obtained project ID %s", projectID)
|
||||
log.Infof("antigravity: obtained project ID %s", util.HideAPIKey(projectID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2125,7 +2125,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
CompleteOAuthSessionsByProvider("antigravity")
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
fmt.Printf("Using GCP project: %s\n", util.HideAPIKey(projectID))
|
||||
}
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
@@ -103,18 +103,6 @@ func (s *Server) routeMuxConnection(conn net.Conn, httpListener *muxListener) {
|
||||
}
|
||||
|
||||
if isRedisRESPPrefix(prefix[0]) {
|
||||
if s.cfg != nil && s.cfg.Home.Enabled {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close redis connection while home mode is enabled: %v", errClose)
|
||||
}
|
||||
return
|
||||
}
|
||||
if !s.managementRoutesEnabled.Load() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Errorf("failed to close redis connection while management is disabled: %v", errClose)
|
||||
}
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
s.handleRedisConnection(conn, reader)
|
||||
return
|
||||
|
||||
@@ -31,9 +31,12 @@ func isRedisRESPPrefix(prefix byte) bool {
|
||||
}
|
||||
|
||||
func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
if s == nil || conn == nil || reader == nil {
|
||||
if s == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
if reader == nil {
|
||||
reader = bufio.NewReader(conn)
|
||||
}
|
||||
|
||||
clientIP, localClient := resolveRemoteIP(conn.RemoteAddr())
|
||||
authed := false
|
||||
@@ -63,10 +66,10 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
return
|
||||
}
|
||||
|
||||
args, err := readRESPArray(reader)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
_ = writeRedisError(writer, "ERR "+err.Error())
|
||||
args, errRead := readRESPArray(reader)
|
||||
if errRead != nil {
|
||||
if !errors.Is(errRead, io.EOF) {
|
||||
_ = writeRedisError(writer, "ERR "+errRead.Error())
|
||||
_ = writer.Flush()
|
||||
}
|
||||
return
|
||||
@@ -139,13 +142,6 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
return
|
||||
}
|
||||
case "SUBSCRIBE":
|
||||
if !authed {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
channel, ok := parseSubscribeChannel(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for 'subscribe' command")
|
||||
@@ -174,13 +170,6 @@ func (s *Server) handleRedisConnection(conn net.Conn, reader *bufio.Reader) {
|
||||
s.streamRedisUsageSubscription(reader, writer, messages, unsubscribe)
|
||||
return
|
||||
case "LPOP", "RPOP":
|
||||
if !authed {
|
||||
_ = writeRedisError(writer, "NOAUTH Authentication required.")
|
||||
if !flush() {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
count, hasCount, ok := parsePopCount(args)
|
||||
if !ok {
|
||||
_ = writeRedisError(writer, "ERR wrong number of arguments for '"+strings.ToLower(cmd)+"' command")
|
||||
@@ -270,11 +259,11 @@ func readRedisSubscriptionCommands(reader *bufio.Reader, commands chan<- redisSu
|
||||
defer close(commands)
|
||||
|
||||
for {
|
||||
args, err := readRESPArray(reader)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
args, errRead := readRESPArray(reader)
|
||||
if errRead != nil {
|
||||
if !errors.Is(errRead, io.EOF) {
|
||||
select {
|
||||
case commands <- redisSubscriptionCommand{err: err}:
|
||||
case commands <- redisSubscriptionCommand{err: errRead}:
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
@@ -336,7 +325,7 @@ func resolveRemoteIP(addr net.Addr) (ip string, localClient bool) {
|
||||
}
|
||||
default:
|
||||
host = addr.String()
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
if h, _, errSplit := net.SplitHostPort(host); errSplit == nil {
|
||||
host = h
|
||||
}
|
||||
host = strings.TrimSpace(host)
|
||||
@@ -362,7 +351,6 @@ func parseAuthPassword(args []string) (string, bool) {
|
||||
case 2:
|
||||
return args[1], true
|
||||
case 3:
|
||||
// Support AUTH <username> <password> by ignoring username for compatibility.
|
||||
return args[2], true
|
||||
default:
|
||||
return "", false
|
||||
@@ -383,34 +371,34 @@ func parsePopCount(args []string) (count int, hasCount bool, ok bool) {
|
||||
if len(args) == 2 {
|
||||
return 1, false, true
|
||||
}
|
||||
parsed, err := strconv.Atoi(strings.TrimSpace(args[2]))
|
||||
if err != nil {
|
||||
parsed, errParse := strconv.Atoi(strings.TrimSpace(args[2]))
|
||||
if errParse != nil {
|
||||
return 0, true, true
|
||||
}
|
||||
return parsed, true, true
|
||||
}
|
||||
|
||||
func readRESPArray(reader *bufio.Reader) ([]string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
prefix, errRead := reader.ReadByte()
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
line, errLine := readRESPLine(reader)
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil || count < 0 {
|
||||
count, errParse := strconv.Atoi(line)
|
||||
if errParse != nil || count < 0 {
|
||||
return nil, fmt.Errorf("protocol error")
|
||||
}
|
||||
args := make([]string, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
value, err := readRESPString(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
value, errString := readRESPString(reader)
|
||||
if errString != nil {
|
||||
return nil, errString
|
||||
}
|
||||
args = append(args, value)
|
||||
}
|
||||
@@ -418,9 +406,9 @@ func readRESPArray(reader *bufio.Reader) ([]string, error) {
|
||||
}
|
||||
|
||||
func readRESPString(reader *bufio.Reader) (string, error) {
|
||||
prefix, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
prefix, errRead := reader.ReadByte()
|
||||
if errRead != nil {
|
||||
return "", errRead
|
||||
}
|
||||
switch prefix {
|
||||
case '$':
|
||||
@@ -433,20 +421,20 @@ func readRESPString(reader *bufio.Reader) (string, error) {
|
||||
}
|
||||
|
||||
func readRESPBulkString(reader *bufio.Reader) (string, error) {
|
||||
line, err := readRESPLine(reader)
|
||||
if err != nil {
|
||||
return "", err
|
||||
line, errLine := readRESPLine(reader)
|
||||
if errLine != nil {
|
||||
return "", errLine
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
length, errParse := strconv.Atoi(line)
|
||||
if errParse != nil {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
}
|
||||
if length < 0 {
|
||||
return "", nil
|
||||
}
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(reader, buf); err != nil {
|
||||
return "", err
|
||||
if _, errRead := io.ReadFull(reader, buf); errRead != nil {
|
||||
return "", errRead
|
||||
}
|
||||
if length+2 < 2 || buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return "", fmt.Errorf("protocol error")
|
||||
@@ -455,9 +443,9 @@ func readRESPBulkString(reader *bufio.Reader) (string, error) {
|
||||
}
|
||||
|
||||
func readRESPLine(reader *bufio.Reader) (string, error) {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
line, errRead := reader.ReadString('\n')
|
||||
if errRead != nil {
|
||||
return "", errRead
|
||||
}
|
||||
line = strings.TrimSuffix(line, "\n")
|
||||
line = strings.TrimSuffix(line, "\r")
|
||||
@@ -468,24 +456,24 @@ func writeRedisSimpleString(writer *bufio.Writer, value string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("+" + value + "\r\n")
|
||||
return err
|
||||
_, errWrite := writer.WriteString("+" + value + "\r\n")
|
||||
return errWrite
|
||||
}
|
||||
|
||||
func writeRedisError(writer *bufio.Writer, message string) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("-" + message + "\r\n")
|
||||
return err
|
||||
_, errWrite := writer.WriteString("-" + message + "\r\n")
|
||||
return errWrite
|
||||
}
|
||||
|
||||
func writeRedisNilBulkString(writer *bufio.Writer) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("$-1\r\n")
|
||||
return err
|
||||
_, errWrite := writer.WriteString("$-1\r\n")
|
||||
return errWrite
|
||||
}
|
||||
|
||||
func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
|
||||
@@ -495,26 +483,26 @@ func writeRedisBulkString(writer *bufio.Writer, payload []byte) error {
|
||||
if payload == nil {
|
||||
return writeRedisNilBulkString(writer)
|
||||
}
|
||||
if _, err := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); err != nil {
|
||||
return err
|
||||
if _, errWrite := writer.WriteString("$" + strconv.Itoa(len(payload)) + "\r\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, err := writer.Write(payload); err != nil {
|
||||
return err
|
||||
if _, errWrite := writer.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
_, err := writer.WriteString("\r\n")
|
||||
return err
|
||||
_, errWrite := writer.WriteString("\r\n")
|
||||
return errWrite
|
||||
}
|
||||
|
||||
func writeRedisArrayOfBulkStrings(writer *bufio.Writer, items [][]byte) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if _, err := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); err != nil {
|
||||
return err
|
||||
if _, errWrite := writer.WriteString("*" + strconv.Itoa(len(items)) + "\r\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
for i := range items {
|
||||
if err := writeRedisBulkString(writer, items[i]); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, items[i]); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -524,63 +512,63 @@ func writeRedisInteger(writer *bufio.Writer, value int) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
|
||||
return err
|
||||
_, errWrite := writer.WriteString(":" + strconv.Itoa(value) + "\r\n")
|
||||
return errWrite
|
||||
}
|
||||
|
||||
func writeRedisArrayHeader(writer *bufio.Writer, count int) error {
|
||||
if writer == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
_, err := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
|
||||
return err
|
||||
_, errWrite := writer.WriteString("*" + strconv.Itoa(count) + "\r\n")
|
||||
return errWrite
|
||||
}
|
||||
|
||||
func writeRedisPubSubSubscribe(writer *bufio.Writer, channel string, count int) error {
|
||||
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisArrayHeader(writer, 3); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("subscribe")); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte("subscribe")); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte(channel)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeRedisInteger(writer, count)
|
||||
}
|
||||
|
||||
func writeRedisPubSubUnsubscribe(writer *bufio.Writer, channel string, count int) error {
|
||||
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisArrayHeader(writer, 3); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("unsubscribe")); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte("unsubscribe")); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte(channel)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeRedisInteger(writer, count)
|
||||
}
|
||||
|
||||
func writeRedisPubSubMessage(writer *bufio.Writer, channel string, payload []byte) error {
|
||||
if err := writeRedisArrayHeader(writer, 3); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisArrayHeader(writer, 3); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("message")); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte("message")); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte(channel)); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte(channel)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeRedisBulkString(writer, payload)
|
||||
}
|
||||
|
||||
func writeRedisPubSubPong(writer *bufio.Writer, payload []byte) error {
|
||||
if err := writeRedisArrayHeader(writer, 2); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisArrayHeader(writer, 2); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if err := writeRedisBulkString(writer, []byte("pong")); err != nil {
|
||||
return err
|
||||
if errWrite := writeRedisBulkString(writer, []byte("pong")); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeRedisBulkString(writer, payload)
|
||||
}
|
||||
|
||||
@@ -3,13 +3,10 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -18,18 +15,6 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
|
||||
)
|
||||
|
||||
type remoteAddrConn struct {
|
||||
net.Conn
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (c *remoteAddrConn) RemoteAddr() net.Addr {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func startRedisMuxListener(t *testing.T, server *Server) (addr string, stop func()) {
|
||||
t.Helper()
|
||||
|
||||
@@ -86,17 +71,6 @@ func readTestRESPLine(r *bufio.Reader) (string, error) {
|
||||
return strings.TrimSuffix(line, "\r\n"), nil
|
||||
}
|
||||
|
||||
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if prefix != '+' {
|
||||
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPError(r *bufio.Reader) (string, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
@@ -108,22 +82,33 @@ func readTestRESPError(r *bufio.Reader) (string, error) {
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPSimpleString(r *bufio.Reader) (string, error) {
|
||||
prefix, errRead := r.ReadByte()
|
||||
if errRead != nil {
|
||||
return "", errRead
|
||||
}
|
||||
if prefix != '+' {
|
||||
return "", fmt.Errorf("expected simple string prefix '+', got %q", prefix)
|
||||
}
|
||||
return readTestRESPLine(r)
|
||||
}
|
||||
|
||||
func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
prefix, errRead := r.ReadByte()
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
if prefix != '$' {
|
||||
return nil, fmt.Errorf("expected bulk string prefix '$', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
line, errLine := readTestRESPLine(r)
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, err)
|
||||
length, errParse := strconv.Atoi(line)
|
||||
if errParse != nil {
|
||||
return nil, fmt.Errorf("invalid bulk string length %q: %v", line, errParse)
|
||||
}
|
||||
if length == -1 {
|
||||
return nil, nil
|
||||
@@ -133,8 +118,8 @@ func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
|
||||
}
|
||||
|
||||
payload := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, err
|
||||
if _, errRead := io.ReadFull(r, payload); errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
if payload[length] != '\r' || payload[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("invalid bulk string terminator")
|
||||
@@ -143,21 +128,21 @@ func readTestRESPBulkString(r *bufio.Reader) ([]byte, error) {
|
||||
}
|
||||
|
||||
func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
prefix, errRead := r.ReadByte()
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
if prefix != '*' {
|
||||
return nil, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
line, errLine := readTestRESPLine(r)
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||
count, errParse := strconv.Atoi(line)
|
||||
if errParse != nil {
|
||||
return nil, fmt.Errorf("invalid array length %q: %v", line, errParse)
|
||||
}
|
||||
if count < 0 {
|
||||
return nil, fmt.Errorf("invalid array length %d", count)
|
||||
@@ -165,114 +150,15 @@ func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
|
||||
|
||||
out := make([][]byte, 0, count)
|
||||
for i := 0; i < count; i++ {
|
||||
item, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
item, errItem := readTestRESPBulkString(r)
|
||||
if errItem != nil {
|
||||
return nil, errItem
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readTestRESPInteger(r *bufio.Reader) (int, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if prefix != ':' {
|
||||
return 0, fmt.Errorf("expected integer prefix ':', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
value, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid integer %q: %v", line, err)
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func readTestRESPArrayHeader(r *bufio.Reader) (int, error) {
|
||||
prefix, err := r.ReadByte()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if prefix != '*' {
|
||||
return 0, fmt.Errorf("expected array prefix '*', got %q", prefix)
|
||||
}
|
||||
|
||||
line, err := readTestRESPLine(r)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid array length %q: %v", line, err)
|
||||
}
|
||||
if count < 0 {
|
||||
return 0, fmt.Errorf("invalid array length %d", count)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) {
|
||||
count, err := readTestRESPArrayHeader(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if count != 3 {
|
||||
return "", 0, fmt.Errorf("subscribe array length = %d, want 3", count)
|
||||
}
|
||||
|
||||
kind, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if string(kind) != "subscribe" {
|
||||
return "", 0, fmt.Errorf("pubsub kind = %q, want subscribe", string(kind))
|
||||
}
|
||||
|
||||
channel, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
subscriptions, err := readTestRESPInteger(r)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
return string(channel), subscriptions, nil
|
||||
}
|
||||
|
||||
func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) {
|
||||
count, err := readTestRESPArrayHeader(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if count != 3 {
|
||||
return "", nil, fmt.Errorf("message array length = %d, want 3", count)
|
||||
}
|
||||
|
||||
kind, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if string(kind) != "message" {
|
||||
return "", nil, fmt.Errorf("pubsub kind = %q, want message", string(kind))
|
||||
}
|
||||
|
||||
channel, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
payload, err := readTestRESPBulkString(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return string(channel), payload, nil
|
||||
}
|
||||
|
||||
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
redisqueue.SetEnabled(false)
|
||||
@@ -333,13 +219,19 @@ func TestRedisProtocol_HomeEnabled_DisablesConnection(t *testing.T) {
|
||||
_ = conn.SetDeadline(time.Now().Add(2 * time.Second))
|
||||
_ = writeTestRESPCommand(conn, "PING")
|
||||
|
||||
if msg, err := readTestRESPError(bufio.NewReader(conn)); err != nil {
|
||||
t.Fatalf("failed to read home-mode RESP error: %v", err)
|
||||
} else if msg != "ERR redis usage output disabled in home mode" {
|
||||
t.Fatalf("unexpected disabled RESP error: %q", msg)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1)
|
||||
_, errRead := conn.Read(buf)
|
||||
if errRead == nil {
|
||||
t.Fatalf("expected connection to be closed when home mode is enabled")
|
||||
t.Fatalf("expected connection to be closed after home-mode RESP error")
|
||||
}
|
||||
if ne, ok := errRead.(net.Error); ok && ne.Timeout() {
|
||||
t.Fatalf("expected connection to be closed when home mode is enabled, got timeout: %v", errRead)
|
||||
t.Fatalf("expected connection to be closed after home-mode RESP error, got timeout: %v", errRead)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,29 +260,11 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
|
||||
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", "test-key"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
|
||||
} else if msg != "NOAUTH Authentication required." {
|
||||
t.Fatalf("unexpected LPOP NOAUTH error: %q", msg)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH response: %v", err)
|
||||
if msg, errRead := readTestRESPSimpleString(reader); errRead != nil {
|
||||
t.Fatalf("failed to read AUTH response: %v", errRead)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected AUTH response: %q", msg)
|
||||
}
|
||||
@@ -402,25 +276,25 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
redisqueue.Enqueue([]byte("b"))
|
||||
redisqueue.Enqueue([]byte("c"))
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue"); errWrite != nil {
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read RPOP response: %v", err)
|
||||
if item, errRead := readTestRESPBulkString(reader); errRead != nil {
|
||||
t.Fatalf("failed to read RPOP response: %v", errRead)
|
||||
} else if string(item) != "a" {
|
||||
t.Fatalf("unexpected RPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if item, err := readTestRESPBulkString(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP response: %v", err)
|
||||
if item, errRead := readTestRESPBulkString(reader); errRead != nil {
|
||||
t.Fatalf("failed to read LPOP response: %v", errRead)
|
||||
} else if string(item) != "b" {
|
||||
t.Fatalf("unexpected LPOP item: %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "10"); errWrite != nil {
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "usage", "10"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP count command: %v", errWrite)
|
||||
}
|
||||
items, errItems := readRESPArrayOfBulkStrings(reader)
|
||||
@@ -431,7 +305,7 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
t.Fatalf("unexpected RPOP count items: %#v", items)
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "queue"); errWrite != nil {
|
||||
if errWrite := writeTestRESPCommand(conn, "LPOP", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP empty command: %v", errWrite)
|
||||
}
|
||||
item, errItem := readTestRESPBulkString(reader)
|
||||
@@ -442,7 +316,7 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
t.Fatalf("expected nil bulk string for empty queue, got %q", string(item))
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "queue", "2"); errWrite != nil {
|
||||
if errWrite := writeTestRESPCommand(conn, "RPOP", "usage", "2"); errWrite != nil {
|
||||
t.Fatalf("failed to write RPOP empty count command: %v", errWrite)
|
||||
}
|
||||
emptyItems, errEmpty := readRESPArrayOfBulkStrings(reader)
|
||||
@@ -453,284 +327,3 @@ func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
|
||||
t.Fatalf("expected empty array for empty queue with count, got %#v", emptyItems)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_SubscribeUsageBroadcastsAndSkipsQueue(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
firstConn, errDialFirst := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDialFirst != nil {
|
||||
t.Fatalf("failed to dial first redis listener: %v", errDialFirst)
|
||||
}
|
||||
t.Cleanup(func() { _ = firstConn.Close() })
|
||||
firstReader := bufio.NewReader(firstConn)
|
||||
_ = firstConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(firstConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write first AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(firstReader); err != nil {
|
||||
t.Fatalf("failed to read first AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected first AUTH response: %q", msg)
|
||||
}
|
||||
if errWrite := writeTestRESPCommand(firstConn, "SUBSCRIBE", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write first SUBSCRIBE command: %v", errWrite)
|
||||
}
|
||||
if channel, count, err := readTestRESPPubSubSubscribe(firstReader); err != nil {
|
||||
t.Fatalf("failed to read first SUBSCRIBE response: %v", err)
|
||||
} else if channel != "usage" || count != 1 {
|
||||
t.Fatalf("unexpected first SUBSCRIBE response channel=%q count=%d", channel, count)
|
||||
}
|
||||
|
||||
secondConn, errDialSecond := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDialSecond != nil {
|
||||
t.Fatalf("failed to dial second redis listener: %v", errDialSecond)
|
||||
}
|
||||
t.Cleanup(func() { _ = secondConn.Close() })
|
||||
secondReader := bufio.NewReader(secondConn)
|
||||
_ = secondConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(secondConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write second AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(secondReader); err != nil {
|
||||
t.Fatalf("failed to read second AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected second AUTH response: %q", msg)
|
||||
}
|
||||
if errWrite := writeTestRESPCommand(secondConn, "SUBSCRIBE", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write second SUBSCRIBE command: %v", errWrite)
|
||||
}
|
||||
if channel, count, err := readTestRESPPubSubSubscribe(secondReader); err != nil {
|
||||
t.Fatalf("failed to read second SUBSCRIBE response: %v", err)
|
||||
} else if channel != "usage" || count != 1 {
|
||||
t.Fatalf("unexpected second SUBSCRIBE response channel=%q count=%d", channel, count)
|
||||
}
|
||||
|
||||
redisqueue.Enqueue([]byte(`{"id":1}`))
|
||||
|
||||
if channel, payload, err := readTestRESPPubSubMessage(firstReader); err != nil {
|
||||
t.Fatalf("failed to read first pubsub message: %v", err)
|
||||
} else if channel != "usage" || string(payload) != `{"id":1}` {
|
||||
t.Fatalf("unexpected first pubsub message channel=%q payload=%q", channel, string(payload))
|
||||
}
|
||||
if channel, payload, err := readTestRESPPubSubMessage(secondReader); err != nil {
|
||||
t.Fatalf("failed to read second pubsub message: %v", err)
|
||||
} else if channel != "usage" || string(payload) != `{"id":1}` {
|
||||
t.Fatalf("unexpected second pubsub message channel=%q payload=%q", channel, string(payload))
|
||||
}
|
||||
|
||||
popConn, errDialPop := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDialPop != nil {
|
||||
t.Fatalf("failed to dial pop redis listener: %v", errDialPop)
|
||||
}
|
||||
t.Cleanup(func() { _ = popConn.Close() })
|
||||
popReader := bufio.NewReader(popConn)
|
||||
_ = popConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
if errWrite := writeTestRESPCommand(popConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write pop AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPSimpleString(popReader); err != nil {
|
||||
t.Fatalf("failed to read pop AUTH response: %v", err)
|
||||
} else if msg != "OK" {
|
||||
t.Fatalf("unexpected pop AUTH response: %q", msg)
|
||||
}
|
||||
if errWrite := writeTestRESPCommand(popConn, "LPOP", "usage"); errWrite != nil {
|
||||
t.Fatalf("failed to write pop LPOP command: %v", errWrite)
|
||||
}
|
||||
item, errItem := readTestRESPBulkString(popReader)
|
||||
if errItem != nil {
|
||||
t.Fatalf("failed to read pop LPOP response: %v", errItem)
|
||||
}
|
||||
if item != nil {
|
||||
t.Fatalf("expected subscribed usage to skip queue, got %q", string(item))
|
||||
}
|
||||
|
||||
managementReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=1", nil)
|
||||
managementReq.Header.Set("Authorization", "Bearer "+managementPassword)
|
||||
managementRR := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(managementRR, managementReq)
|
||||
if managementRR.Code != http.StatusOK {
|
||||
t.Fatalf("management usage status = %d, want %d body=%s", managementRR.Code, http.StatusOK, managementRR.Body.String())
|
||||
}
|
||||
var managementPayload []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal(managementRR.Body.Bytes(), &managementPayload); errUnmarshal != nil {
|
||||
t.Fatalf("unmarshal management usage response: %v", errUnmarshal)
|
||||
}
|
||||
if len(managementPayload) != 0 {
|
||||
t.Fatalf("expected management usage queue to be empty, got %s", managementRR.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_IPBan_MirrorsManagementPolicy(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() { _ = clientConn.Close() })
|
||||
t.Cleanup(func() { _ = serverConn.Close() })
|
||||
|
||||
fakeRemote := &net.TCPAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Port: 1234,
|
||||
}
|
||||
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
|
||||
|
||||
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
|
||||
|
||||
reader := bufio.NewReader(clientConn)
|
||||
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read LPOP NOAUTH error: %v", err)
|
||||
} else if msg != "NOAUTH Authentication required." {
|
||||
t.Fatalf("unexpected LPOP NOAUTH error at attempt %d: %q", i+1, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(clientConn, "LPOP", "queue"); errWrite != nil {
|
||||
t.Fatalf("failed to write LPOP command after failures: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read LPOP banned error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected LPOP banned error: %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() { _ = clientConn.Close() })
|
||||
t.Cleanup(func() { _ = serverConn.Close() })
|
||||
|
||||
fakeRemote := &net.TCPAddr{
|
||||
IP: net.ParseIP("1.2.3.4"),
|
||||
Port: 1234,
|
||||
}
|
||||
wrappedConn := &remoteAddrConn{Conn: serverConn, remoteAddr: fakeRemote}
|
||||
|
||||
go server.handleRedisConnection(wrappedConn, bufio.NewReader(wrappedConn))
|
||||
|
||||
reader := bufio.NewReader(clientConn)
|
||||
_ = clientConn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if errWrite := writeTestRESPCommand(clientConn, "AUTH", "wrong-password"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command after failures: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read AUTH banned error: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected AUTH banned error at attempt %d: %q", i+6, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(clientConn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedisProtocol_LOCALHOST_AUTH_IPBan_BlocksCorrectPasswordDuringBan(t *testing.T) {
|
||||
const managementPassword = "test-management-password"
|
||||
|
||||
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() { redisqueue.SetEnabled(false) })
|
||||
|
||||
server := newTestServer(t)
|
||||
if !server.managementRoutesEnabled.Load() {
|
||||
t.Fatalf("expected managementRoutesEnabled to be true")
|
||||
}
|
||||
|
||||
addr, stop := startRedisMuxListener(t, server)
|
||||
t.Cleanup(stop)
|
||||
|
||||
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
|
||||
if errDial != nil {
|
||||
t.Fatalf("failed to dial redis listener: %v", errDial)
|
||||
}
|
||||
t.Cleanup(func() { _ = conn.Close() })
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", "wrong-password"); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command: %v", errWrite)
|
||||
}
|
||||
if msg, err := readTestRESPError(reader); err != nil {
|
||||
t.Fatalf("failed to read AUTH error: %v", err)
|
||||
} else if msg != "ERR invalid management key" {
|
||||
t.Fatalf("unexpected AUTH error at attempt %d: %q", i+1, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
|
||||
t.Fatalf("failed to write AUTH command with correct password: %v", errWrite)
|
||||
}
|
||||
msg, err := readTestRESPError(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read AUTH banned error for correct password: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(msg, "ERR IP banned due to too many failed attempts. Try again in") {
|
||||
t.Fatalf("unexpected AUTH banned error for correct password: %q", msg)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,9 +217,6 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
|
||||
// Create gin engine
|
||||
engine := gin.New()
|
||||
if errSetTrustedProxies := engine.SetTrustedProxies(nil); errSetTrustedProxies != nil {
|
||||
log.Warnf("failed to disable trusted proxy headers: %v", errSetTrustedProxies)
|
||||
}
|
||||
if optionState.engineConfigurator != nil {
|
||||
optionState.engineConfigurator(engine)
|
||||
}
|
||||
|
||||
+24
-26
@@ -21,10 +21,6 @@ import (
|
||||
)
|
||||
|
||||
func newTestServer(t *testing.T) *Server {
|
||||
return newTestServerWithOptions(t)
|
||||
}
|
||||
|
||||
func newTestServerWithOptions(t *testing.T, opts ...ServerOption) *Server {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
@@ -50,7 +46,7 @@ func newTestServerWithOptions(t *testing.T, opts ...ServerOption) *Server {
|
||||
accessManager := sdkaccess.NewManager()
|
||||
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
return NewServer(cfg, authManager, accessManager, configPath, opts...)
|
||||
return NewServer(cfg, authManager, accessManager, configPath)
|
||||
}
|
||||
|
||||
func TestHealthz(t *testing.T) {
|
||||
@@ -152,26 +148,6 @@ func TestManagementUsageRequiresManagementAuthAndPopsArray(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagementLocalPasswordRejectsSpoofedForwardedFor(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
|
||||
server := newTestServerWithOptions(t, WithLocalManagementPassword("test-local-key"))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/config", nil)
|
||||
req.RemoteAddr = "203.0.113.10:45678"
|
||||
req.Header.Set("X-Forwarded-For", "127.0.0.1")
|
||||
req.Header.Set("Authorization", "Bearer test-local-key")
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d body=%s", rr.Code, http.StatusForbidden, rr.Body.String())
|
||||
}
|
||||
if body := rr.Body.String(); !strings.Contains(body, "remote management disabled") {
|
||||
t.Fatalf("body = %q, want remote management disabled", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHomeEnabledHidesManagementEndpointsAndControlPanel(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "test-management-key")
|
||||
|
||||
@@ -287,7 +263,7 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) {
|
||||
DisplayName: "Custom Codex Model",
|
||||
Description: "Custom model from registry",
|
||||
ContextLength: 123456,
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "medium"}},
|
||||
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "minimal", "low", "medium", "unsupported", "high", "xhigh"}},
|
||||
},
|
||||
{ID: "grok-imagine-image-quality", Object: "model", OwnedBy: "xai", Type: "openai"},
|
||||
{ID: "gpt-image-2", Object: "model", OwnedBy: "openai", Type: "openai"},
|
||||
@@ -358,6 +334,7 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) {
|
||||
if got, _ := custom["context_window"].(float64); got != 123456 {
|
||||
t.Fatalf("custom context_window = %v, want 123456", custom["context_window"])
|
||||
}
|
||||
assertCodexSupportedReasoningLevels(t, custom, []string{"none", "low", "medium", "high", "xhigh"})
|
||||
if custom["base_instructions"] != gpt55["base_instructions"] {
|
||||
t.Fatal("expected custom model to use gpt-5.5 base_instructions fallback")
|
||||
}
|
||||
@@ -400,6 +377,27 @@ func TestModelsWithClientVersionReturnsCodexCatalog(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func assertCodexSupportedReasoningLevels(t *testing.T, model map[string]any, want []string) {
|
||||
t.Helper()
|
||||
|
||||
rawLevels, ok := model["supported_reasoning_levels"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected supported_reasoning_levels, got %#v", model["supported_reasoning_levels"])
|
||||
}
|
||||
if len(rawLevels) != len(want) {
|
||||
t.Fatalf("supported_reasoning_levels length = %d, want %d: %#v", len(rawLevels), len(want), rawLevels)
|
||||
}
|
||||
for index, rawLevel := range rawLevels {
|
||||
levelEntry, ok := rawLevel.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("supported_reasoning_levels[%d] = %#v, want object", index, rawLevel)
|
||||
}
|
||||
if got, _ := levelEntry["effort"].(string); got != want[index] {
|
||||
t.Fatalf("supported_reasoning_levels[%d].effort = %q, want %q", index, got, want[index])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
|
||||
t.Setenv("WRITABLE_PATH", "")
|
||||
t.Setenv("writable_path", "")
|
||||
|
||||
@@ -48,10 +48,76 @@ func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *Antigravit
|
||||
}
|
||||
}
|
||||
|
||||
func (o *AntigravityAuth) loadCodeAssistUserAgent() string {
|
||||
func (o *AntigravityAuth) shortUserAgent() string {
|
||||
return misc.AntigravityRequestUserAgent("")
|
||||
}
|
||||
|
||||
func (o *AntigravityAuth) nodeUserAgent() string {
|
||||
return misc.AntigravityLoadCodeAssistUserAgent("")
|
||||
}
|
||||
|
||||
func antigravityLoadCodeAssistMetadata() map[string]string {
|
||||
return map[string]string{
|
||||
"ideType": "ANTIGRAVITY",
|
||||
}
|
||||
}
|
||||
|
||||
func antigravityControlPlaneMetadata(userAgent string) map[string]string {
|
||||
return map[string]string{
|
||||
"ide_type": "ANTIGRAVITY",
|
||||
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
|
||||
"ide_name": "antigravity",
|
||||
}
|
||||
}
|
||||
|
||||
func extractCloudaicompanionProject(data map[string]any) string {
|
||||
if data == nil {
|
||||
return ""
|
||||
}
|
||||
for _, key := range []string{"cloudaicompanionProject", "projectId", "project"} {
|
||||
switch value := data[key].(type) {
|
||||
case string:
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
case map[string]any:
|
||||
if id, ok := value["id"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(id); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func defaultAntigravityTierID(loadResp map[string]any) string {
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); !okDefault || !isDefault {
|
||||
continue
|
||||
}
|
||||
if id, okID := tier["id"].(string); okID {
|
||||
if trimmed := strings.TrimSpace(id); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentTier, okTier := loadResp["currentTier"].(map[string]any); okTier {
|
||||
if id, okID := currentTier["id"].(string); okID {
|
||||
if trimmed := strings.TrimSpace(id); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return "free-tier"
|
||||
}
|
||||
|
||||
// BuildAuthURL generates the OAuth authorization URL.
|
||||
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
@@ -123,7 +189,7 @@ func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string)
|
||||
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("User-Agent", o.loadCodeAssistUserAgent())
|
||||
req.Header.Set("User-Agent", o.shortUserAgent())
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
@@ -159,13 +225,9 @@ func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string)
|
||||
|
||||
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
|
||||
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
|
||||
userAgent := o.loadCodeAssistUserAgent()
|
||||
userAgent := o.shortUserAgent()
|
||||
loadReqBody := map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ide_type": "ANTIGRAVITY",
|
||||
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
|
||||
"ide_name": "antigravity",
|
||||
},
|
||||
"metadata": antigravityLoadCodeAssistMetadata(),
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(loadReqBody)
|
||||
@@ -179,9 +241,9 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA)
|
||||
|
||||
resp, errDo := o.httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
@@ -207,40 +269,16 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string
|
||||
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||
}
|
||||
|
||||
// Extract projectID from response
|
||||
projectID := ""
|
||||
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
if projectID == "" {
|
||||
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
|
||||
if id, okID := projectMap["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
projectID := extractCloudaicompanionProject(loadResp)
|
||||
|
||||
if projectID == "" {
|
||||
tierID := "legacy-tier"
|
||||
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||
for _, rawTier := range tiers {
|
||||
tier, okTier := rawTier.(map[string]any)
|
||||
if !okTier {
|
||||
continue
|
||||
}
|
||||
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||
tierID = strings.TrimSpace(id)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
|
||||
projectID, err = o.OnboardUser(ctx, accessToken, defaultAntigravityTierID(loadResp))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if projectID == "" {
|
||||
return "", fmt.Errorf("project id not found in loadCodeAssist or onboardUser response")
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
@@ -250,14 +288,10 @@ func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string
|
||||
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
|
||||
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
|
||||
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
|
||||
userAgent := o.loadCodeAssistUserAgent()
|
||||
userAgent := o.nodeUserAgent()
|
||||
requestBody := map[string]any{
|
||||
"tierId": tierID,
|
||||
"metadata": map[string]string{
|
||||
"ide_type": "ANTIGRAVITY",
|
||||
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
|
||||
"ide_name": "antigravity",
|
||||
},
|
||||
"tier_id": tierID,
|
||||
"metadata": antigravityControlPlaneMetadata(userAgent),
|
||||
}
|
||||
|
||||
rawBody, errMarshal := json.Marshal(requestBody)
|
||||
@@ -276,13 +310,14 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s
|
||||
}
|
||||
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
|
||||
endpointURL := fmt.Sprintf("%s/%s:onboardUser", DailyAPIEndpoint, APIVersion)
|
||||
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||
if errRequest != nil {
|
||||
cancel()
|
||||
return "", fmt.Errorf("create request: %w", errRequest)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
req.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA)
|
||||
@@ -312,18 +347,11 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s
|
||||
if done, okDone := data["done"].(bool); okDone && done {
|
||||
projectID := ""
|
||||
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||
case map[string]any:
|
||||
if id, okID := projectValue["id"].(string); okID {
|
||||
projectID = strings.TrimSpace(id)
|
||||
}
|
||||
case string:
|
||||
projectID = strings.TrimSpace(projectValue)
|
||||
}
|
||||
projectID = extractCloudaicompanionProject(responseData)
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||
log.Infof("Successfully fetched project_id: %s", util.HideAPIKey(projectID))
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
@@ -346,5 +374,5 @@ func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID s
|
||||
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||
}
|
||||
|
||||
return "", nil
|
||||
return "", fmt.Errorf("onboard user did not complete after %d attempts", maxAttempts)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestFetchProjectIDFromLoadCodeAssist(t *testing.T) {
|
||||
auth := NewAntigravityAuth(nil, &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" {
|
||||
t.Fatalf("unexpected request URL: %s", req.URL.String())
|
||||
}
|
||||
assertLoadCodeAssistHeaders(t, req)
|
||||
assertJSONContains(t, req, `"ideType":"ANTIGRAVITY"`)
|
||||
return jsonResponse(`{"cloudaicompanionProject":"cogent-snow-4mnnp"}`), nil
|
||||
})})
|
||||
|
||||
projectID, err := auth.FetchProjectID(context.Background(), "access-token")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchProjectID error: %v", err)
|
||||
}
|
||||
if projectID != "cogent-snow-4mnnp" {
|
||||
t.Fatalf("projectID = %q", projectID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchProjectIDFallsBackToDailyOnboardUser(t *testing.T) {
|
||||
var sawOnboard bool
|
||||
auth := NewAntigravityAuth(nil, &http.Client{Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
switch req.URL.String() {
|
||||
case "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist":
|
||||
assertLoadCodeAssistHeaders(t, req)
|
||||
return jsonResponse(`{"allowedTiers":[{"id":"free-tier","isDefault":true}]}`), nil
|
||||
case "https://daily-cloudcode-pa.googleapis.com/v1internal:onboardUser":
|
||||
sawOnboard = true
|
||||
assertOnboardUserHeaders(t, req)
|
||||
assertJSONContains(t, req, `"tier_id":"free-tier"`)
|
||||
assertJSONContains(t, req, `"ide_type":"ANTIGRAVITY"`)
|
||||
return jsonResponse(`{
|
||||
"done": true,
|
||||
"response": {
|
||||
"cloudaicompanionProject": {
|
||||
"id": "cogent-snow-4mnnp",
|
||||
"name": "cogent-snow-4mnnp",
|
||||
"projectNumber": "22597072101"
|
||||
}
|
||||
}
|
||||
}`), nil
|
||||
default:
|
||||
t.Fatalf("unexpected request URL: %s", req.URL.String())
|
||||
return nil, nil
|
||||
}
|
||||
})})
|
||||
|
||||
projectID, err := auth.FetchProjectID(context.Background(), "access-token")
|
||||
if err != nil {
|
||||
t.Fatalf("FetchProjectID error: %v", err)
|
||||
}
|
||||
if !sawOnboard {
|
||||
t.Fatalf("expected onboardUser fallback")
|
||||
}
|
||||
if projectID != "cogent-snow-4mnnp" {
|
||||
t.Fatalf("projectID = %q", projectID)
|
||||
}
|
||||
}
|
||||
|
||||
func assertLoadCodeAssistHeaders(t *testing.T, req *http.Request) {
|
||||
t.Helper()
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer access-token" {
|
||||
t.Fatalf("Authorization = %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Accept"); got != "*/*" {
|
||||
t.Fatalf("Accept = %q", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Goog-Api-Client"); got != "" {
|
||||
t.Fatalf("X-Goog-Api-Client = %q, want empty", got)
|
||||
}
|
||||
if got := req.Header.Get("User-Agent"); strings.Contains(got, "google-api-nodejs-client/") {
|
||||
t.Fatalf("User-Agent = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func assertOnboardUserHeaders(t *testing.T, req *http.Request) {
|
||||
t.Helper()
|
||||
if got := req.Header.Get("Authorization"); got != "Bearer access-token" {
|
||||
t.Fatalf("Authorization = %q", got)
|
||||
}
|
||||
if got := req.Header.Get("Accept"); got != "*/*" {
|
||||
t.Fatalf("Accept = %q", got)
|
||||
}
|
||||
if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" {
|
||||
t.Fatalf("X-Goog-Api-Client = %q", got)
|
||||
}
|
||||
if got := req.Header.Get("User-Agent"); !strings.Contains(got, "google-api-nodejs-client/10.3.0") {
|
||||
t.Fatalf("User-Agent = %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func assertJSONContains(t *testing.T, req *http.Request, want string) {
|
||||
t.Helper()
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
bodyText := string(body)
|
||||
req.Body = io.NopCloser(strings.NewReader(bodyText))
|
||||
if !strings.Contains(bodyText, want) {
|
||||
t.Fatalf("body missing %s: %s", want, bodyText)
|
||||
}
|
||||
}
|
||||
|
||||
func jsonResponse(body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ const (
|
||||
|
||||
// Antigravity API configuration
|
||||
const (
|
||||
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
APIVersion = "v1internal"
|
||||
APIEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
DailyAPIEndpoint = "https://daily-cloudcode-pa.googleapis.com"
|
||||
APIVersion = "v1internal"
|
||||
)
|
||||
|
||||
@@ -34,7 +34,7 @@ func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
|
||||
if cfg != nil {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
|
||||
log.Errorf("failed to configure proxy dialer for %q: %v", proxyutil.Redact(cfg.ProxyURL), errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
|
||||
@@ -37,8 +37,8 @@ type Config struct {
|
||||
// TLS config controls HTTPS server settings.
|
||||
TLS TLSConfig `yaml:"tls" json:"tls"`
|
||||
|
||||
// Home config enables the Redis-based control plane integration.
|
||||
Home HomeConfig `yaml:"home" json:"-"`
|
||||
// Home config is runtime-only and is populated from -home-jwt.
|
||||
Home HomeConfig `yaml:"-" json:"-"`
|
||||
|
||||
// RemoteManagement nests management-related options under 'remote-management'.
|
||||
RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"`
|
||||
@@ -69,8 +69,8 @@ type Config struct {
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
|
||||
// are retained in memory for the Redis RESP interface (LPOP/RPOP).
|
||||
// RedisUsageQueueRetentionSeconds controls how long usage queue items are retained
|
||||
// in memory for Management API consumers.
|
||||
// Default: 60. Max: 3600.
|
||||
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
|
||||
|
||||
@@ -585,6 +585,9 @@ type OpenAICompatibilityModel struct {
|
||||
// Alias is the model name alias that clients will use to reference this model.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
|
||||
// Image marks this model as callable through /v1/images/generations and /v1/images/edits.
|
||||
Image bool `yaml:"image,omitempty" json:"image,omitempty"`
|
||||
|
||||
// Thinking configures the thinking/reasoning capability for this model.
|
||||
// If nil, the model defaults to level-based reasoning with levels ["low", "medium", "high"].
|
||||
Thinking *registry.ThinkingSupport `yaml:"thinking,omitempty" json:"thinking,omitempty"`
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
package config
|
||||
|
||||
// HomeConfig configures the optional "home" control plane integration over Redis protocol.
|
||||
// HomeConfig stores runtime-only Home control plane settings from -home-jwt.
|
||||
type HomeConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Host string `yaml:"host" json:"-"`
|
||||
Port int `yaml:"port" json:"-"`
|
||||
Password string `yaml:"password" json:"-"`
|
||||
DisableClusterDiscovery bool `yaml:"disable-cluster-discovery" json:"-"`
|
||||
TLS HomeTLSConfig `yaml:"tls" json:"-"`
|
||||
}
|
||||
|
||||
// HomeTLSConfig configures client-side TLS for the home Redis connection.
|
||||
type HomeTLSConfig struct {
|
||||
Enable bool `yaml:"enable" json:"-"`
|
||||
ServerName string `yaml:"server-name" json:"-"`
|
||||
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
|
||||
CACert string `yaml:"ca-cert" json:"-"`
|
||||
Enable bool `yaml:"enable" json:"-"`
|
||||
ServerName string `yaml:"server-name" json:"-"`
|
||||
InsecureSkipVerify bool `yaml:"insecure-skip-verify" json:"-"`
|
||||
CACert string `yaml:"ca-cert" json:"-"`
|
||||
ClientCert string `yaml:"-" json:"-"`
|
||||
ClientKey string `yaml:"-" json:"-"`
|
||||
UseTargetServerName bool `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
@@ -2,13 +2,12 @@ package config
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseConfigBytesHomeTLS(t *testing.T) {
|
||||
func TestParseConfigBytesIgnoresHomeConfig(t *testing.T) {
|
||||
cfg, err := ParseConfigBytes([]byte(`
|
||||
home:
|
||||
enabled: true
|
||||
host: home.example.com
|
||||
port: 444
|
||||
password: secret
|
||||
disable-cluster-discovery: true
|
||||
tls:
|
||||
enable: true
|
||||
@@ -20,31 +19,28 @@ home:
|
||||
t.Fatalf("ParseConfigBytes() error = %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Home.Enabled {
|
||||
t.Fatal("Home.Enabled = false, want true")
|
||||
if cfg.Home.Enabled {
|
||||
t.Fatal("Home.Enabled = true, want false")
|
||||
}
|
||||
if cfg.Home.Host != "home.example.com" {
|
||||
t.Fatalf("Home.Host = %q, want home.example.com", cfg.Home.Host)
|
||||
if cfg.Home.Host != "" {
|
||||
t.Fatalf("Home.Host = %q, want empty", cfg.Home.Host)
|
||||
}
|
||||
if cfg.Home.Port != 444 {
|
||||
t.Fatalf("Home.Port = %d, want 444", cfg.Home.Port)
|
||||
if cfg.Home.Port != 0 {
|
||||
t.Fatalf("Home.Port = %d, want 0", cfg.Home.Port)
|
||||
}
|
||||
if cfg.Home.Password != "secret" {
|
||||
t.Fatalf("Home.Password = %q, want secret", cfg.Home.Password)
|
||||
if cfg.Home.DisableClusterDiscovery {
|
||||
t.Fatal("Home.DisableClusterDiscovery = true, want false")
|
||||
}
|
||||
if !cfg.Home.DisableClusterDiscovery {
|
||||
t.Fatal("Home.DisableClusterDiscovery = false, want true")
|
||||
if cfg.Home.TLS.Enable {
|
||||
t.Fatal("Home.TLS.Enable = true, want false")
|
||||
}
|
||||
if !cfg.Home.TLS.Enable {
|
||||
t.Fatal("Home.TLS.Enable = false, want true")
|
||||
if cfg.Home.TLS.ServerName != "" {
|
||||
t.Fatalf("Home.TLS.ServerName = %q, want empty", cfg.Home.TLS.ServerName)
|
||||
}
|
||||
if cfg.Home.TLS.ServerName != "home.example.com" {
|
||||
t.Fatalf("Home.TLS.ServerName = %q, want home.example.com", cfg.Home.TLS.ServerName)
|
||||
if cfg.Home.TLS.CACert != "" {
|
||||
t.Fatalf("Home.TLS.CACert = %q, want empty", cfg.Home.TLS.CACert)
|
||||
}
|
||||
if cfg.Home.TLS.CACert != "C:/certs/ca.pem" {
|
||||
t.Fatalf("Home.TLS.CACert = %q, want C:/certs/ca.pem", cfg.Home.TLS.CACert)
|
||||
}
|
||||
if !cfg.Home.TLS.InsecureSkipVerify {
|
||||
t.Fatal("Home.TLS.InsecureSkipVerify = false, want true")
|
||||
if cfg.Home.TLS.InsecureSkipVerify {
|
||||
t.Fatal("Home.TLS.InsecureSkipVerify = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,386 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
)
|
||||
|
||||
const homeCertificateRequestTimeout = 30 * time.Second
|
||||
|
||||
type homeJWTClaims struct {
|
||||
CertificateID string `json:"certificate_id"`
|
||||
ClusterID string `json:"cluster_id"`
|
||||
CAFingerprint string `json:"ca_fingerprint"`
|
||||
EnrollmentSecret string `json:"enrollment_secret"`
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
}
|
||||
|
||||
type certificateRequestResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Certificate string `json:"certificate"`
|
||||
CA string `json:"ca"`
|
||||
}
|
||||
|
||||
type certificatePaths struct {
|
||||
Dir string
|
||||
ClientCert string
|
||||
ClientKey string
|
||||
CACert string
|
||||
}
|
||||
|
||||
// ConfigFromJWT prepares a Home config from the JWT and ensures local mTLS files exist.
|
||||
func ConfigFromJWT(ctx context.Context, rawJWT string) (config.HomeConfig, error) {
|
||||
claims, errClaims := parseHomeJWTClaims(rawJWT)
|
||||
if errClaims != nil {
|
||||
return config.HomeConfig{}, errClaims
|
||||
}
|
||||
paths, errPaths := defaultCertificatePaths()
|
||||
if errPaths != nil {
|
||||
return config.HomeConfig{}, errPaths
|
||||
}
|
||||
if errEnsure := ensureHomeCertificateFiles(ctx, claims, paths); errEnsure != nil {
|
||||
return config.HomeConfig{}, errEnsure
|
||||
}
|
||||
return config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: strings.TrimSpace(claims.IP),
|
||||
Port: claims.Port,
|
||||
TLS: config.HomeTLSConfig{
|
||||
Enable: true,
|
||||
CACert: paths.CACert,
|
||||
ClientCert: paths.ClientCert,
|
||||
ClientKey: paths.ClientKey,
|
||||
UseTargetServerName: true,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseHomeJWTClaims(rawJWT string) (homeJWTClaims, error) {
|
||||
var claims homeJWTClaims
|
||||
parts := strings.Split(strings.TrimSpace(rawJWT), ".")
|
||||
if len(parts) != 3 {
|
||||
return claims, fmt.Errorf("home jwt is invalid")
|
||||
}
|
||||
payload, errDecode := decodeJWTPart(parts[1])
|
||||
if errDecode != nil {
|
||||
return claims, errDecode
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(payload, &claims); errUnmarshal != nil {
|
||||
return claims, errUnmarshal
|
||||
}
|
||||
if strings.TrimSpace(claims.CertificateID) == "" {
|
||||
return claims, fmt.Errorf("home jwt certificate_id is required")
|
||||
}
|
||||
if strings.TrimSpace(claims.ClusterID) == "" {
|
||||
return claims, fmt.Errorf("home jwt cluster_id is required")
|
||||
}
|
||||
if normalizeFingerprint(claims.CAFingerprint) == "" {
|
||||
return claims, fmt.Errorf("home jwt ca_fingerprint is required")
|
||||
}
|
||||
if strings.TrimSpace(claims.EnrollmentSecret) == "" {
|
||||
return claims, fmt.Errorf("home jwt enrollment_secret is required")
|
||||
}
|
||||
if strings.TrimSpace(claims.IP) == "" || claims.Port <= 0 {
|
||||
return claims, fmt.Errorf("home jwt target address is invalid")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func decodeJWTPart(part string) ([]byte, error) {
|
||||
if decoded, errDecode := base64.RawURLEncoding.DecodeString(part); errDecode == nil {
|
||||
return decoded, nil
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(part)
|
||||
}
|
||||
|
||||
func defaultCertificatePaths() (certificatePaths, error) {
|
||||
homeDir, errHome := os.UserHomeDir()
|
||||
if errHome != nil {
|
||||
return certificatePaths{}, errHome
|
||||
}
|
||||
dir := filepath.Join(homeDir, ".cli-proxy-api")
|
||||
return certificatePaths{
|
||||
Dir: dir,
|
||||
ClientCert: filepath.Join(dir, "client-crt.pem"),
|
||||
ClientKey: filepath.Join(dir, "client-key.pem"),
|
||||
CACert: filepath.Join(dir, "home-ca-crt.pem"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ensureHomeCertificateFiles(ctx context.Context, claims homeJWTClaims, paths certificatePaths) error {
|
||||
if fileExists(paths.ClientCert) && fileExists(paths.ClientKey) {
|
||||
if !fileExists(paths.CACert) {
|
||||
return fmt.Errorf("home ca certificate file is missing")
|
||||
}
|
||||
if errVerify := verifyCACertificateFile(paths.CACert, claims.CAFingerprint); errVerify != nil {
|
||||
return errVerify
|
||||
}
|
||||
if errChmod := chmodCertificateFiles(paths); errChmod != nil {
|
||||
return errChmod
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if errMkdir := os.MkdirAll(paths.Dir, 0o700); errMkdir != nil {
|
||||
return errMkdir
|
||||
}
|
||||
key, errKey := loadOrCreateClientKey(paths.ClientKey)
|
||||
if errKey != nil {
|
||||
return errKey
|
||||
}
|
||||
csrPEM, errCSR := createClientCSR(claims.CertificateID, key)
|
||||
if errCSR != nil {
|
||||
return errCSR
|
||||
}
|
||||
response, errRequest := requestClientCertificate(ctx, claims, csrPEM)
|
||||
if errRequest != nil {
|
||||
return errRequest
|
||||
}
|
||||
if strings.TrimSpace(response.Certificate) == "" || strings.TrimSpace(response.CA) == "" {
|
||||
return fmt.Errorf("home certificate response is incomplete")
|
||||
}
|
||||
if errVerify := verifyCACertificatePEM([]byte(response.CA), claims.CAFingerprint); errVerify != nil {
|
||||
return errVerify
|
||||
}
|
||||
if errWrite := writeFile0600(paths.ClientCert, []byte(response.Certificate)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeFile0600(paths.CACert, []byte(response.CA)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyCACertificateFile(path string, expectedFingerprint string) error {
|
||||
raw, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return errRead
|
||||
}
|
||||
return verifyCACertificatePEM(raw, expectedFingerprint)
|
||||
}
|
||||
|
||||
func verifyCACertificatePEM(raw []byte, expectedFingerprint string) error {
|
||||
actual, errFingerprint := certificateFingerprintPEM(raw)
|
||||
if errFingerprint != nil {
|
||||
return errFingerprint
|
||||
}
|
||||
expected := normalizeFingerprint(expectedFingerprint)
|
||||
if expected == "" {
|
||||
return fmt.Errorf("home ca fingerprint is required")
|
||||
}
|
||||
if actual != expected {
|
||||
return fmt.Errorf("home ca fingerprint mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func certificateFingerprintPEM(raw []byte) (string, error) {
|
||||
block, _ := pem.Decode(raw)
|
||||
if block == nil || block.Type != "CERTIFICATE" {
|
||||
return "", fmt.Errorf("home ca certificate pem is invalid")
|
||||
}
|
||||
cert, errParse := x509.ParseCertificate(block.Bytes)
|
||||
if errParse != nil {
|
||||
return "", errParse
|
||||
}
|
||||
sum := sha256.Sum256(cert.Raw)
|
||||
return hex.EncodeToString(sum[:]), nil
|
||||
}
|
||||
|
||||
func normalizeFingerprint(fingerprint string) string {
|
||||
fingerprint = strings.TrimSpace(strings.ToLower(fingerprint))
|
||||
fingerprint = strings.ReplaceAll(fingerprint, ":", "")
|
||||
fingerprint = strings.ReplaceAll(fingerprint, " ", "")
|
||||
return fingerprint
|
||||
}
|
||||
|
||||
func loadOrCreateClientKey(path string) (*rsa.PrivateKey, error) {
|
||||
if fileExists(path) {
|
||||
raw, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
key, errParse := parseRSAPrivateKeyPEM(raw)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
|
||||
return nil, errChmod
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
key, errKey := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if errKey != nil {
|
||||
return nil, errKey
|
||||
}
|
||||
raw := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||
if errWrite := writeFile0600(path, raw); errWrite != nil {
|
||||
return nil, errWrite
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func writeFile0600(path string, raw []byte) error {
|
||||
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return os.Chmod(path, 0o600)
|
||||
}
|
||||
|
||||
func chmodCertificateFiles(paths certificatePaths) error {
|
||||
for _, path := range []string{paths.ClientCert, paths.ClientKey, paths.CACert} {
|
||||
if errChmod := os.Chmod(path, 0o600); errChmod != nil {
|
||||
return errChmod
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRSAPrivateKeyPEM(raw []byte) (*rsa.PrivateKey, error) {
|
||||
block, _ := pem.Decode(raw)
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("client key pem is invalid")
|
||||
}
|
||||
switch block.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
case "PRIVATE KEY":
|
||||
key, errParse := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("client key is not rsa")
|
||||
}
|
||||
return rsaKey, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("client key pem type %q is unsupported", block.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func createClientCSR(certificateID string, key *rsa.PrivateKey) ([]byte, error) {
|
||||
certificateID = strings.TrimSpace(certificateID)
|
||||
if certificateID == "" {
|
||||
return nil, fmt.Errorf("certificate id is required")
|
||||
}
|
||||
template := &x509.CertificateRequest{
|
||||
Subject: pkix.Name{
|
||||
CommonName: certificateID,
|
||||
},
|
||||
}
|
||||
der, errCreate := x509.CreateCertificateRequest(rand.Reader, template, key)
|
||||
if errCreate != nil {
|
||||
return nil, errCreate
|
||||
}
|
||||
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: der}), nil
|
||||
}
|
||||
|
||||
func requestClientCertificate(ctx context.Context, claims homeJWTClaims, csrPEM []byte) (certificateRequestResponse, error) {
|
||||
var response certificateRequestResponse
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
dialCtx, cancel := context.WithTimeout(ctx, homeCertificateRequestTimeout)
|
||||
defer cancel()
|
||||
addr := net.JoinHostPort(strings.TrimSpace(claims.IP), strconv.Itoa(claims.Port))
|
||||
conn, errDial := (&net.Dialer{}).DialContext(dialCtx, "tcp", addr)
|
||||
if errDial != nil {
|
||||
return response, errDial
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
if deadline, ok := dialCtx.Deadline(); ok {
|
||||
_ = conn.SetDeadline(deadline)
|
||||
}
|
||||
if _, errWrite := conn.Write(encodeRESPArray("CERTIFICATE", "REQUEST", claims.CertificateID, claims.EnrollmentSecret, string(csrPEM))); errWrite != nil {
|
||||
return response, errWrite
|
||||
}
|
||||
raw, errRead := readRESPBulk(bufio.NewReader(conn))
|
||||
if errRead != nil {
|
||||
return response, errRead
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(raw, &response); errUnmarshal != nil {
|
||||
return response, errUnmarshal
|
||||
}
|
||||
if !response.OK {
|
||||
return response, fmt.Errorf("home certificate request failed")
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func encodeRESPArray(args ...string) []byte {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("*")
|
||||
buf.WriteString(strconv.Itoa(len(args)))
|
||||
buf.WriteString("\r\n")
|
||||
for _, arg := range args {
|
||||
buf.WriteString("$")
|
||||
buf.WriteString(strconv.Itoa(len(arg)))
|
||||
buf.WriteString("\r\n")
|
||||
buf.WriteString(arg)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func readRESPBulk(reader *bufio.Reader) ([]byte, error) {
|
||||
prefix, errRead := reader.ReadByte()
|
||||
if errRead != nil {
|
||||
return nil, errRead
|
||||
}
|
||||
switch prefix {
|
||||
case '$':
|
||||
line, errLine := reader.ReadString('\n')
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
size, errSize := strconv.Atoi(strings.TrimSpace(line))
|
||||
if errSize != nil {
|
||||
return nil, errSize
|
||||
}
|
||||
if size < 0 {
|
||||
return nil, fmt.Errorf("home certificate request returned nil")
|
||||
}
|
||||
payload := make([]byte, size+2)
|
||||
if _, errFull := io.ReadFull(reader, payload); errFull != nil {
|
||||
return nil, errFull
|
||||
}
|
||||
return payload[:size], nil
|
||||
case '-':
|
||||
line, errLine := reader.ReadString('\n')
|
||||
if errLine != nil {
|
||||
return nil, errLine
|
||||
}
|
||||
return nil, fmt.Errorf("%s", strings.TrimSpace(line))
|
||||
default:
|
||||
return nil, fmt.Errorf("home certificate request returned unsupported resp prefix %q", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
info, errStat := os.Stat(path)
|
||||
return errStat == nil && !info.IsDir()
|
||||
}
|
||||
+99
-17
@@ -31,6 +31,8 @@ const (
|
||||
|
||||
homeReconnectInterval = time.Second
|
||||
homeReconnectFailoverThreshold = 3
|
||||
homeRedisOperationTimeout = 3 * time.Second
|
||||
homeSubscriptionReceiveTimeout = 3 * time.Second
|
||||
redisChannelCluster = "cluster"
|
||||
)
|
||||
|
||||
@@ -172,21 +174,30 @@ func (c *Client) ensureClients() error {
|
||||
}
|
||||
|
||||
func (c *Client) redisOptionsLocked(addr string) (*redis.Options, error) {
|
||||
tlsConfig, errTLS := c.homeTLSConfigLocked()
|
||||
tlsConfig, errTLS := c.homeTLSConfigLocked(addr)
|
||||
if errTLS != nil {
|
||||
return nil, errTLS
|
||||
}
|
||||
return &redis.Options{
|
||||
Addr: addr,
|
||||
Password: c.homeCfg.Password,
|
||||
TLSConfig: tlsConfig,
|
||||
Addr: addr,
|
||||
TLSConfig: tlsConfig,
|
||||
DialTimeout: homeRedisOperationTimeout,
|
||||
ReadTimeout: homeRedisOperationTimeout,
|
||||
WriteTimeout: homeRedisOperationTimeout,
|
||||
MaxRetries: -1,
|
||||
DialerRetries: 1,
|
||||
ContextTimeoutEnabled: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
|
||||
func (c *Client) homeTLSConfigLocked(addr string) (*tls.Config, error) {
|
||||
serverName := strings.TrimSpace(c.homeCfg.TLS.ServerName)
|
||||
if serverName == "" {
|
||||
serverName = strings.TrimSpace(c.seedHost)
|
||||
if c.homeCfg.TLS.UseTargetServerName {
|
||||
serverName = hostFromAddress(addr)
|
||||
} else {
|
||||
serverName = strings.TrimSpace(c.seedHost)
|
||||
}
|
||||
}
|
||||
if serverName == "" {
|
||||
serverName = strings.TrimSpace(c.homeCfg.Host)
|
||||
@@ -194,6 +205,14 @@ func (c *Client) homeTLSConfigLocked() (*tls.Config, error) {
|
||||
return newHomeTLSConfig(c.homeCfg.TLS, serverName)
|
||||
}
|
||||
|
||||
func hostFromAddress(addr string) string {
|
||||
host, _, errSplit := net.SplitHostPort(strings.TrimSpace(addr))
|
||||
if errSplit == nil {
|
||||
return strings.TrimSpace(host)
|
||||
}
|
||||
return strings.TrimSpace(addr)
|
||||
}
|
||||
|
||||
func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls.Config, error) {
|
||||
if !cfg.Enable {
|
||||
return nil, nil
|
||||
@@ -210,6 +229,19 @@ func newHomeTLSConfig(cfg config.HomeTLSConfig, fallbackServerName string) (*tls
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
clientCertPath := strings.TrimSpace(cfg.ClientCert)
|
||||
clientKeyPath := strings.TrimSpace(cfg.ClientKey)
|
||||
if clientCertPath != "" || clientKeyPath != "" {
|
||||
if clientCertPath == "" || clientKeyPath == "" {
|
||||
return nil, fmt.Errorf("home tls: client certificate and key must be set together")
|
||||
}
|
||||
certPair, errLoad := tls.LoadX509KeyPair(clientCertPath, clientKeyPath)
|
||||
if errLoad != nil {
|
||||
return nil, fmt.Errorf("home tls: load client certificate: %w", errLoad)
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{certPair}
|
||||
}
|
||||
|
||||
caCertPath := strings.TrimSpace(cfg.CACert)
|
||||
if caCertPath == "" {
|
||||
return tlsConfig, nil
|
||||
@@ -404,6 +436,25 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) {
|
||||
}
|
||||
c.reconnectFailures = 0
|
||||
|
||||
return c.switchToNextNodeLocked()
|
||||
}
|
||||
|
||||
func (c *Client) failoverAfterSubscriptionTimeout() (bool, string) {
|
||||
if c == nil {
|
||||
return false, ""
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.clusterDiscoveryEnabledLocked() {
|
||||
c.reconnectFailures = 0
|
||||
return false, ""
|
||||
}
|
||||
c.reconnectFailures = 0
|
||||
return c.switchToNextNodeLocked()
|
||||
}
|
||||
|
||||
func (c *Client) switchToNextNodeLocked() (bool, string) {
|
||||
currentHost := strings.TrimSpace(c.homeCfg.Host)
|
||||
currentPort := c.homeCfg.Port
|
||||
candidates := append([]clusterNode(nil), c.clusterNodes...)
|
||||
@@ -426,6 +477,13 @@ func (c *Client) failoverAfterReconnectFailure() (bool, string) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (c *Client) markSubscriptionTimeout() {
|
||||
switched, addr := c.failoverAfterSubscriptionTimeout()
|
||||
if switched {
|
||||
log.Warnf("home subscription heartbeat timeout; switching to %s", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) resetReconnectFailures() {
|
||||
if c == nil {
|
||||
return
|
||||
@@ -683,7 +741,7 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
}
|
||||
|
||||
// Ensure the subscription is established before marking heartbeat OK.
|
||||
if _, errReceive := pubsub.Receive(ctx); errReceive != nil {
|
||||
if _, errReceive := pubsub.ReceiveTimeout(ctx, homeSubscriptionReceiveTimeout); errReceive != nil {
|
||||
_ = pubsub.Close()
|
||||
c.markReconnectFailure("subscribe")
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
@@ -694,28 +752,52 @@ func (c *Client) StartConfigSubscriber(ctx context.Context, onConfig func([]byte
|
||||
c.heartbeatOK.Store(true)
|
||||
|
||||
for {
|
||||
msg, errMsg := pubsub.ReceiveMessage(ctx)
|
||||
event, errMsg := pubsub.ReceiveTimeout(ctx, homeSubscriptionReceiveTimeout)
|
||||
if errMsg != nil {
|
||||
_ = pubsub.Close()
|
||||
c.heartbeatOK.Store(false)
|
||||
c.markReconnectFailure("subscription")
|
||||
if isTimeoutError(errMsg) {
|
||||
c.markSubscriptionTimeout()
|
||||
} else {
|
||||
c.markReconnectFailure("subscription")
|
||||
}
|
||||
sleepWithContext(ctx, homeReconnectInterval)
|
||||
break
|
||||
}
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if errApply := c.handleSubscriptionPayload(msg.Channel, msg.Payload, onConfig); errApply != nil {
|
||||
if strings.EqualFold(strings.TrimSpace(msg.Channel), redisChannelCluster) {
|
||||
log.Warn("failed to apply cluster update from home control center, ignoring")
|
||||
} else {
|
||||
log.Warn("failed to apply config update from home control center, ignoring")
|
||||
switch msg := event.(type) {
|
||||
case *redis.Message:
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if errApply := c.handleSubscriptionPayload(msg.Channel, msg.Payload, onConfig); errApply != nil {
|
||||
if strings.EqualFold(strings.TrimSpace(msg.Channel), redisChannelCluster) {
|
||||
log.Warn("failed to apply cluster update from home control center, ignoring")
|
||||
} else {
|
||||
log.Warn("failed to apply config update from home control center, ignoring")
|
||||
}
|
||||
}
|
||||
case *redis.Pong:
|
||||
c.resetReconnectFailures()
|
||||
case *redis.Subscription:
|
||||
continue
|
||||
default:
|
||||
log.Debugf("home subscription returned unsupported message type %T", event)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isTimeoutError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
func sleepWithContext(ctx context.Context, d time.Duration) {
|
||||
if d <= 0 {
|
||||
return
|
||||
|
||||
@@ -37,10 +37,9 @@ func TestAuthDispatchRequestDefaultsCountToOne(t *testing.T) {
|
||||
|
||||
func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
|
||||
client := New(config.HomeConfig{
|
||||
Enabled: true,
|
||||
Host: "127.0.0.1",
|
||||
Port: 6379,
|
||||
Password: "secret",
|
||||
Enabled: true,
|
||||
Host: "127.0.0.1",
|
||||
Port: 6379,
|
||||
})
|
||||
|
||||
client.mu.Lock()
|
||||
@@ -53,8 +52,8 @@ func TestRedisOptionsHomeTLSDisabled(t *testing.T) {
|
||||
if options.TLSConfig != nil {
|
||||
t.Fatalf("TLSConfig = %#v, want nil", options.TLSConfig)
|
||||
}
|
||||
if options.Password != "secret" {
|
||||
t.Fatalf("Password = %q, want secret", options.Password)
|
||||
if options.Password != "" {
|
||||
t.Fatalf("Password = %q, want empty", options.Password)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,16 +2,24 @@ package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type endpointKey struct{}
|
||||
type responseStatusKey struct{}
|
||||
type responseHeadersKey struct{}
|
||||
|
||||
type responseStatusHolder struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
type responseHeadersHolder struct {
|
||||
mu sync.RWMutex
|
||||
headers http.Header
|
||||
}
|
||||
|
||||
func WithEndpoint(ctx context.Context, endpoint string) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
@@ -39,6 +47,16 @@ func WithResponseStatusHolder(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{})
|
||||
}
|
||||
|
||||
func WithResponseHeadersHolder(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if holder, ok := ctx.Value(responseHeadersKey{}).(*responseHeadersHolder); ok && holder != nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, responseHeadersKey{}, &responseHeadersHolder{})
|
||||
}
|
||||
|
||||
func SetResponseStatus(ctx context.Context, status int) {
|
||||
if ctx == nil || status <= 0 {
|
||||
return
|
||||
@@ -50,6 +68,19 @@ func SetResponseStatus(ctx context.Context, status int) {
|
||||
holder.status.Store(int32(status))
|
||||
}
|
||||
|
||||
func SetResponseHeaders(ctx context.Context, headers http.Header) {
|
||||
if ctx == nil {
|
||||
return
|
||||
}
|
||||
holder, ok := ctx.Value(responseHeadersKey{}).(*responseHeadersHolder)
|
||||
if !ok || holder == nil {
|
||||
return
|
||||
}
|
||||
holder.mu.Lock()
|
||||
defer holder.mu.Unlock()
|
||||
holder.headers = cloneHTTPHeader(headers)
|
||||
}
|
||||
|
||||
func GetResponseStatus(ctx context.Context) int {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
@@ -60,3 +91,27 @@ func GetResponseStatus(ctx context.Context) int {
|
||||
}
|
||||
return int(holder.status.Load())
|
||||
}
|
||||
|
||||
func GetResponseHeaders(ctx context.Context) http.Header {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
holder, ok := ctx.Value(responseHeadersKey{}).(*responseHeadersHolder)
|
||||
if !ok || holder == nil {
|
||||
return nil
|
||||
}
|
||||
holder.mu.RLock()
|
||||
defer holder.mu.RUnlock()
|
||||
return cloneHTTPHeader(holder.headers)
|
||||
}
|
||||
|
||||
func cloneHTTPHeader(src http.Header) http.Header {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make(http.Header, len(src))
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package redisqueue
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -47,6 +48,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
}
|
||||
apiKey := strings.TrimSpace(record.APIKey)
|
||||
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
|
||||
reasoningEffort := strings.TrimSpace(record.ReasoningEffort)
|
||||
if reasoningEffort == "" {
|
||||
reasoningEffort = coreusage.ReasoningEffortFromContext(ctx)
|
||||
}
|
||||
|
||||
tokens := tokenStats{
|
||||
InputTokens: record.Detail.InputTokens,
|
||||
@@ -71,24 +76,26 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
fail := resolveFail(ctx, record, failed)
|
||||
|
||||
detail := requestDetail{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: record.Latency.Milliseconds(),
|
||||
Source: record.Source,
|
||||
AuthIndex: record.AuthIndex,
|
||||
Tokens: tokens,
|
||||
Failed: failed,
|
||||
Fail: fail,
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: record.Latency.Milliseconds(),
|
||||
Source: record.Source,
|
||||
AuthIndex: record.AuthIndex,
|
||||
Tokens: tokens,
|
||||
Failed: failed,
|
||||
Fail: fail,
|
||||
ResponseHeaders: record.ResponseHeaders,
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(queuedUsageDetail{
|
||||
requestDetail: detail,
|
||||
Provider: provider,
|
||||
Model: modelName,
|
||||
Alias: aliasName,
|
||||
Endpoint: resolveEndpoint(ctx),
|
||||
AuthType: authType,
|
||||
APIKey: apiKey,
|
||||
RequestID: requestID,
|
||||
requestDetail: detail,
|
||||
Provider: provider,
|
||||
Model: modelName,
|
||||
Alias: aliasName,
|
||||
Endpoint: resolveEndpoint(ctx),
|
||||
AuthType: authType,
|
||||
APIKey: apiKey,
|
||||
RequestID: requestID,
|
||||
ReasoningEffort: reasoningEffort,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
@@ -98,23 +105,25 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
|
||||
type queuedUsageDetail struct {
|
||||
requestDetail
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Alias string `json:"alias"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
AuthType string `json:"auth_type"`
|
||||
APIKey string `json:"api_key"`
|
||||
RequestID string `json:"request_id"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Alias string `json:"alias"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
AuthType string `json:"auth_type"`
|
||||
APIKey string `json:"api_key"`
|
||||
RequestID string `json:"request_id"`
|
||||
ReasoningEffort string `json:"reasoning_effort"`
|
||||
}
|
||||
|
||||
type requestDetail struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens tokenStats `json:"tokens"`
|
||||
Failed bool `json:"failed"`
|
||||
Fail failDetail `json:"fail"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens tokenStats `json:"tokens"`
|
||||
Failed bool `json:"failed"`
|
||||
Fail failDetail `json:"fail"`
|
||||
ResponseHeaders http.Header `json:"response_headers,omitempty"`
|
||||
}
|
||||
|
||||
type tokenStats struct {
|
||||
|
||||
@@ -19,9 +19,69 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
|
||||
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
|
||||
ctx = internallogging.WithResponseStatusHolder(ctx)
|
||||
internallogging.SetResponseStatus(ctx, http.StatusOK)
|
||||
responseHeaders := http.Header{}
|
||||
responseHeaders.Add("X-Upstream-Request-Id", "upstream-req-1")
|
||||
responseHeaders.Add("Retry-After", "30")
|
||||
|
||||
plugin := &usageQueuePlugin{}
|
||||
plugin.HandleUsage(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4",
|
||||
Alias: "client-gpt",
|
||||
APIKey: "test-key",
|
||||
AuthIndex: "0",
|
||||
AuthType: "apikey",
|
||||
Source: "user@example.com",
|
||||
ReasoningEffort: "medium",
|
||||
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
|
||||
Latency: 1500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
ResponseHeaders: responseHeaders.Clone(),
|
||||
})
|
||||
responseHeaders.Set("Retry-After", "999")
|
||||
|
||||
payload := popSinglePayload(t)
|
||||
requireStringField(t, payload, "provider", "openai")
|
||||
requireStringField(t, payload, "model", "gpt-5.4")
|
||||
requireStringField(t, payload, "alias", "client-gpt")
|
||||
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
|
||||
requireStringField(t, payload, "auth_type", "apikey")
|
||||
requireMissingField(t, payload, "user_api_key")
|
||||
requireStringField(t, payload, "request_id", "ctx-request-id")
|
||||
requireStringField(t, payload, "reasoning_effort", "medium")
|
||||
requireHeaderField(t, payload, "response_headers", "X-Upstream-Request-Id", []string{"upstream-req-1"})
|
||||
requireHeaderField(t, payload, "response_headers", "Retry-After", []string{"30"})
|
||||
requireBoolField(t, payload, "failed", false)
|
||||
requireFailField(t, payload, http.StatusOK, "")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsageQueuePluginAsyncUsesRecordResponseHeaders(t *testing.T) {
|
||||
withEnabledQueue(t, func() {
|
||||
ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id")
|
||||
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
|
||||
ctx = internallogging.WithResponseStatusHolder(ctx)
|
||||
ctx = internallogging.WithResponseHeadersHolder(ctx)
|
||||
internallogging.SetResponseStatus(ctx, http.StatusOK)
|
||||
initialHeaders := http.Header{}
|
||||
initialHeaders.Set("X-Upstream-Request-Id", "upstream-req-1")
|
||||
internallogging.SetResponseHeaders(ctx, initialHeaders)
|
||||
|
||||
mgr := coreusage.NewManager(16)
|
||||
defer mgr.Stop()
|
||||
|
||||
mgr.Register(pluginFunc(func(ctx context.Context, _ coreusage.Record) {
|
||||
nextHeaders := http.Header{}
|
||||
nextHeaders.Set("X-Upstream-Request-Id", "upstream-req-2")
|
||||
internallogging.SetResponseHeaders(ctx, nextHeaders)
|
||||
}))
|
||||
mgr.Register(&usageQueuePlugin{})
|
||||
|
||||
mgr.Publish(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4",
|
||||
Alias: "client-gpt",
|
||||
@@ -36,18 +96,11 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
ResponseHeaders: internallogging.GetResponseHeaders(ctx),
|
||||
})
|
||||
|
||||
payload := popSinglePayload(t)
|
||||
requireStringField(t, payload, "provider", "openai")
|
||||
requireStringField(t, payload, "model", "gpt-5.4")
|
||||
requireStringField(t, payload, "alias", "client-gpt")
|
||||
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
|
||||
requireStringField(t, payload, "auth_type", "apikey")
|
||||
requireMissingField(t, payload, "user_api_key")
|
||||
requireStringField(t, payload, "request_id", "ctx-request-id")
|
||||
requireBoolField(t, payload, "failed", false)
|
||||
requireFailField(t, payload, http.StatusOK, "")
|
||||
payload := waitForSinglePayload(t, 2*time.Second)
|
||||
requireHeaderField(t, payload, "response_headers", "X-Upstream-Request-Id", []string{"upstream-req-1"})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -276,3 +329,28 @@ func requireFailField(t *testing.T, payload map[string]json.RawMessage, wantStat
|
||||
t.Fatalf("fail = {status_code:%d body:%q}, want {status_code:%d body:%q}", got.StatusCode, got.Body, wantStatus, wantBody)
|
||||
}
|
||||
}
|
||||
|
||||
func requireHeaderField(t *testing.T, payload map[string]json.RawMessage, field, key string, want []string) {
|
||||
t.Helper()
|
||||
|
||||
raw, ok := payload[field]
|
||||
if !ok {
|
||||
t.Fatalf("payload missing %q", field)
|
||||
}
|
||||
var headers map[string][]string
|
||||
if err := json.Unmarshal(raw, &headers); err != nil {
|
||||
t.Fatalf("unmarshal %q: %v", field, err)
|
||||
}
|
||||
got, ok := headers[key]
|
||||
if !ok {
|
||||
t.Fatalf("%s missing header %q", field, key)
|
||||
}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("%s[%q] = %v, want %v", field, key, got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("%s[%q] = %v, want %v", field, key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package registry
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCodexFreeModelsExcludeGPT55(t *testing.T) {
|
||||
model := findModelInfo(GetCodexFreeModels(), "gpt-5.5")
|
||||
if model != nil {
|
||||
t.Fatal("expected codex free tier to NOT include gpt-5.5")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexStaticModelsIncludeGPT55(t *testing.T) {
|
||||
tierModels := map[string][]*ModelInfo{
|
||||
"team": GetCodexTeamModels(),
|
||||
"plus": GetCodexPlusModels(),
|
||||
"pro": GetCodexProModels(),
|
||||
}
|
||||
|
||||
for tier, models := range tierModels {
|
||||
t.Run(tier, func(t *testing.T) {
|
||||
model := findModelInfo(models, "gpt-5.5")
|
||||
if model == nil {
|
||||
t.Fatalf("expected codex %s tier to include gpt-5.5", tier)
|
||||
}
|
||||
assertGPT55ModelInfo(t, tier, model)
|
||||
})
|
||||
}
|
||||
|
||||
model := LookupStaticModelInfo("gpt-5.5")
|
||||
if model == nil {
|
||||
t.Fatal("expected LookupStaticModelInfo to find gpt-5.5")
|
||||
}
|
||||
assertGPT55ModelInfo(t, "lookup", model)
|
||||
}
|
||||
|
||||
func TestWithXAIBuiltinsAddsVideoModel(t *testing.T) {
|
||||
models := WithXAIBuiltins(nil)
|
||||
found := false
|
||||
for _, model := range models {
|
||||
if model != nil && model.ID == xaiBuiltinVideoModelID {
|
||||
found = true
|
||||
if model.OwnedBy != "xai" {
|
||||
t.Fatalf("OwnedBy = %q, want xai", model.OwnedBy)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected %s builtin model", xaiBuiltinVideoModelID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateModelsCatalogAllowsMissingSections(t *testing.T) {
|
||||
data := validTestModelsCatalog()
|
||||
data.XAI = nil
|
||||
|
||||
if err := validateModelsCatalog(data); err != nil {
|
||||
t.Fatalf("validateModelsCatalog() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateModelsCatalogRejectsInvalidDefinitions(t *testing.T) {
|
||||
data := validTestModelsCatalog()
|
||||
data.Claude = []*ModelInfo{{ID: ""}}
|
||||
|
||||
if err := validateModelsCatalog(data); err == nil {
|
||||
t.Fatal("expected invalid model definition error")
|
||||
}
|
||||
}
|
||||
|
||||
func validTestModelsCatalog() *staticModelsJSON {
|
||||
models := []*ModelInfo{{ID: "test-model"}}
|
||||
return &staticModelsJSON{
|
||||
Claude: models,
|
||||
Gemini: models,
|
||||
Vertex: models,
|
||||
GeminiCLI: models,
|
||||
AIStudio: models,
|
||||
CodexFree: models,
|
||||
CodexTeam: models,
|
||||
CodexPlus: models,
|
||||
CodexPro: models,
|
||||
Kimi: models,
|
||||
Antigravity: models,
|
||||
XAI: models,
|
||||
}
|
||||
}
|
||||
|
||||
func findModelInfo(models []*ModelInfo, id string) *ModelInfo {
|
||||
for _, model := range models {
|
||||
if model != nil && model.ID == id {
|
||||
return model
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertGPT55ModelInfo(t *testing.T, source string, model *ModelInfo) {
|
||||
t.Helper()
|
||||
|
||||
if model.ID != "gpt-5.5" {
|
||||
t.Fatalf("%s id mismatch: got %q", source, model.ID)
|
||||
}
|
||||
if model.Object != "model" {
|
||||
t.Fatalf("%s object mismatch: got %q", source, model.Object)
|
||||
}
|
||||
if model.Created != 1776902400 {
|
||||
t.Fatalf("%s created timestamp mismatch: got %d", source, model.Created)
|
||||
}
|
||||
if model.OwnedBy != "openai" {
|
||||
t.Fatalf("%s owned_by mismatch: got %q", source, model.OwnedBy)
|
||||
}
|
||||
if model.Type != "openai" {
|
||||
t.Fatalf("%s type mismatch: got %q", source, model.Type)
|
||||
}
|
||||
if model.DisplayName != "GPT 5.5" {
|
||||
t.Fatalf("%s display name mismatch: got %q", source, model.DisplayName)
|
||||
}
|
||||
if model.Version != "gpt-5.5" {
|
||||
t.Fatalf("%s version mismatch: got %q", source, model.Version)
|
||||
}
|
||||
if model.Description != "Frontier model for complex coding, research, and real-world work." {
|
||||
t.Fatalf("%s description mismatch: got %q", source, model.Description)
|
||||
}
|
||||
if model.ContextLength != 272000 {
|
||||
t.Fatalf("%s context length mismatch: got %d", source, model.ContextLength)
|
||||
}
|
||||
if model.MaxCompletionTokens != 128000 {
|
||||
t.Fatalf("%s max completion tokens mismatch: got %d", source, model.MaxCompletionTokens)
|
||||
}
|
||||
if len(model.SupportedParameters) != 1 || model.SupportedParameters[0] != "tools" {
|
||||
t.Fatalf("%s supported parameters mismatch: got %v", source, model.SupportedParameters)
|
||||
}
|
||||
if model.Thinking == nil {
|
||||
t.Fatalf("%s missing thinking support", source)
|
||||
}
|
||||
|
||||
want := []string{"low", "medium", "high", "xhigh"}
|
||||
if len(model.Thinking.Levels) != len(want) {
|
||||
t.Fatalf("%s thinking level count mismatch: got %d, want %d", source, len(model.Thinking.Levels), len(want))
|
||||
}
|
||||
for i, level := range want {
|
||||
if model.Thinking.Levels[i] != level {
|
||||
t.Fatalf("%s thinking level %d mismatch: got %q, want %q", source, i, model.Thinking.Levels[i], level)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,9 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OpenAIImageModelType marks models that are callable through OpenAI-compatible image endpoints.
|
||||
const OpenAIImageModelType = "openai-image"
|
||||
|
||||
// ModelInfo represents information about an available model
|
||||
type ModelInfo struct {
|
||||
// ID is the unique identifier for the model
|
||||
|
||||
@@ -421,6 +421,36 @@
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.5-flash",
|
||||
"object": "model",
|
||||
"created": 1779235200,
|
||||
"owned_by": "google",
|
||||
"type": "gemini",
|
||||
"display_name": "Gemini 3.5 Flash",
|
||||
"name": "models/gemini-3.5-flash",
|
||||
"version": "3.5",
|
||||
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
"inputTokenLimit": 1048576,
|
||||
"outputTokenLimit": 65536,
|
||||
"supportedGenerationMethods": [
|
||||
"generateContent",
|
||||
"countTokens",
|
||||
"createCachedContent",
|
||||
"batchGenerateContent"
|
||||
],
|
||||
"thinking": {
|
||||
"min": 128,
|
||||
"max": 32768,
|
||||
"dynamic_allowed": true,
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"vertex": [
|
||||
@@ -762,6 +792,36 @@
|
||||
"supportedGenerationMethods": [
|
||||
"predict"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.5-flash",
|
||||
"object": "model",
|
||||
"created": 1779235200,
|
||||
"owned_by": "google",
|
||||
"type": "gemini",
|
||||
"display_name": "Gemini 3.5 Flash",
|
||||
"name": "models/gemini-3.5-flash",
|
||||
"version": "3.5",
|
||||
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
"inputTokenLimit": 1048576,
|
||||
"outputTokenLimit": 65536,
|
||||
"supportedGenerationMethods": [
|
||||
"generateContent",
|
||||
"countTokens",
|
||||
"createCachedContent",
|
||||
"batchGenerateContent"
|
||||
],
|
||||
"thinking": {
|
||||
"min": 128,
|
||||
"max": 32768,
|
||||
"dynamic_allowed": true,
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"gemini-cli": [
|
||||
@@ -1221,6 +1281,36 @@
|
||||
"createCachedContent",
|
||||
"batchGenerateContent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.5-flash",
|
||||
"object": "model",
|
||||
"created": 1779235200,
|
||||
"owned_by": "google",
|
||||
"type": "gemini",
|
||||
"display_name": "Gemini 3.5 Flash",
|
||||
"name": "models/gemini-3.5-flash",
|
||||
"version": "3.5",
|
||||
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
"inputTokenLimit": 1048576,
|
||||
"outputTokenLimit": 65536,
|
||||
"supportedGenerationMethods": [
|
||||
"generateContent",
|
||||
"countTokens",
|
||||
"createCachedContent",
|
||||
"batchGenerateContent"
|
||||
],
|
||||
"thinking": {
|
||||
"min": 128,
|
||||
"max": 32768,
|
||||
"dynamic_allowed": true,
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"codex-free": [
|
||||
@@ -1954,6 +2044,28 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-3-flash-agent",
|
||||
"object": "model",
|
||||
"owned_by": "antigravity",
|
||||
"type": "antigravity",
|
||||
"display_name": "Gemini 3.5 Flash",
|
||||
"name": "gemini-3-flash-agent",
|
||||
"description": "Gemini 3.5 Flash",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"thinking": {
|
||||
"min": 128,
|
||||
"max": 32768,
|
||||
"dynamic_allowed": true,
|
||||
"levels": [
|
||||
"minimal",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-3-pro-high",
|
||||
"object": "model",
|
||||
@@ -2087,9 +2199,52 @@
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "gemini-3.5-flash-low",
|
||||
"object": "model",
|
||||
"owned_by": "antigravity",
|
||||
"type": "antigravity",
|
||||
"display_name": "Gemini 3.5 Flash (Low)",
|
||||
"name": "gemini-3.5-flash-low",
|
||||
"description": "Gemini 3.5 Flash (Low)",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65535,
|
||||
"thinking": {
|
||||
"min": 1,
|
||||
"max": 65535,
|
||||
"dynamic_allowed": true,
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
],
|
||||
"xai": [
|
||||
{
|
||||
"id": "grok-build-0.1",
|
||||
"object": "model",
|
||||
"created": 1779321600,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok Build 0.1",
|
||||
"name": "grok-build-0.1",
|
||||
"description": "Grok Build 0.1 is xAI’s fast coding model trained specifically for agentic software engineering workflows.",
|
||||
"context_length": 256000,
|
||||
"max_completion_tokens": 256000,
|
||||
"thinking": {
|
||||
"zero_allowed": true,
|
||||
"levels": [
|
||||
"none",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "grok-4.3",
|
||||
"object": "model",
|
||||
|
||||
@@ -1415,6 +1415,41 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ShouldPrepareRequestAuth(auth *cliproxyauth.Auth) bool {
|
||||
return antigravityProjectIDFromAuth(auth) == ""
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) PrepareRequestAuth(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
if auth == nil || !e.ShouldPrepareRequestAuth(auth) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
updated := auth.Clone()
|
||||
token, refreshedAuth, errToken := e.ensureAccessToken(ctx, updated)
|
||||
if errToken != nil {
|
||||
return nil, errToken
|
||||
}
|
||||
if refreshedAuth != nil {
|
||||
updated = refreshedAuth
|
||||
}
|
||||
if antigravityProjectIDFromAuth(updated) != "" {
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
projectID, errProject := e.fetchAntigravityProjectID(ctx, updated, token)
|
||||
if errProject != nil {
|
||||
return nil, missingAntigravityProjectIDError(errProject)
|
||||
}
|
||||
if projectID == "" {
|
||||
return nil, missingAntigravityProjectIDError(nil)
|
||||
}
|
||||
if updated.Metadata == nil {
|
||||
updated.Metadata = make(map[string]any)
|
||||
}
|
||||
updated.Metadata["project_id"] = projectID
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request using the Antigravity API.
|
||||
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
@@ -1752,34 +1787,67 @@ func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, au
|
||||
return nil
|
||||
}
|
||||
|
||||
if auth.Metadata["project_id"] != nil {
|
||||
if antigravityProjectIDFromAuth(auth) != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(accessToken)
|
||||
if token == "" {
|
||||
token = metaStringValue(auth.Metadata, "access_token")
|
||||
}
|
||||
if token == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||
projectID, errFetch := e.fetchAntigravityProjectID(ctx, auth, accessToken)
|
||||
if errFetch != nil {
|
||||
return errFetch
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
if projectID == "" {
|
||||
return nil
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
|
||||
auth.Metadata["project_id"] = projectID
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) fetchAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) (string, error) {
|
||||
token := strings.TrimSpace(accessToken)
|
||||
if token == "" {
|
||||
token = metaStringValue(auth.Metadata, "access_token")
|
||||
}
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||
if errFetch != nil {
|
||||
return "", errFetch
|
||||
}
|
||||
return strings.TrimSpace(projectID), nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) projectIDForRequest(_ context.Context, auth *cliproxyauth.Auth, _ string) (string, error) {
|
||||
if projectID := antigravityProjectIDFromAuth(auth); projectID != "" {
|
||||
return projectID, nil
|
||||
}
|
||||
return "", missingAntigravityProjectIDError(nil)
|
||||
}
|
||||
|
||||
func antigravityProjectIDFromAuth(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok {
|
||||
return strings.TrimSpace(pid)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func missingAntigravityProjectIDError(cause error) statusErr {
|
||||
msg := "antigravity auth missing project_id"
|
||||
if cause != nil {
|
||||
msg = fmt.Sprintf("%s: %v", msg, cause)
|
||||
}
|
||||
return statusErr{code: http.StatusBadRequest, msg: msg}
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) {
|
||||
if auth == nil || strings.TrimSpace(auth.ID) == "" {
|
||||
return
|
||||
@@ -1792,19 +1860,17 @@ func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Contex
|
||||
return
|
||||
}
|
||||
|
||||
userAgent := resolveLoadCodeAssistUserAgent(auth)
|
||||
userAgent := resolveUserAgent(auth)
|
||||
loadReqBody, errMarshal := json.Marshal(map[string]any{
|
||||
"metadata": map[string]string{
|
||||
"ide_type": "ANTIGRAVITY",
|
||||
"ide_version": misc.AntigravityVersionFromUserAgent(userAgent),
|
||||
"ide_name": "antigravity",
|
||||
"ideType": "ANTIGRAVITY",
|
||||
},
|
||||
})
|
||||
if errMarshal != nil {
|
||||
log.Debugf("antigravity executor: marshal loadCodeAssist request error: %v", errMarshal)
|
||||
return
|
||||
}
|
||||
baseURL := buildBaseURL(auth)
|
||||
baseURL := antigravityLoadCodeAssistBaseURL(auth)
|
||||
endpointURL := strings.TrimSuffix(baseURL, "/") + "/v1internal:loadCodeAssist"
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, bytes.NewReader(loadReqBody))
|
||||
if errReq != nil {
|
||||
@@ -1812,9 +1878,9 @@ func (e *AntigravityExecutor) updateAntigravityCreditsBalance(ctx context.Contex
|
||||
return
|
||||
}
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("Accept", "*/*")
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("User-Agent", userAgent)
|
||||
httpReq.Header.Set("X-Goog-Api-Client", misc.AntigravityGoogAPIClientUA)
|
||||
|
||||
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
@@ -1909,12 +1975,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
requestURL.WriteString(url.QueryEscape(alt))
|
||||
}
|
||||
|
||||
// Extract project_id from auth metadata if available
|
||||
projectID := ""
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if pid, ok := auth.Metadata["project_id"].(string); ok {
|
||||
projectID = strings.TrimSpace(pid)
|
||||
}
|
||||
projectID, errProject := e.projectIDForRequest(ctx, auth, token)
|
||||
if errProject != nil {
|
||||
return nil, errProject
|
||||
}
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
@@ -2100,6 +2163,13 @@ func buildBaseURL(auth *cliproxyauth.Auth) string {
|
||||
return antigravityBaseURLDaily
|
||||
}
|
||||
|
||||
func antigravityLoadCodeAssistBaseURL(auth *cliproxyauth.Auth) string {
|
||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||
return base
|
||||
}
|
||||
return antigravityBaseURLProd
|
||||
}
|
||||
|
||||
func resolveHost(base string) string {
|
||||
parsed, errParse := url.Parse(base)
|
||||
if errParse != nil {
|
||||
@@ -2338,11 +2408,10 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
}
|
||||
template, _ = sjson.SetBytes(template, "requestType", reqType)
|
||||
|
||||
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
|
||||
if projectID != "" {
|
||||
template, _ = sjson.SetBytes(template, "project", projectID)
|
||||
} else {
|
||||
template, _ = sjson.SetBytes(template, "project", generateProjectID())
|
||||
template, _ = sjson.DeleteBytes(template, "project")
|
||||
}
|
||||
|
||||
if isImageModel {
|
||||
@@ -2391,14 +2460,3 @@ func generateStableSessionID(payload []byte) string {
|
||||
}
|
||||
return generateSessionID()
|
||||
}
|
||||
|
||||
func generateProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
randSourceMutex.Lock()
|
||||
adj := adjectives[randSource.Intn(len(adjectives))]
|
||||
noun := nouns[randSource.Intn(len(nouns))]
|
||||
randSourceMutex.Unlock()
|
||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||
return adj + "-" + noun + "-" + randomPart
|
||||
}
|
||||
|
||||
@@ -4,7 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
)
|
||||
@@ -90,6 +93,82 @@ func TestAntigravityBuildRequest_SkipsSchemaSanitizationWithEmptyToolsArray(t *t
|
||||
assertNonSchemaRequestPreserved(t, body)
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_UsesAuthProjectID(t *testing.T) {
|
||||
body := buildRequestBodyFromRawPayload(t, "gemini-3.1-pro", []byte(`{
|
||||
"request": {
|
||||
"contents": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "hello"}]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`))
|
||||
|
||||
if got, ok := body["project"].(string); !ok || got != "project-1" {
|
||||
t.Fatalf("project should come from auth metadata, got=%v", body["project"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityPrepareRequestAuth_FetchesMissingProjectID(t *testing.T) {
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{Metadata: map[string]any{
|
||||
"access_token": "token",
|
||||
"expired": time.Now().Add(1 * time.Hour).Format(time.RFC3339),
|
||||
}}
|
||||
ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" {
|
||||
t.Fatalf("unexpected project discovery request: %s", req.URL.String())
|
||||
}
|
||||
if got := req.Header.Get("X-Goog-Api-Client"); got != "" {
|
||||
t.Fatalf("X-Goog-Api-Client = %q, want empty", got)
|
||||
}
|
||||
raw, errRead := io.ReadAll(req.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read discovery body: %v", errRead)
|
||||
}
|
||||
if !strings.Contains(string(raw), `"ideType":"ANTIGRAVITY"`) {
|
||||
t.Fatalf("unexpected discovery body: %s", string(raw))
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(`{"cloudaicompanionProject":"fetched-project"}`)),
|
||||
}, nil
|
||||
}))
|
||||
|
||||
updated, err := executor.PrepareRequestAuth(ctx, auth)
|
||||
if err != nil {
|
||||
t.Fatalf("PrepareRequestAuth error: %v", err)
|
||||
}
|
||||
if updated == nil {
|
||||
t.Fatalf("PrepareRequestAuth returned nil auth")
|
||||
}
|
||||
if _, ok := auth.Metadata["project_id"]; ok {
|
||||
t.Fatalf("original auth metadata should not be mutated")
|
||||
}
|
||||
if got, ok := updated.Metadata["project_id"].(string); !ok || got != "fetched-project" {
|
||||
t.Fatalf("updated auth metadata project_id = %v, want fetched-project", updated.Metadata["project_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityBuildRequest_RejectsMissingProjectID(t *testing.T) {
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{Metadata: map[string]any{}}
|
||||
|
||||
_, err := executor.buildRequest(context.Background(), auth, "token", "gemini-3.1-pro", []byte(`{"request":{}}`), false, "", "https://example.com")
|
||||
if err == nil {
|
||||
t.Fatalf("buildRequest should fail when auth has no project_id")
|
||||
}
|
||||
status, ok := err.(interface{ StatusCode() int })
|
||||
if !ok {
|
||||
t.Fatalf("error should expose status code, got %T", err)
|
||||
}
|
||||
if got := status.StatusCode(); got != http.StatusBadRequest {
|
||||
t.Fatalf("status code = %d, want %d", got, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNonSchemaRequestPreserved(t *testing.T, body map[string]any) {
|
||||
t.Helper()
|
||||
|
||||
@@ -172,13 +251,19 @@ func buildRequestBodyFromRawPayload(t *testing.T, modelName string, payload []by
|
||||
t.Helper()
|
||||
|
||||
executor := &AntigravityExecutor{}
|
||||
auth := &cliproxyauth.Auth{}
|
||||
auth := &cliproxyauth.Auth{Metadata: map[string]any{"project_id": "project-1"}}
|
||||
|
||||
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("buildRequest error: %v", err)
|
||||
}
|
||||
|
||||
return requestBody(t, req)
|
||||
}
|
||||
|
||||
func requestBody(t *testing.T, req *http.Request) map[string]any {
|
||||
t.Helper()
|
||||
|
||||
raw, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body error: %v", err)
|
||||
|
||||
@@ -444,24 +444,25 @@ func TestUpdateAntigravityCreditsBalance_LoadCodeAssistUserAgent(t *testing.T) {
|
||||
t.Cleanup(resetAntigravityCreditsRetryState)
|
||||
|
||||
exec := NewAntigravityExecutor(&config.Config{})
|
||||
const userAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0"
|
||||
const configuredUserAgent = "antigravity/1.23.2 windows/amd64 google-api-nodejs-client/10.3.0"
|
||||
const loadCodeAssistUserAgent = "antigravity/1.23.2 windows/amd64"
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "auth-load-code-assist-ua",
|
||||
Attributes: map[string]string{"user_agent": userAgent},
|
||||
Attributes: map[string]string{"user_agent": configuredUserAgent},
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), "cliproxy.roundtripper", roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.String() != "https://cloudcode-pa.googleapis.com/v1internal:loadCodeAssist" {
|
||||
t.Fatalf("unexpected request url %s", req.URL.String())
|
||||
}
|
||||
if got := req.Header.Get("User-Agent"); got != userAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", got, userAgent)
|
||||
if got := req.Header.Get("User-Agent"); got != loadCodeAssistUserAgent {
|
||||
t.Fatalf("User-Agent = %q, want %q", got, loadCodeAssistUserAgent)
|
||||
}
|
||||
if got := req.Header.Get("X-Goog-Api-Client"); got != "gl-node/22.21.1" {
|
||||
t.Fatalf("X-Goog-Api-Client = %q, want %q", got, "gl-node/22.21.1")
|
||||
if got := req.Header.Get("X-Goog-Api-Client"); got != "" {
|
||||
t.Fatalf("X-Goog-Api-Client = %q, want empty", got)
|
||||
}
|
||||
body, _ := io.ReadAll(req.Body)
|
||||
_ = req.Body.Close()
|
||||
if string(body) != `{"metadata":{"ide_name":"antigravity","ide_type":"ANTIGRAVITY","ide_version":"1.23.2"}}` {
|
||||
if string(body) != `{"metadata":{"ideType":"ANTIGRAVITY"}}` {
|
||||
t.Fatalf("loadCodeAssist body = %s", string(body))
|
||||
}
|
||||
return &http.Response{
|
||||
|
||||
@@ -100,6 +100,103 @@ func patchCodexCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]
|
||||
return completedDataPatched
|
||||
}
|
||||
|
||||
func codexTerminalStreamContextLengthErr(eventData []byte) (statusErr, bool) {
|
||||
eventType := gjson.GetBytes(eventData, "type").String()
|
||||
var body []byte
|
||||
switch eventType {
|
||||
case "error":
|
||||
body = codexTerminalErrorBody(eventData, "error")
|
||||
if len(body) == 0 {
|
||||
body = codexTerminalTopLevelErrorBody(eventData)
|
||||
}
|
||||
case "response.failed":
|
||||
body = codexTerminalErrorBody(eventData, "response.error")
|
||||
if len(body) == 0 {
|
||||
body = codexTerminalErrorBody(eventData, "error")
|
||||
}
|
||||
default:
|
||||
return statusErr{}, false
|
||||
}
|
||||
if len(body) == 0 {
|
||||
return statusErr{}, false
|
||||
}
|
||||
if !codexTerminalErrorIsContextLength(body) {
|
||||
return statusErr{}, false
|
||||
}
|
||||
return newCodexStatusErr(http.StatusBadRequest, body), true
|
||||
}
|
||||
|
||||
func codexTerminalErrorBody(eventData []byte, path string) []byte {
|
||||
errorResult := gjson.GetBytes(eventData, path)
|
||||
if !errorResult.Exists() {
|
||||
return nil
|
||||
}
|
||||
body := []byte(`{"error":{}}`)
|
||||
if errorResult.Type == gjson.JSON {
|
||||
body, _ = sjson.SetRawBytes(body, "error", []byte(errorResult.Raw))
|
||||
} else if message := strings.TrimSpace(errorResult.String()); message != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", message)
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
|
||||
if message := strings.TrimSpace(gjson.GetBytes(eventData, "response.error.message").String()); message != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", message)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
|
||||
if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", code)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
|
||||
if errorType := strings.TrimSpace(gjson.GetBytes(body, "error.type").String()); errorType != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", errorType)
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func codexTerminalTopLevelErrorBody(eventData []byte) []byte {
|
||||
message := strings.TrimSpace(gjson.GetBytes(eventData, "message").String())
|
||||
code := strings.TrimSpace(gjson.GetBytes(eventData, "code").String())
|
||||
errorType := strings.TrimSpace(gjson.GetBytes(eventData, "error_type").String())
|
||||
param := strings.TrimSpace(gjson.GetBytes(eventData, "param").String())
|
||||
if message == "" && code == "" && errorType == "" && param == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
body := []byte(`{"error":{}}`)
|
||||
if message != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", message)
|
||||
}
|
||||
if code != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.code", code)
|
||||
}
|
||||
if errorType != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.type", errorType)
|
||||
}
|
||||
if param != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.param", param)
|
||||
}
|
||||
if strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) == "" {
|
||||
if code != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", code)
|
||||
} else if errorType != "" {
|
||||
body, _ = sjson.SetBytes(body, "error.message", errorType)
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func codexTerminalErrorIsContextLength(body []byte) bool {
|
||||
errorCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "error.code").String()))
|
||||
message := strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "error.message").String()))
|
||||
return errorCode == "context_length_exceeded" ||
|
||||
errorCode == "context_too_large" ||
|
||||
strings.Contains(message, "context window") ||
|
||||
strings.Contains(message, "context length") ||
|
||||
strings.Contains(message, "too many tokens")
|
||||
}
|
||||
|
||||
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
|
||||
// If api_key is unavailable on auth, it falls back to legacy via ClientAdapter.
|
||||
type CodexExecutor struct {
|
||||
@@ -147,6 +244,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if opts.Alt == "responses/compact" {
|
||||
return e.executeCompact(ctx, auth, req, opts)
|
||||
}
|
||||
if isCodexOpenAIImageRequest(opts) {
|
||||
return e.executeOpenAIImage(ctx, auth, req, opts)
|
||||
}
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
apiKey, baseURL := codexCreds(auth)
|
||||
@@ -246,6 +346,11 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
eventData := bytes.TrimSpace(line[5:])
|
||||
eventType := gjson.GetBytes(eventData, "type").String()
|
||||
|
||||
if streamErr, ok := codexTerminalStreamContextLengthErr(eventData); ok {
|
||||
err = streamErr
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if eventType == "response.output_item.done" {
|
||||
itemResult := gjson.GetBytes(eventData, "item")
|
||||
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
|
||||
@@ -397,6 +502,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if opts.Alt == "responses/compact" {
|
||||
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
|
||||
}
|
||||
if isCodexOpenAIImageRequest(opts) {
|
||||
return e.executeOpenAIImageStream(ctx, auth, req, opts)
|
||||
}
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
apiKey, baseURL := codexCreds(auth)
|
||||
@@ -500,6 +608,15 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if streamErr, ok := codexTerminalStreamContextLengthErr(data); ok {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, streamErr)
|
||||
reporter.PublishFailure(ctx, streamErr)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: streamErr}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
switch gjson.GetBytes(data, "type").String() {
|
||||
case "response.output_item.done":
|
||||
collectCodexOutputItemDone(data, outputItemsByIndex, &outputItemsFallback)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
@@ -46,6 +47,128 @@ func TestCodexExecutorExecute_EmptyStreamCompletionOutputUsesOutputItemDone(t *t
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorExecuteSurfacesTerminalStreamError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: response.created\n"))
|
||||
_, _ = w.Write([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.5"}}` + "\n\n"))
|
||||
_, _ = w.Write([]byte("event: error\n"))
|
||||
_, _ = w.Write([]byte(`data: {"type":"error","error":{"type":"invalid_request_error","code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again.","param":"input"},"sequence_number":2}` + "\n\n"))
|
||||
_, _ = w.Write([]byte("event: response.failed\n"))
|
||||
_, _ = w.Write([]byte(`data: {"type":"response.failed","response":{"id":"resp_1","status":"failed","error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."}}}` + "\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"api_key": "test",
|
||||
}}
|
||||
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.5",
|
||||
Payload: []byte(`{"model":"gpt-5.5","input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
Stream: false,
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected terminal stream error, got nil")
|
||||
}
|
||||
if got := statusCodeFromTestError(t, err); got != http.StatusBadRequest {
|
||||
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, err)
|
||||
}
|
||||
assertCodexErrorCode(t, err.Error(), "invalid_request_error", "context_too_large")
|
||||
if !strings.Contains(err.Error(), "Your input exceeds the context window") {
|
||||
t.Fatalf("error message missing upstream context text: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorExecuteStreamSurfacesTerminalStreamError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: response.created\n"))
|
||||
_, _ = w.Write([]byte(`data: {"type":"response.created","response":{"id":"resp_1","model":"gpt-5.5"}}` + "\n\n"))
|
||||
_, _ = w.Write([]byte("event: error\n"))
|
||||
_, _ = w.Write([]byte(`data: {"type":"error","error":{"type":"invalid_request_error","code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again.","param":"input"},"sequence_number":2}` + "\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewCodexExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"api_key": "test",
|
||||
}}
|
||||
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "gpt-5.5",
|
||||
Payload: []byte(`{"model":"gpt-5.5","input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-response"),
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var streamErr error
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
streamErr = chunk.Err
|
||||
break
|
||||
}
|
||||
}
|
||||
if streamErr == nil {
|
||||
t.Fatal("missing stream terminal error")
|
||||
}
|
||||
if got := statusCodeFromTestError(t, streamErr); got != http.StatusBadRequest {
|
||||
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, streamErr)
|
||||
}
|
||||
assertCodexErrorCode(t, streamErr.Error(), "invalid_request_error", "context_too_large")
|
||||
}
|
||||
|
||||
func TestCodexTerminalStreamContextLengthErrFromResponseFailed(t *testing.T) {
|
||||
err, ok := codexTerminalStreamContextLengthErr([]byte(`{"type":"response.failed","response":{"id":"resp_1","status":"failed","error":{"code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again."}}}`))
|
||||
if !ok {
|
||||
t.Fatal("expected context length terminal error")
|
||||
}
|
||||
if got := statusCodeFromTestError(t, err); got != http.StatusBadRequest {
|
||||
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, err)
|
||||
}
|
||||
assertCodexErrorCode(t, err.Error(), "invalid_request_error", "context_too_large")
|
||||
}
|
||||
|
||||
func TestCodexTerminalStreamContextLengthErrFromTopLevelError(t *testing.T) {
|
||||
err, ok := codexTerminalStreamContextLengthErr([]byte(`{"type":"error","code":"context_length_exceeded","message":"Your input exceeds the context window of this model. Please adjust your input and try again.","sequence_number":2}`))
|
||||
if !ok {
|
||||
t.Fatal("expected top-level context length terminal error")
|
||||
}
|
||||
if got := statusCodeFromTestError(t, err); got != http.StatusBadRequest {
|
||||
t.Fatalf("status code = %d, want %d; err=%v", got, http.StatusBadRequest, err)
|
||||
}
|
||||
assertCodexErrorCode(t, err.Error(), "invalid_request_error", "context_too_large")
|
||||
if !strings.Contains(err.Error(), "Your input exceeds the context window") {
|
||||
t.Fatalf("error message missing upstream context text: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexTerminalStreamContextLengthErrIgnoresOtherTerminalErrors(t *testing.T) {
|
||||
_, ok := codexTerminalStreamContextLengthErr([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"Rate limit reached."}}`))
|
||||
if ok {
|
||||
t.Fatal("rate limit terminal error should not be handled by context length fix")
|
||||
}
|
||||
}
|
||||
|
||||
func statusCodeFromTestError(t *testing.T, err error) int {
|
||||
t.Helper()
|
||||
|
||||
statusErr, ok := err.(interface{ StatusCode() int })
|
||||
if !ok {
|
||||
t.Fatalf("error %T does not expose StatusCode(): %v", err, err)
|
||||
}
|
||||
return statusErr.StatusCode()
|
||||
}
|
||||
|
||||
func TestCodexExecutorExecuteStream_EmptyStreamCompletionOutputUsesOutputItemDone(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
@@ -0,0 +1,678 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
codexOpenAIImageSourceFormat = "openai-image"
|
||||
codexImagesGenerationsPath = "/v1/images/generations"
|
||||
codexImagesEditsPath = "/v1/images/edits"
|
||||
codexOpenAIImagesMainModel = "gpt-5.4-mini"
|
||||
)
|
||||
|
||||
type codexOpenAIImagePreparedRequest struct {
|
||||
Body []byte
|
||||
ResponseFormat string
|
||||
StreamPrefix string
|
||||
}
|
||||
|
||||
type codexImageCallResult struct {
|
||||
Result string
|
||||
RevisedPrompt string
|
||||
OutputFormat string
|
||||
Size string
|
||||
Background string
|
||||
Quality string
|
||||
}
|
||||
|
||||
func isCodexOpenAIImageRequest(opts cliproxyexecutor.Options) bool {
|
||||
if !strings.EqualFold(strings.TrimSpace(opts.SourceFormat.String()), codexOpenAIImageSourceFormat) {
|
||||
return false
|
||||
}
|
||||
return codexIsImagesEndpointPath(helps.PayloadRequestPath(opts))
|
||||
}
|
||||
|
||||
func codexIsImagesEndpointPath(path string) bool {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == codexImagesGenerationsPath || path == codexImagesEditsPath {
|
||||
return true
|
||||
}
|
||||
return strings.HasSuffix(path, codexImagesGenerationsPath) || strings.HasSuffix(path, codexImagesEditsPath)
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) executeOpenAIImage(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts)
|
||||
if errPrepare != nil {
|
||||
return resp, errPrepare
|
||||
}
|
||||
|
||||
apiKey, baseURL := codexCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts)
|
||||
if errBuild != nil {
|
||||
return resp, errBuild
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, errCache := e.cacheHelper(ctx, sdktranslator.FromString(codexOpenAIImageSourceFormat), url, req, body)
|
||||
if errCache != nil {
|
||||
return resp, errCache
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
recordCodexOpenAIImageRequest(ctx, e.cfg, e.Identifier(), auth, url, httpReq.Header.Clone(), body)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||
if !bytes.HasPrefix(line, dataTag) {
|
||||
continue
|
||||
}
|
||||
eventData := bytes.TrimSpace(line[len(dataTag):])
|
||||
switch gjson.GetBytes(eventData, "type").String() {
|
||||
case "response.output_item.done":
|
||||
collectCodexOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||
case "response.completed":
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
publishCodexImageToolUsage(ctx, reporter, body, eventData)
|
||||
completedData := patchCodexCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||
results, createdAt, usageRaw, firstMeta, errExtract := codexExtractImagesFromResponsesCompleted(completedData)
|
||||
if errExtract != nil {
|
||||
return resp, errExtract
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return resp, statusErr{code: http.StatusBadGateway, msg: "upstream did not return image output"}
|
||||
}
|
||||
out, errOutput := codexBuildImagesAPIResponse(results, createdAt, usageRaw, firstMeta, prepared.ResponseFormat)
|
||||
if errOutput != nil {
|
||||
return resp, errOutput
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
err = statusErr{code: http.StatusGatewayTimeout, msg: "stream error: stream disconnected before completion"}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) executeOpenAIImageStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
prepared, errPrepare := codexPrepareOpenAIImageRequest(req, opts)
|
||||
if errPrepare != nil {
|
||||
return nil, errPrepare
|
||||
}
|
||||
|
||||
apiKey, baseURL := codexCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = "https://chatgpt.com/backend-api/codex"
|
||||
}
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), codexOpenAIImagesMainModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
body, errBuild := e.prepareCodexOpenAIImageBody(prepared.Body, req, opts)
|
||||
if errBuild != nil {
|
||||
return nil, errBuild
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, errCache := e.cacheHelper(ctx, sdktranslator.FromString(codexOpenAIImageSourceFormat), url, req, body)
|
||||
if errCache != nil {
|
||||
return nil, errCache
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
recordCodexOpenAIImageRequest(ctx, e.cfg, e.Identifier(), auth, url, httpReq.Header.Clone(), body)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return nil, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
err = newCodexStatusErr(httpResp.StatusCode, data)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("codex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
sendPayload := func(payload []byte) bool {
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: payload}:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
sendError := func(errSend error) bool {
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errSend}:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if !bytes.HasPrefix(line, dataTag) {
|
||||
continue
|
||||
}
|
||||
eventData := bytes.TrimSpace(line[len(dataTag):])
|
||||
switch gjson.GetBytes(eventData, "type").String() {
|
||||
case "response.output_item.done":
|
||||
collectCodexOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||
case "response.image_generation_call.partial_image":
|
||||
frame := codexBuildImagePartialFrame(eventData, prepared.ResponseFormat, prepared.StreamPrefix)
|
||||
if len(frame) > 0 && !sendPayload(frame) {
|
||||
return
|
||||
}
|
||||
case "response.completed":
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
publishCodexImageToolUsage(ctx, reporter, body, eventData)
|
||||
completedData := patchCodexCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||
results, _, usageRaw, _, errExtract := codexExtractImagesFromResponsesCompleted(completedData)
|
||||
if errExtract != nil {
|
||||
sendError(errExtract)
|
||||
return
|
||||
}
|
||||
if len(results) == 0 {
|
||||
sendError(statusErr{code: http.StatusBadGateway, msg: "upstream did not return image output"})
|
||||
return
|
||||
}
|
||||
for _, img := range results {
|
||||
frame := codexBuildImageCompletedFrame(img, usageRaw, prepared.ResponseFormat, prepared.StreamPrefix)
|
||||
if len(frame) > 0 && !sendPayload(frame) {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx, errScan)
|
||||
sendError(errScan)
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) prepareCodexOpenAIImageBody(body []byte, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) ([]byte, error) {
|
||||
out := body
|
||||
var errThinking error
|
||||
out, errThinking = thinking.ApplyThinking(out, codexOpenAIImagesMainModel, codexOpenAIImageSourceFormat, "codex", e.Identifier())
|
||||
if errThinking != nil {
|
||||
return nil, errThinking
|
||||
}
|
||||
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
out = helps.ApplyPayloadConfigWithRequest(e.cfg, codexOpenAIImagesMainModel, "codex", codexOpenAIImageSourceFormat, "", out, body, requestedModel, requestPath, opts.Headers)
|
||||
out, _ = sjson.SetBytes(out, "model", codexOpenAIImagesMainModel)
|
||||
out, _ = sjson.SetBytes(out, "stream", true)
|
||||
out, _ = sjson.DeleteBytes(out, "previous_response_id")
|
||||
out, _ = sjson.DeleteBytes(out, "prompt_cache_retention")
|
||||
out, _ = sjson.DeleteBytes(out, "safety_identifier")
|
||||
out, _ = sjson.DeleteBytes(out, "stream_options")
|
||||
return normalizeCodexInstructions(out), nil
|
||||
}
|
||||
|
||||
func recordCodexOpenAIImageRequest(ctx context.Context, cfg *config.Config, provider string, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
Provider: provider,
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
|
||||
func codexPrepareOpenAIImageRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (codexOpenAIImagePreparedRequest, error) {
|
||||
path := helps.PayloadRequestPath(opts)
|
||||
if strings.HasSuffix(path, codexImagesGenerationsPath) {
|
||||
return codexPrepareOpenAIImageGenerationJSON(req.Payload, req.Model)
|
||||
}
|
||||
if !strings.HasSuffix(path, codexImagesEditsPath) {
|
||||
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("unsupported OpenAI image endpoint path %q", path)
|
||||
}
|
||||
|
||||
contentType := codexImageContentType(opts.Headers)
|
||||
mediaType, _, _ := mime.ParseMediaType(contentType)
|
||||
if strings.HasPrefix(strings.ToLower(mediaType), "multipart/") {
|
||||
return codexPrepareOpenAIImageEditMultipart(req.Payload, req.Model, contentType)
|
||||
}
|
||||
return codexPrepareOpenAIImageEditJSON(req.Payload, req.Model)
|
||||
}
|
||||
|
||||
func codexPrepareOpenAIImageGenerationJSON(rawJSON []byte, routeModel string) (codexOpenAIImagePreparedRequest, error) {
|
||||
if !json.Valid(rawJSON) {
|
||||
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("invalid OpenAI image generation request JSON")
|
||||
}
|
||||
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
|
||||
tool := codexBuildOpenAIImageTool(rawJSON, routeModel, "generate", []string{"size", "quality", "background", "output_format", "moderation"}, []string{"output_compression", "partial_images"})
|
||||
body := codexBuildImagesResponsesRequest(prompt, nil, tool)
|
||||
return codexOpenAIImagePreparedRequest{
|
||||
Body: body,
|
||||
ResponseFormat: codexOpenAIImageResponseFormatFromJSON(rawJSON),
|
||||
StreamPrefix: "image_generation",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexPrepareOpenAIImageEditJSON(rawJSON []byte, routeModel string) (codexOpenAIImagePreparedRequest, error) {
|
||||
if !json.Valid(rawJSON) {
|
||||
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("invalid OpenAI image edit request JSON")
|
||||
}
|
||||
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
|
||||
images := make([]string, 0)
|
||||
if imagesResult := gjson.GetBytes(rawJSON, "images"); imagesResult.IsArray() {
|
||||
for _, img := range imagesResult.Array() {
|
||||
url := strings.TrimSpace(img.Get("image_url").String())
|
||||
if url != "" {
|
||||
images = append(images, url)
|
||||
}
|
||||
}
|
||||
}
|
||||
tool := codexBuildOpenAIImageTool(rawJSON, routeModel, "edit", []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"}, []string{"output_compression", "partial_images"})
|
||||
if mask := strings.TrimSpace(gjson.GetBytes(rawJSON, "mask.image_url").String()); mask != "" {
|
||||
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", mask)
|
||||
}
|
||||
body := codexBuildImagesResponsesRequest(prompt, images, tool)
|
||||
return codexOpenAIImagePreparedRequest{
|
||||
Body: body,
|
||||
ResponseFormat: codexOpenAIImageResponseFormatFromJSON(rawJSON),
|
||||
StreamPrefix: "image_edit",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexPrepareOpenAIImageEditMultipart(rawBody []byte, routeModel string, contentType string) (codexOpenAIImagePreparedRequest, error) {
|
||||
_, params, errMedia := mime.ParseMediaType(contentType)
|
||||
if errMedia != nil {
|
||||
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("parse multipart content type failed: %w", errMedia)
|
||||
}
|
||||
boundary := strings.TrimSpace(params["boundary"])
|
||||
if boundary == "" {
|
||||
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("multipart boundary is required")
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(rawBody), boundary)
|
||||
form, errForm := reader.ReadForm(32 << 20)
|
||||
if errForm != nil {
|
||||
return codexOpenAIImagePreparedRequest{}, fmt.Errorf("parse multipart form failed: %w", errForm)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := form.RemoveAll(); errRemove != nil {
|
||||
log.Errorf("codex openai images: remove multipart temp files error: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
|
||||
prompt := strings.TrimSpace(codexFormValue(form, "prompt"))
|
||||
responseFormat := codexNormalizeImageResponseFormat(codexFormValue(form, "response_format"))
|
||||
tool := []byte(`{"type":"image_generation","action":"edit"}`)
|
||||
tool, _ = sjson.SetBytes(tool, "model", codexOpenAIImageToolModel(codexFormValue(form, "model"), routeModel))
|
||||
for _, field := range []string{"size", "quality", "background", "output_format", "input_fidelity", "moderation"} {
|
||||
if value := strings.TrimSpace(codexFormValue(form, field)); value != "" {
|
||||
tool, _ = sjson.SetBytes(tool, field, value)
|
||||
}
|
||||
}
|
||||
for _, field := range []string{"output_compression", "partial_images"} {
|
||||
if value := strings.TrimSpace(codexFormValue(form, field)); value != "" {
|
||||
if parsed, errParse := strconv.ParseInt(value, 10, 64); errParse == nil {
|
||||
tool, _ = sjson.SetBytes(tool, field, parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
images := make([]string, 0)
|
||||
for _, fh := range codexMultipartImageFiles(form) {
|
||||
dataURL, errData := codexMultipartFileToDataURL(fh)
|
||||
if errData != nil {
|
||||
return codexOpenAIImagePreparedRequest{}, errData
|
||||
}
|
||||
images = append(images, dataURL)
|
||||
}
|
||||
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
|
||||
dataURL, errData := codexMultipartFileToDataURL(maskFiles[0])
|
||||
if errData != nil {
|
||||
return codexOpenAIImagePreparedRequest{}, errData
|
||||
}
|
||||
tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", dataURL)
|
||||
}
|
||||
|
||||
body := codexBuildImagesResponsesRequest(prompt, images, tool)
|
||||
return codexOpenAIImagePreparedRequest{
|
||||
Body: body,
|
||||
ResponseFormat: responseFormat,
|
||||
StreamPrefix: "image_edit",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func codexImageContentType(headers http.Header) string {
|
||||
if headers == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(headers.Get("Content-Type"))
|
||||
}
|
||||
|
||||
func codexOpenAIImageResponseFormatFromJSON(rawJSON []byte) string {
|
||||
return codexNormalizeImageResponseFormat(gjson.GetBytes(rawJSON, "response_format").String())
|
||||
}
|
||||
|
||||
func codexNormalizeImageResponseFormat(responseFormat string) string {
|
||||
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
|
||||
return "url"
|
||||
}
|
||||
return "b64_json"
|
||||
}
|
||||
|
||||
func codexOpenAIImageToolModel(requestModel string, routeModel string) string {
|
||||
model := strings.TrimSpace(requestModel)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(routeModel)
|
||||
}
|
||||
if model == "" {
|
||||
model = codexDefaultImageToolModel
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func codexBuildOpenAIImageTool(rawJSON []byte, routeModel string, action string, stringFields []string, numberFields []string) []byte {
|
||||
tool := []byte(`{"type":"image_generation","action":""}`)
|
||||
tool, _ = sjson.SetBytes(tool, "action", action)
|
||||
tool, _ = sjson.SetBytes(tool, "model", codexOpenAIImageToolModel(gjson.GetBytes(rawJSON, "model").String(), routeModel))
|
||||
for _, field := range stringFields {
|
||||
if value := strings.TrimSpace(gjson.GetBytes(rawJSON, field).String()); value != "" {
|
||||
tool, _ = sjson.SetBytes(tool, field, value)
|
||||
}
|
||||
}
|
||||
for _, field := range numberFields {
|
||||
if value := gjson.GetBytes(rawJSON, field); value.Exists() && value.Type == gjson.Number {
|
||||
tool, _ = sjson.SetBytes(tool, field, value.Int())
|
||||
}
|
||||
}
|
||||
return tool
|
||||
}
|
||||
|
||||
func codexBuildImagesResponsesRequest(prompt string, images []string, toolJSON []byte) []byte {
|
||||
req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`)
|
||||
req, _ = sjson.SetBytes(req, "model", codexOpenAIImagesMainModel)
|
||||
|
||||
input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`)
|
||||
input, _ = sjson.SetBytes(input, "0.content.0.text", prompt)
|
||||
contentIndex := 1
|
||||
for _, img := range images {
|
||||
if strings.TrimSpace(img) == "" {
|
||||
continue
|
||||
}
|
||||
part := []byte(`{"type":"input_image","image_url":""}`)
|
||||
part, _ = sjson.SetBytes(part, "image_url", img)
|
||||
input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", contentIndex), part)
|
||||
contentIndex++
|
||||
}
|
||||
req, _ = sjson.SetRawBytes(req, "input", input)
|
||||
|
||||
req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`))
|
||||
if len(toolJSON) > 0 && json.Valid(toolJSON) {
|
||||
req, _ = sjson.SetRawBytes(req, "tools.-1", toolJSON)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func codexFormValue(form *multipart.Form, key string) string {
|
||||
if form == nil || len(form.Value[key]) == 0 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(form.Value[key][0])
|
||||
}
|
||||
|
||||
func codexMultipartImageFiles(form *multipart.Form) []*multipart.FileHeader {
|
||||
if form == nil {
|
||||
return nil
|
||||
}
|
||||
if files := form.File["image[]"]; len(files) > 0 {
|
||||
return files
|
||||
}
|
||||
return form.File["image"]
|
||||
}
|
||||
|
||||
func codexMultipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
|
||||
if fileHeader == nil {
|
||||
return "", fmt.Errorf("upload file is nil")
|
||||
}
|
||||
f, errOpen := fileHeader.Open()
|
||||
if errOpen != nil {
|
||||
return "", fmt.Errorf("open upload file failed: %w", errOpen)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("codex openai images: close upload file error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
data, errRead := io.ReadAll(f)
|
||||
if errRead != nil {
|
||||
return "", fmt.Errorf("read upload file failed: %w", errRead)
|
||||
}
|
||||
mediaType := strings.TrimSpace(fileHeader.Header.Get("Content-Type"))
|
||||
if mediaType == "" {
|
||||
mediaType = http.DetectContentType(data)
|
||||
}
|
||||
return "data:" + mediaType + ";base64," + base64.StdEncoding.EncodeToString(data), nil
|
||||
}
|
||||
|
||||
func codexExtractImagesFromResponsesCompleted(payload []byte) (results []codexImageCallResult, createdAt int64, usageRaw []byte, firstMeta codexImageCallResult, err error) {
|
||||
if gjson.GetBytes(payload, "type").String() != "response.completed" {
|
||||
return nil, 0, nil, codexImageCallResult{}, fmt.Errorf("unexpected event type")
|
||||
}
|
||||
createdAt = gjson.GetBytes(payload, "response.created_at").Int()
|
||||
if createdAt <= 0 {
|
||||
createdAt = time.Now().Unix()
|
||||
}
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if output.IsArray() {
|
||||
for _, item := range output.Array() {
|
||||
if item.Get("type").String() != "image_generation_call" {
|
||||
continue
|
||||
}
|
||||
res := strings.TrimSpace(item.Get("result").String())
|
||||
if res == "" {
|
||||
continue
|
||||
}
|
||||
entry := codexImageCallResult{
|
||||
Result: res,
|
||||
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
|
||||
OutputFormat: strings.TrimSpace(item.Get("output_format").String()),
|
||||
Size: strings.TrimSpace(item.Get("size").String()),
|
||||
Background: strings.TrimSpace(item.Get("background").String()),
|
||||
Quality: strings.TrimSpace(item.Get("quality").String()),
|
||||
}
|
||||
if len(results) == 0 {
|
||||
firstMeta = entry
|
||||
}
|
||||
results = append(results, entry)
|
||||
}
|
||||
}
|
||||
if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() {
|
||||
usageRaw = []byte(usage.Raw)
|
||||
}
|
||||
return results, createdAt, usageRaw, firstMeta, nil
|
||||
}
|
||||
|
||||
func codexBuildImagesAPIResponse(results []codexImageCallResult, createdAt int64, usageRaw []byte, firstMeta codexImageCallResult, responseFormat string) ([]byte, error) {
|
||||
out := []byte(`{"created":0,"data":[]}`)
|
||||
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||
responseFormat = codexNormalizeImageResponseFormat(responseFormat)
|
||||
for _, img := range results {
|
||||
item := []byte(`{}`)
|
||||
if responseFormat == "url" {
|
||||
item, _ = sjson.SetBytes(item, "url", "data:"+codexMimeTypeFromOutputFormat(img.OutputFormat)+";base64,"+img.Result)
|
||||
} else {
|
||||
item, _ = sjson.SetBytes(item, "b64_json", img.Result)
|
||||
}
|
||||
if img.RevisedPrompt != "" {
|
||||
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "data.-1", item)
|
||||
}
|
||||
if firstMeta.Background != "" {
|
||||
out, _ = sjson.SetBytes(out, "background", firstMeta.Background)
|
||||
}
|
||||
if firstMeta.OutputFormat != "" {
|
||||
out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat)
|
||||
}
|
||||
if firstMeta.Quality != "" {
|
||||
out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality)
|
||||
}
|
||||
if firstMeta.Size != "" {
|
||||
out, _ = sjson.SetBytes(out, "size", firstMeta.Size)
|
||||
}
|
||||
if len(usageRaw) > 0 && json.Valid(usageRaw) {
|
||||
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func codexBuildImagePartialFrame(payload []byte, responseFormat string, streamPrefix string) []byte {
|
||||
b64 := strings.TrimSpace(gjson.GetBytes(payload, "partial_image_b64").String())
|
||||
if b64 == "" {
|
||||
return nil
|
||||
}
|
||||
outputFormat := strings.TrimSpace(gjson.GetBytes(payload, "output_format").String())
|
||||
eventName := strings.TrimSpace(streamPrefix) + ".partial_image"
|
||||
data := []byte(`{"type":"","partial_image_index":0}`)
|
||||
data, _ = sjson.SetBytes(data, "type", eventName)
|
||||
data, _ = sjson.SetBytes(data, "partial_image_index", gjson.GetBytes(payload, "partial_image_index").Int())
|
||||
if codexNormalizeImageResponseFormat(responseFormat) == "url" {
|
||||
data, _ = sjson.SetBytes(data, "url", "data:"+codexMimeTypeFromOutputFormat(outputFormat)+";base64,"+b64)
|
||||
} else {
|
||||
data, _ = sjson.SetBytes(data, "b64_json", b64)
|
||||
}
|
||||
return codexBuildSSEFrame(eventName, data)
|
||||
}
|
||||
|
||||
func codexBuildImageCompletedFrame(img codexImageCallResult, usageRaw []byte, responseFormat string, streamPrefix string) []byte {
|
||||
eventName := strings.TrimSpace(streamPrefix) + ".completed"
|
||||
data := []byte(`{"type":""}`)
|
||||
data, _ = sjson.SetBytes(data, "type", eventName)
|
||||
if codexNormalizeImageResponseFormat(responseFormat) == "url" {
|
||||
data, _ = sjson.SetBytes(data, "url", "data:"+codexMimeTypeFromOutputFormat(img.OutputFormat)+";base64,"+img.Result)
|
||||
} else {
|
||||
data, _ = sjson.SetBytes(data, "b64_json", img.Result)
|
||||
}
|
||||
if len(usageRaw) > 0 && json.Valid(usageRaw) {
|
||||
data, _ = sjson.SetRawBytes(data, "usage", usageRaw)
|
||||
}
|
||||
return codexBuildSSEFrame(eventName, data)
|
||||
}
|
||||
|
||||
func codexBuildSSEFrame(eventName string, data []byte) []byte {
|
||||
var buf bytes.Buffer
|
||||
if strings.TrimSpace(eventName) != "" {
|
||||
buf.WriteString("event: ")
|
||||
buf.WriteString(eventName)
|
||||
buf.WriteString("\n")
|
||||
}
|
||||
buf.WriteString("data: ")
|
||||
buf.Write(data)
|
||||
buf.WriteString("\n\n")
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func codexMimeTypeFromOutputFormat(outputFormat string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(outputFormat)) {
|
||||
case "jpg", "jpeg":
|
||||
return "image/jpeg"
|
||||
case "webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
@@ -135,6 +136,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, to.String(), from.String(), "", body, originalTranslated, requestedModel, requestPath, opts.Headers)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = capGeminiMaxOutputTokens(body, baseModel)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
@@ -243,6 +245,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRequest(e.cfg, baseModel, to.String(), from.String(), "", body, originalTranslated, requestedModel, requestPath, opts.Headers)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = capGeminiMaxOutputTokens(body, baseModel)
|
||||
|
||||
baseURL := resolveGeminiBaseURL(auth)
|
||||
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "streamGenerateContent")
|
||||
@@ -527,6 +530,26 @@ func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
}
|
||||
|
||||
func capGeminiMaxOutputTokens(body []byte, modelName string) []byte {
|
||||
maxOut := gjson.GetBytes(body, "generationConfig.maxOutputTokens")
|
||||
if !maxOut.Exists() || maxOut.Type != gjson.Number {
|
||||
return body
|
||||
}
|
||||
modelInfo := registry.LookupModelInfo(modelName, "gemini")
|
||||
if modelInfo == nil {
|
||||
return body
|
||||
}
|
||||
limit := modelInfo.OutputTokenLimit
|
||||
if limit <= 0 {
|
||||
limit = modelInfo.MaxCompletionTokens
|
||||
}
|
||||
if limit <= 0 || maxOut.Int() <= int64(limit) {
|
||||
return body
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "generationConfig.maxOutputTokens", limit)
|
||||
return body
|
||||
}
|
||||
|
||||
func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
||||
if modelName == "gemini-2.5-flash-image-preview" {
|
||||
aspectRatioResult := gjson.GetBytes(rawJSON, "generationConfig.imageConfig.aspectRatio")
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCapGeminiMaxOutputTokensUsesOutputTokenLimit(t *testing.T) {
|
||||
body := []byte(`{"generationConfig":{"maxOutputTokens":500000,"temperature":0.2},"contents":[]}`)
|
||||
|
||||
out := capGeminiMaxOutputTokens(body, "gemini-3.1-pro-preview")
|
||||
|
||||
if got := gjson.GetBytes(out, "generationConfig.maxOutputTokens").Int(); got != 65536 {
|
||||
t.Fatalf("maxOutputTokens = %d, want 65536", got)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "generationConfig.temperature").Float(); got != 0.2 {
|
||||
t.Fatalf("temperature = %v, want 0.2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCapGeminiMaxOutputTokensLeavesAllowedOrUnknown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
body []byte
|
||||
want int64
|
||||
}{
|
||||
{
|
||||
name: "allowed value",
|
||||
model: "gemini-3.1-pro-preview",
|
||||
body: []byte(`{"generationConfig":{"maxOutputTokens":64000}}`),
|
||||
want: 64000,
|
||||
},
|
||||
{
|
||||
name: "unknown model",
|
||||
model: "custom-gemini-model",
|
||||
body: []byte(`{"generationConfig":{"maxOutputTokens":500000}}`),
|
||||
want: 500000,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
out := capGeminiMaxOutputTokens(tt.body, tt.model)
|
||||
if got := gjson.GetBytes(out, "generationConfig.maxOutputTokens").Int(); got != tt.want {
|
||||
t.Fatalf("maxOutputTokens = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiExecutorExecuteCapsMaxOutputTokensBeforeUpstream(t *testing.T) {
|
||||
var upstreamMaxOutputTokens int64
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
upstreamMaxOutputTokens = gjson.GetBytes(body, "generationConfig.maxOutputTokens").Int()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"candidates":[{"content":{"role":"model","parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewGeminiExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "test-key",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gemini-3.1-pro-preview",
|
||||
Payload: []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"maxOutputTokens":500000}}`),
|
||||
}
|
||||
|
||||
if _, err := exec.Execute(context.Background(), auth, req, cliproxyexecutor.Options{SourceFormat: sdktranslator.FormatGemini}); err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
if upstreamMaxOutputTokens != 65536 {
|
||||
t.Fatalf("upstream maxOutputTokens = %d, want 65536", upstreamMaxOutputTokens)
|
||||
}
|
||||
}
|
||||
@@ -102,6 +102,7 @@ func RecordAPIRequest(ctx context.Context, cfg *config.Config, info UpstreamRequ
|
||||
|
||||
// RecordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
|
||||
func RecordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
logging.SetResponseHeaders(ctx, headers)
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
@@ -227,6 +228,7 @@ func RecordAPIWebsocketRequest(ctx context.Context, cfg *config.Config, info Ups
|
||||
|
||||
// RecordAPIWebsocketHandshake stores the upstream websocket handshake response metadata.
|
||||
func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
|
||||
logging.SetResponseHeaders(ctx, headers)
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
@@ -250,6 +252,7 @@ func RecordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, status
|
||||
|
||||
// RecordAPIWebsocketUpgradeRejection stores a rejected websocket upgrade as an HTTP attempt.
|
||||
func RecordAPIWebsocketUpgradeRejection(ctx context.Context, cfg *config.Config, info UpstreamRequestLog, status int, headers http.Header, body []byte) {
|
||||
logging.SetResponseHeaders(ctx, headers)
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
||||
)
|
||||
|
||||
func TestRecordAPIResponseMetadataStoresHeadersWhenRequestLogDisabled(t *testing.T) {
|
||||
ctx := logging.WithResponseHeadersHolder(context.Background())
|
||||
headers := http.Header{}
|
||||
headers.Add("X-Upstream-Request-Id", "upstream-req-1")
|
||||
|
||||
RecordAPIResponseMetadata(ctx, &config.Config{}, http.StatusOK, headers)
|
||||
headers.Set("X-Upstream-Request-Id", "mutated")
|
||||
|
||||
got := logging.GetResponseHeaders(ctx)
|
||||
if got.Get("X-Upstream-Request-Id") != "upstream-req-1" {
|
||||
t.Fatalf("response header = %q, want %q", got.Get("X-Upstream-Request-Id"), "upstream-req-1")
|
||||
}
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func NewProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip
|
||||
return httpClient
|
||||
}
|
||||
// If proxy setup failed, log and fall through to context RoundTripper
|
||||
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL)
|
||||
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyutil.Redact(proxyURL))
|
||||
}
|
||||
|
||||
// Priority 3: Use RoundTripper from context (typically from RoundTripperFor)
|
||||
|
||||
@@ -8,4 +8,5 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/geminicli"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/kimi"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/openai"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/xai"
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v7/internal/logging"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -25,6 +26,7 @@ type UsageReporter struct {
|
||||
authType string
|
||||
apiKey string
|
||||
source string
|
||||
reasoning string
|
||||
requestedAt time.Time
|
||||
once sync.Once
|
||||
}
|
||||
@@ -43,6 +45,7 @@ func NewUsageReporter(ctx context.Context, provider, model string, auth *cliprox
|
||||
apiKey: apiKey,
|
||||
source: resolveUsageSource(auth, apiKey),
|
||||
authType: resolveUsageAuthType(auth),
|
||||
reasoning: usage.ReasoningEffortFromContext(ctx),
|
||||
}
|
||||
if auth != nil {
|
||||
reporter.authID = auth.ID
|
||||
@@ -60,7 +63,7 @@ func (r *UsageReporter) PublishAdditionalModel(ctx context.Context, model string
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
usage.PublishRecord(ctx, record)
|
||||
r.publishRecord(ctx, record)
|
||||
}
|
||||
|
||||
func (r *UsageReporter) buildAdditionalModelRecord(model string, detail usage.Detail) (usage.Record, bool) {
|
||||
@@ -97,7 +100,7 @@ func (r *UsageReporter) publishWithOutcome(ctx context.Context, detail usage.Det
|
||||
}
|
||||
detail = normalizeUsageDetailTotal(detail)
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, r.buildRecord(detail, failed, fail))
|
||||
r.publishRecord(ctx, r.buildRecord(detail, failed, fail))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -130,10 +133,15 @@ func (r *UsageReporter) EnsurePublished(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
r.once.Do(func() {
|
||||
usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false, usage.Failure{}))
|
||||
r.publishRecord(ctx, r.buildRecord(usage.Detail{}, false, usage.Failure{}))
|
||||
})
|
||||
}
|
||||
|
||||
func (r *UsageReporter) publishRecord(ctx context.Context, record usage.Record) {
|
||||
record.ResponseHeaders = internallogging.GetResponseHeaders(ctx)
|
||||
usage.PublishRecord(ctx, record)
|
||||
}
|
||||
|
||||
func (r *UsageReporter) buildRecord(detail usage.Detail, failed bool, failures ...usage.Failure) usage.Record {
|
||||
var fail usage.Failure
|
||||
if len(failures) > 0 {
|
||||
@@ -150,19 +158,20 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
|
||||
return usage.Record{Model: model, Detail: detail, Failed: failed, Fail: fail}
|
||||
}
|
||||
return usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: model,
|
||||
Alias: r.alias,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
AuthType: r.authType,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
Fail: fail,
|
||||
Detail: detail,
|
||||
Provider: r.provider,
|
||||
Model: model,
|
||||
Alias: r.alias,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
AuthIndex: r.authIndex,
|
||||
AuthType: r.authType,
|
||||
ReasoningEffort: r.reasoning,
|
||||
RequestedAt: r.requestedAt,
|
||||
Latency: r.latency(),
|
||||
Failed: failed,
|
||||
Fail: fail,
|
||||
Detail: detail,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -159,6 +159,16 @@ func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesReasoningEffort(t *testing.T) {
|
||||
ctx := usage.WithReasoningEffort(context.Background(), "medium")
|
||||
reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
|
||||
|
||||
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||
if record.ReasoningEffort != "medium" {
|
||||
t.Fatalf("reasoning effort = %q, want %q", record.ReasoningEffort, "medium")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
|
||||
reporter := &UsageReporter{
|
||||
provider: "codex",
|
||||
|
||||
@@ -30,7 +30,7 @@ func newUtlsRoundTripper(proxyURL string) *utlsRoundTripper {
|
||||
if proxyURL != "" {
|
||||
proxyDialer, mode, errBuild := proxyutil.BuildDialer(proxyURL)
|
||||
if errBuild != nil {
|
||||
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyURL, errBuild)
|
||||
log.Errorf("utls: failed to configure proxy dialer for %q: %v", proxyutil.Redact(proxyURL), errBuild)
|
||||
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
|
||||
dialer = proxyDialer
|
||||
}
|
||||
|
||||
@@ -4,9 +4,13 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +25,14 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
openAICompatImageHandlerType = "openai-image"
|
||||
openAICompatImagesGenerationsPath = "/images/generations"
|
||||
openAICompatImagesEditsPath = "/images/edits"
|
||||
openAICompatDefaultImageEndpoint = openAICompatImagesGenerationsPath
|
||||
openAICompatMultipartMemory int64 = 32 << 20
|
||||
)
|
||||
|
||||
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
|
||||
// It performs request/response translation and executes against the provider base URL
|
||||
// using per-auth credentials (API key) and per-auth HTTP transport (proxy) from context.
|
||||
@@ -71,6 +83,10 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
|
||||
return e.executeImages(ctx, auth, req, opts, endpointPath)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -179,7 +195,98 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
baseURL, apiKey := e.resolveCredentials(auth)
|
||||
if baseURL == "" {
|
||||
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), false)
|
||||
if errPrepare != nil {
|
||||
err = errPrepare
|
||||
return resp, err
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + endpointPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
body, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(body)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
|
||||
reporter.EnsurePublished(ctx)
|
||||
resp = cliproxyexecutor.Response{Payload: body, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
|
||||
return e.executeImagesStream(ctx, auth, req, opts, endpointPath)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -342,6 +449,121 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) executeImagesStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
baseURL, apiKey := e.resolveCredentials(auth)
|
||||
if baseURL == "" {
|
||||
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), true)
|
||||
if errPrepare != nil {
|
||||
err = errPrepare
|
||||
return nil, err
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + endpointPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
body, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return nil, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(body)}
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
reporter.EnsurePublished(ctx)
|
||||
}()
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, errRead := httpResp.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
chunk := bytes.Clone(buffer[:n])
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, chunk)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errRead != nil {
|
||||
if errRead != io.EOF {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
reporter.PublishFailure(ctx, errRead)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
@@ -380,6 +602,124 @@ func (e *OpenAICompatExecutor) Refresh(ctx context.Context, auth *cliproxyauth.A
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func openAICompatImageEndpointPath(opts cliproxyexecutor.Options) string {
|
||||
if opts.SourceFormat.String() != openAICompatImageHandlerType {
|
||||
return ""
|
||||
}
|
||||
path := helps.PayloadRequestPath(opts)
|
||||
if strings.HasSuffix(path, "/images/edits") {
|
||||
return openAICompatImagesEditsPath
|
||||
}
|
||||
if strings.HasSuffix(path, "/images/generations") {
|
||||
return openAICompatImagesGenerationsPath
|
||||
}
|
||||
return openAICompatDefaultImageEndpoint
|
||||
}
|
||||
|
||||
func prepareOpenAICompatImagesPayload(payload []byte, model string, contentType string, stream bool) ([]byte, string, error) {
|
||||
model = strings.TrimSpace(model)
|
||||
contentType = strings.TrimSpace(contentType)
|
||||
if json.Valid(payload) {
|
||||
if model != "" {
|
||||
payload, _ = sjson.SetBytes(payload, "model", model)
|
||||
}
|
||||
if stream {
|
||||
payload, _ = sjson.SetBytes(payload, "stream", true)
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "stream")
|
||||
}
|
||||
return payload, "application/json", nil
|
||||
}
|
||||
|
||||
mediaType, params, errParse := mime.ParseMediaType(contentType)
|
||||
if errParse != nil || !strings.HasPrefix(strings.ToLower(strings.TrimSpace(mediaType)), "multipart/") {
|
||||
return payload, contentType, nil
|
||||
}
|
||||
boundary := strings.TrimSpace(params["boundary"])
|
||||
if boundary == "" {
|
||||
return nil, "", fmt.Errorf("multipart boundary is missing")
|
||||
}
|
||||
return rewriteOpenAICompatImagesMultipartPayload(payload, model, boundary, stream)
|
||||
}
|
||||
|
||||
func cloneOpenAICompatMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
|
||||
dst := make(textproto.MIMEHeader, len(src))
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func rewriteOpenAICompatImagesMultipartPayload(payload []byte, model string, boundary string, stream bool) ([]byte, string, error) {
|
||||
reader := multipart.NewReader(bytes.NewReader(payload), boundary)
|
||||
form, errRead := reader.ReadForm(openAICompatMultipartMemory)
|
||||
if errRead != nil {
|
||||
return nil, "", fmt.Errorf("read multipart form failed: %w", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := form.RemoveAll(); errRemove != nil {
|
||||
log.Errorf("openai compat executor: remove multipart form files error: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if model != "" {
|
||||
if errWrite := writer.WriteField("model", model); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
|
||||
}
|
||||
}
|
||||
if stream {
|
||||
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
|
||||
}
|
||||
}
|
||||
for key, values := range form.Value {
|
||||
if key == "model" || key == "stream" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
if errWrite := writer.WriteField(key, value); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
|
||||
}
|
||||
}
|
||||
}
|
||||
for key, files := range form.File {
|
||||
for _, fileHeader := range files {
|
||||
if fileHeader == nil {
|
||||
continue
|
||||
}
|
||||
header := cloneOpenAICompatMIMEHeader(fileHeader.Header)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
|
||||
if header.Get("Content-Type") == "" {
|
||||
header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
|
||||
}
|
||||
src, errOpen := fileHeader.Open()
|
||||
if errOpen != nil {
|
||||
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
|
||||
}
|
||||
_, errCopy := io.Copy(part, src)
|
||||
if errClose := src.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close upload file error: %v", errClose)
|
||||
if errCopy == nil {
|
||||
errCopy = errClose
|
||||
}
|
||||
}
|
||||
if errCopy != nil {
|
||||
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
|
||||
}
|
||||
return body.Bytes(), writer.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (baseURL, apiKey string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -102,6 +106,265 @@ func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorImagesGenerationsPassthrough(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotBody []byte
|
||||
var gotContentType string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotContentType = r.Header.Get("Content-Type")
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody = body
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}],"usage":{"total_tokens":1}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL + "/v1",
|
||||
"api_key": "test",
|
||||
}}
|
||||
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "upstream-image",
|
||||
Payload: []byte(`{"model":"compat-image","prompt":"draw"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-image"),
|
||||
Stream: false,
|
||||
Headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if gotPath != "/v1/images/generations" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/generations")
|
||||
}
|
||||
if gotContentType != "application/json" {
|
||||
t.Fatalf("content type = %q, want application/json", gotContentType)
|
||||
}
|
||||
if got := gjson.GetBytes(gotBody, "model").String(); got != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(gotBody))
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "data.0.b64_json").String(); got != "AA==" {
|
||||
t.Fatalf("response payload = %s", string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorImagesGenerationsStreamsUpstream(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotBody []byte
|
||||
var gotAccept string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAccept = r.Header.Get("Accept")
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody = body
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("event: image_generation.partial\ndata: {\"type\":\"image_generation.partial\"}\n\n"))
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL + "/v1",
|
||||
"api_key": "test",
|
||||
}}
|
||||
streamResult, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "upstream-image",
|
||||
Payload: []byte(`{"model":"compat-image","prompt":"draw","stream":true}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-image"),
|
||||
Stream: true,
|
||||
Headers: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/generations",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
var streamed bytes.Buffer
|
||||
for chunk := range streamResult.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("stream chunk error: %v", chunk.Err)
|
||||
}
|
||||
streamed.Write(chunk.Payload)
|
||||
}
|
||||
if gotPath != "/v1/images/generations" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/generations")
|
||||
}
|
||||
if gotAccept != "text/event-stream" {
|
||||
t.Fatalf("accept = %q, want text/event-stream", gotAccept)
|
||||
}
|
||||
if got := gjson.GetBytes(gotBody, "model").String(); got != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(gotBody))
|
||||
}
|
||||
if !gjson.GetBytes(gotBody, "stream").Bool() {
|
||||
t.Fatalf("stream flag missing from upstream body: %s", string(gotBody))
|
||||
}
|
||||
if !strings.Contains(streamed.String(), "event: image_generation.partial") || !strings.Contains(streamed.String(), "data: [DONE]") {
|
||||
t.Fatalf("streamed body = %q", streamed.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorImagesEditsMultipartRewritesModel(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
|
||||
t.Fatalf("write model field: %v", errWrite)
|
||||
}
|
||||
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
|
||||
t.Fatalf("write prompt field: %v", errWrite)
|
||||
}
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
|
||||
header.Set("Content-Type", "image/png")
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
t.Fatalf("create image field: %v", errCreate)
|
||||
}
|
||||
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
|
||||
t.Fatalf("write image field: %v", errWrite)
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
t.Fatalf("close multipart writer: %v", errClose)
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
|
||||
var gotPath string
|
||||
var gotModel string
|
||||
var gotPrompt string
|
||||
var gotFile string
|
||||
var gotFileContentType string
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
if errParse := r.ParseMultipartForm(32 << 20); errParse != nil {
|
||||
t.Fatalf("parse multipart form: %v", errParse)
|
||||
}
|
||||
gotModel = r.FormValue("model")
|
||||
gotPrompt = r.FormValue("prompt")
|
||||
file, fileHeader, errFile := r.FormFile("image")
|
||||
if errFile != nil {
|
||||
t.Fatalf("read image file: %v", errFile)
|
||||
}
|
||||
gotFileContentType = fileHeader.Header.Get("Content-Type")
|
||||
data, errRead := io.ReadAll(file)
|
||||
if errClose := file.Close(); errClose != nil {
|
||||
t.Fatalf("close image file: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
t.Fatalf("read image file: %v", errRead)
|
||||
}
|
||||
gotFile = string(data)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"created":123,"data":[{"b64_json":"AA=="}]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL + "/v1",
|
||||
"api_key": "test",
|
||||
}}
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "upstream-image",
|
||||
Payload: body.Bytes(),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai-image"),
|
||||
Stream: false,
|
||||
Headers: http.Header{
|
||||
"Content-Type": []string{contentType},
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.RequestPathMetadataKey: "/v1/images/edits",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if gotPath != "/v1/images/edits" {
|
||||
t.Fatalf("path = %q, want %q", gotPath, "/v1/images/edits")
|
||||
}
|
||||
if gotModel != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image", gotModel)
|
||||
}
|
||||
if gotPrompt != "edit" {
|
||||
t.Fatalf("prompt = %q, want edit", gotPrompt)
|
||||
}
|
||||
if gotFile != "png-data" {
|
||||
t.Fatalf("file = %q, want png-data", gotFile)
|
||||
}
|
||||
if gotFileContentType != "image/png" {
|
||||
t.Fatalf("file content type = %q, want image/png", gotFileContentType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteOpenAICompatImagesMultipartPayloadPreservesStreamAndFileContentType(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
|
||||
t.Fatalf("write model field: %v", errWrite)
|
||||
}
|
||||
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
|
||||
t.Fatalf("write stream field: %v", errWrite)
|
||||
}
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.webp"))
|
||||
header.Set("Content-Type", "image/webp")
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
t.Fatalf("create image field: %v", errCreate)
|
||||
}
|
||||
if _, errWrite := part.Write([]byte("webp-data")); errWrite != nil {
|
||||
t.Fatalf("write image field: %v", errWrite)
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
t.Fatalf("close multipart writer: %v", errClose)
|
||||
}
|
||||
|
||||
out, contentType, err := prepareOpenAICompatImagesPayload(body.Bytes(), "upstream-image", writer.FormDataContentType(), true)
|
||||
if err != nil {
|
||||
t.Fatalf("prepareOpenAICompatImagesPayload error: %v", err)
|
||||
}
|
||||
mediaType, params, errParse := mime.ParseMediaType(contentType)
|
||||
if errParse != nil {
|
||||
t.Fatalf("parse content type: %v", errParse)
|
||||
}
|
||||
if mediaType != "multipart/form-data" {
|
||||
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
|
||||
}
|
||||
reader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
|
||||
form, errRead := reader.ReadForm(32 << 20)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read rewritten form: %v", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := form.RemoveAll(); errRemove != nil {
|
||||
t.Fatalf("remove form files: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
if got := form.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
|
||||
t.Fatalf("model values = %#v, want upstream-image", got)
|
||||
}
|
||||
if got := form.Value["stream"]; len(got) != 1 || got[0] != "true" {
|
||||
t.Fatalf("stream values = %#v, want true", got)
|
||||
}
|
||||
if got := form.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/webp" {
|
||||
t.Fatalf("image headers = %#v, want image/webp", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
@@ -487,7 +487,7 @@ func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxye
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
|
||||
var err error
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), e.Identifier(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -196,6 +196,48 @@ func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIExecutorAppliesThinkingSuffix(t *testing.T) {
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var errRead error
|
||||
gotBody, errRead = io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read body: %v", errRead)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-4.3(low)",
|
||||
Payload: []byte(`{"model":"grok-4.3","input":"hello"}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(gotBody, "model").String(); got != "grok-4.3" {
|
||||
t.Fatalf("model = %q, want grok-4.3; body=%s", got, string(gotBody))
|
||||
}
|
||||
if got := gjson.GetBytes(gotBody, "reasoning.effort").String(); got != "low" {
|
||||
t.Fatalf("reasoning.effort = %q, want low; body=%s", got, string(gotBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIExecutorExecuteStreamFiltersToolSearchTool(t *testing.T) {
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -18,6 +18,7 @@ var providerAppliers = map[string]ProviderApplier{
|
||||
"codex": nil,
|
||||
"antigravity": nil,
|
||||
"kimi": nil,
|
||||
"xai": nil,
|
||||
}
|
||||
|
||||
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||
@@ -62,7 +63,7 @@ func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
|
||||
// - body: Original request body JSON
|
||||
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
|
||||
// - fromFormat: Source request format (e.g., openai, codex, gemini)
|
||||
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, kimi)
|
||||
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, kimi, xai)
|
||||
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
|
||||
//
|
||||
// Returns:
|
||||
@@ -324,7 +325,7 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
||||
return extractGeminiConfig(body, provider)
|
||||
case "openai":
|
||||
return extractOpenAIConfig(body)
|
||||
case "codex":
|
||||
case "codex", "xai":
|
||||
return extractCodexConfig(body)
|
||||
case "kimi":
|
||||
// Kimi uses OpenAI-compatible reasoning_effort format
|
||||
@@ -338,6 +339,56 @@ func hasThinkingConfig(config ThinkingConfig) bool {
|
||||
return config.Mode != ModeBudget || config.Budget != 0 || config.Level != ""
|
||||
}
|
||||
|
||||
// ExtractReasoningEffort returns the request's thinking setting as a canonical
|
||||
// reasoning_effort label for usage logging. Model suffixes have the same
|
||||
// priority as ApplyThinking: a valid suffix overrides body fields.
|
||||
func ExtractReasoningEffort(body []byte, provider, model string) string {
|
||||
if effort := reasoningEffortFromSuffix(ParseSuffix(model)); effort != "" {
|
||||
return effort
|
||||
}
|
||||
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
config := extractThinkingConfig(body, provider)
|
||||
if !hasThinkingConfig(config) {
|
||||
switch provider {
|
||||
case "openai-response":
|
||||
config = extractCodexConfig(body)
|
||||
case "openai":
|
||||
config = extractCodexConfig(body)
|
||||
}
|
||||
}
|
||||
return reasoningEffortFromConfig(config)
|
||||
}
|
||||
|
||||
func reasoningEffortFromSuffix(suffix SuffixResult) string {
|
||||
if !suffix.HasSuffix {
|
||||
return ""
|
||||
}
|
||||
return reasoningEffortFromConfig(parseSuffixToConfig(suffix.RawSuffix, "", suffix.ModelName))
|
||||
}
|
||||
|
||||
func reasoningEffortFromConfig(config ThinkingConfig) string {
|
||||
if !hasThinkingConfig(config) {
|
||||
return ""
|
||||
}
|
||||
switch config.Mode {
|
||||
case ModeNone:
|
||||
return string(LevelNone)
|
||||
case ModeAuto:
|
||||
return string(LevelAuto)
|
||||
case ModeLevel:
|
||||
return strings.ToLower(strings.TrimSpace(string(config.Level)))
|
||||
case ModeBudget:
|
||||
level, ok := ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return level
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// extractClaudeConfig extracts thinking configuration from Claude format request body.
|
||||
//
|
||||
// Claude API format:
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
// Package xai implements thinking configuration for xAI Grok Responses API models.
|
||||
//
|
||||
// xAI models use the OpenAI Responses API compatible reasoning.effort format
|
||||
// with discrete levels.
|
||||
package xai
|
||||
|
||||
import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking/provider/codex"
|
||||
)
|
||||
|
||||
// Applier implements thinking.ProviderApplier for xAI models.
|
||||
type Applier struct {
|
||||
codex.Applier
|
||||
}
|
||||
|
||||
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||
|
||||
// NewApplier creates a new xAI thinking applier.
|
||||
func NewApplier() *Applier {
|
||||
return &Applier{}
|
||||
}
|
||||
|
||||
func init() {
|
||||
thinking.RegisterProvider("xai", NewApplier())
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestApplySetsReasoningEffort(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "grok-4.3",
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
ZeroAllowed: true,
|
||||
Levels: []string{"none", "low", "medium", "high"},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := applier.Apply([]byte(`{"input":"hello"}`), thinking.ThinkingConfig{
|
||||
Mode: thinking.ModeLevel,
|
||||
Level: thinking.LevelHigh,
|
||||
}, modelInfo)
|
||||
if err != nil {
|
||||
t.Fatalf("Apply() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "reasoning.effort").String(); got != "high" {
|
||||
t.Fatalf("reasoning.effort = %q, want high; body=%s", got, string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyNoneFallsBackToLowestLevelWhenDisableUnsupported(t *testing.T) {
|
||||
applier := NewApplier()
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: "grok-3-mini",
|
||||
Thinking: ®istry.ThinkingSupport{
|
||||
Levels: []string{"low", "medium", "high"},
|
||||
},
|
||||
}
|
||||
|
||||
out, err := applier.Apply([]byte(`{"input":"hello"}`), thinking.ThinkingConfig{
|
||||
Mode: thinking.ModeNone,
|
||||
}, modelInfo)
|
||||
if err != nil {
|
||||
t.Fatalf("Apply() error = %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "reasoning.effort").String(); got != "low" {
|
||||
t.Fatalf("reasoning.effort = %q, want low; body=%s", got, string(out))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package thinking
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExtractReasoningEffortUsesSuffixOverBody(t *testing.T) {
|
||||
got := ExtractReasoningEffort([]byte(`{"reasoning_effort":"low"}`), "openai", "gpt-5.4(high)")
|
||||
if got != "high" {
|
||||
t.Fatalf("ExtractReasoningEffort() = %q, want %q", got, "high")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractReasoningEffortConvertsBudgetToLevel(t *testing.T) {
|
||||
got := ExtractReasoningEffort([]byte(`{"thinking":{"type":"enabled","budget_tokens":8192}}`), "claude", "claude-sonnet-4-5")
|
||||
if got != "medium" {
|
||||
t.Fatalf("ExtractReasoningEffort() = %q, want %q", got, "medium")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractReasoningEffortSupportsOpenAIResponses(t *testing.T) {
|
||||
got := ExtractReasoningEffort([]byte(`{"reasoning":{"effort":"medium"}}`), "openai-response", "gpt-5.4")
|
||||
if got != "medium" {
|
||||
t.Fatalf("ExtractReasoningEffort() = %q, want %q", got, "medium")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractReasoningEffortMissingConfigIsEmpty(t *testing.T) {
|
||||
got := ExtractReasoningEffort([]byte(`{"messages":[{"role":"user","content":"hi"}]}`), "openai", "gpt-5.4")
|
||||
if got != "" {
|
||||
t.Fatalf("ExtractReasoningEffort() = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
@@ -42,7 +42,7 @@ func StripThinkingConfig(body []byte, provider string) []byte {
|
||||
"reasoning_effort",
|
||||
"thinking",
|
||||
}
|
||||
case "codex":
|
||||
case "codex", "xai":
|
||||
paths = []string{"reasoning.effort"}
|
||||
default:
|
||||
return body
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Package thinking provides unified thinking configuration processing.
|
||||
//
|
||||
// This package offers a unified interface for parsing, validating, and applying
|
||||
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, Codex, Antigravity, Kimi).
|
||||
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, Codex, Antigravity, Kimi, xAI).
|
||||
package thinking
|
||||
|
||||
import "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
|
||||
@@ -357,7 +357,7 @@ func isGeminiFamily(provider string) bool {
|
||||
|
||||
func isOpenAIFamily(provider string) bool {
|
||||
switch provider {
|
||||
case "openai", "openai-response", "codex":
|
||||
case "openai", "openai-response", "codex", "xai":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
||||
@@ -101,7 +101,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
if strings.HasPrefix(systemPrompt, "x-anthropic-billing-header:") {
|
||||
if util.IsClaudeCodeAttributionSystemText(systemPrompt) {
|
||||
continue
|
||||
}
|
||||
partJSON := []byte(`{}`)
|
||||
@@ -112,7 +112,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
} else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) {
|
||||
systemInstructionJSON = []byte(`{"role":"user","parts":[{"text":""}]}`)
|
||||
systemInstructionJSON, _ = sjson.SetBytes(systemInstructionJSON, "parts.0.text", systemResult.String())
|
||||
hasSystemInstruction = true
|
||||
|
||||
@@ -70,6 +70,28 @@ func uint64Ptr(v uint64) *uint64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_StripsClaudeCodeAttribution(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [
|
||||
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
|
||||
{"type": "text", "text": "Antigravity system prompt"}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.systemInstruction.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 system part after attribution strip, got %d: %s", len(parts), gjson.Get(outputStr, "request.systemInstruction.parts").Raw)
|
||||
}
|
||||
if got := parts[0].Get("text").String(); got != "Antigravity system prompt" {
|
||||
t.Fatalf("Unexpected system part: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func testNonAnthropicRawSignature(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -99,35 +99,19 @@ func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
|
||||
// Gemini-specific handling for non-Claude models:
|
||||
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
|
||||
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
// - Replace client-provided thoughtSignature values with the skip sentinel.
|
||||
// - Add the same sentinel to functionCall and thinking parts so upstream can bypass signature validation.
|
||||
if !strings.Contains(strings.ToLower(modelName), "claude") {
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
// First pass: collect indices of thinking parts to mark with skip sentinel
|
||||
var thinkingIndicesToSkipSignature []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Collect indices of thinking blocks to mark with skip sentinel
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
|
||||
}
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
}
|
||||
if part.Get("functionCall").Exists() || part.Get("thought").Exists() || part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToSkipSignature[i]
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
|
||||
// Valid signature on functionCall should be preserved
|
||||
func TestConvertGeminiRequestToAntigravity_ReplacesClientSignatureOnFunctionCall(t *testing.T) {
|
||||
// Client signatures on Gemini function calls are not portable to Antigravity.
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
@@ -25,15 +25,83 @@ func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T)
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that valid thoughtSignature is preserved
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part, got %d", len(parts))
|
||||
}
|
||||
|
||||
sig := parts[0].Get("thoughtSignature").String()
|
||||
if sig != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_ReplacesClientSignatureOnTextPart(t *testing.T) {
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"text": "previous answer", "thoughtSignature": "%s"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_AddsSkipSentinelToStringThoughtPart(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"thought": "internal reasoning"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_SkipsUppercaseClaudeModel(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "Claude-Test",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "test_tool", "args": {}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("Claude-Test", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
if sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature"); sig.Exists() {
|
||||
t.Fatalf("Expected no thoughtSignature for Claude model, got %s", sig.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,12 +6,15 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -50,7 +53,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
contentIndex := 0
|
||||
|
||||
appendSystemText := func(text string) {
|
||||
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
|
||||
if text == "" || util.IsClaudeCodeAttributionSystemText(text) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,6 +87,9 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
for i := 0; i < len(messageResults); i++ {
|
||||
messageResult := messageResults[i]
|
||||
messageRole := messageResult.Get("role").String()
|
||||
if messageRole == "system" {
|
||||
messageRole = "developer"
|
||||
}
|
||||
|
||||
newMessage := func() []byte {
|
||||
msg := []byte(`{"type":"message","role":"","content":[]}`)
|
||||
@@ -172,7 +178,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
case "tool_use":
|
||||
flushMessage()
|
||||
functionCallMessage := []byte(`{"type":"function_call"}`)
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", messageContentResult.Get("id").String())
|
||||
functionCallMessage, _ = sjson.SetBytes(functionCallMessage, "call_id", shortenCodexCallIDIfNeeded(messageContentResult.Get("id").String()))
|
||||
{
|
||||
name := messageContentResult.Get("name").String()
|
||||
if short, ok := toolNameMap[name]; ok {
|
||||
@@ -187,7 +193,7 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
case "tool_result":
|
||||
flushMessage()
|
||||
functionCallOutputMessage := []byte(`{"type":"function_call_output"}`)
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||
functionCallOutputMessage, _ = sjson.SetBytes(functionCallOutputMessage, "call_id", shortenCodexCallIDIfNeeded(messageContentResult.Get("tool_use_id").String()))
|
||||
|
||||
contentResult := messageContentResult.Get("content")
|
||||
if contentResult.IsArray() {
|
||||
@@ -361,6 +367,23 @@ func isFernetLikeReasoningSignature(signature string) bool {
|
||||
return ciphertextLen > 0 && ciphertextLen%aesBlockSize == 0
|
||||
}
|
||||
|
||||
// shortenCodexCallIDIfNeeded keeps Claude tool IDs within the OpenAI Responses
|
||||
// API call_id limit while preserving a stable, low-collision mapping.
|
||||
func shortenCodexCallIDIfNeeded(id string) string {
|
||||
const limit = 64
|
||||
if len(id) <= limit {
|
||||
return id
|
||||
}
|
||||
|
||||
sum := sha256.Sum256([]byte(id))
|
||||
suffix := "_" + hex.EncodeToString(sum[:8])
|
||||
prefixLen := limit - len(suffix)
|
||||
if prefixLen <= 0 {
|
||||
return suffix[len(suffix)-limit:]
|
||||
}
|
||||
return id[:prefixLen] + suffix
|
||||
}
|
||||
|
||||
func isClaudeWebSearchToolType(toolType string) bool {
|
||||
return toolType == "web_search_20250305" || toolType == "web_search_20260209"
|
||||
}
|
||||
|
||||
@@ -42,6 +42,18 @@ func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
|
||||
wantHasDeveloper: true,
|
||||
wantTexts: []string{"Be helpful"},
|
||||
},
|
||||
{
|
||||
name: "System role in messages",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{"role": "system", "content": "Follow the project instructions"},
|
||||
{"role": "user", "content": "hello"}
|
||||
]
|
||||
}`,
|
||||
wantHasDeveloper: true,
|
||||
wantTexts: []string{"Follow the project instructions"},
|
||||
},
|
||||
{
|
||||
name: "Array system field with filtered billing header",
|
||||
inputJSON: `{
|
||||
@@ -136,6 +148,56 @@ func TestConvertClaudeRequestToCodex_ParallelToolCalls(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToCodex_ShortenLongToolUseIDs(t *testing.T) {
|
||||
longID := "toolu_" + strings.Repeat("a", 62)
|
||||
if len(longID) <= 64 {
|
||||
t.Fatalf("test setup error: longID length = %d, want > 64", len(longID))
|
||||
}
|
||||
|
||||
inputJSON := `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type":"text","text":"run pwd"}]},
|
||||
{"role": "assistant", "content": [
|
||||
{"type":"tool_use","id":"` + longID + `","name":"Bash","input":{"cmd":"pwd"}}
|
||||
]},
|
||||
{"role": "user", "content": [
|
||||
{"type":"tool_result","tool_use_id":"` + longID + `","content":"ok"}
|
||||
]}
|
||||
]
|
||||
}`
|
||||
|
||||
result := ConvertClaudeRequestToCodex("test-model", []byte(inputJSON), false)
|
||||
inputs := gjson.GetBytes(result, "input").Array()
|
||||
|
||||
var callID string
|
||||
var outputCallID string
|
||||
for _, item := range inputs {
|
||||
switch item.Get("type").String() {
|
||||
case "function_call":
|
||||
callID = item.Get("call_id").String()
|
||||
case "function_call_output":
|
||||
outputCallID = item.Get("call_id").String()
|
||||
}
|
||||
}
|
||||
|
||||
if callID == "" {
|
||||
t.Fatalf("missing function_call item. Output: %s", string(result))
|
||||
}
|
||||
if outputCallID == "" {
|
||||
t.Fatalf("missing function_call_output item. Output: %s", string(result))
|
||||
}
|
||||
if callID != outputCallID {
|
||||
t.Fatalf("call_id mismatch: function_call=%q function_call_output=%q. Output: %s", callID, outputCallID, string(result))
|
||||
}
|
||||
if len(callID) > 64 {
|
||||
t.Fatalf("call_id length = %d, want <= 64: %q", len(callID), callID)
|
||||
}
|
||||
if callID == longID {
|
||||
t.Fatalf("long call_id was not shortened: %q", callID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToCodex_ToolChoiceModeMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -140,7 +140,7 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa
|
||||
params.HasReceivedArgumentsDelta = false
|
||||
template = []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
||||
template, _ = sjson.SetBytes(template, "index", params.BlockIndex)
|
||||
template, _ = sjson.SetBytes(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
|
||||
template, _ = sjson.SetBytes(template, "content_block.id", shortenCodexCallIDIfNeeded(util.SanitizeClaudeToolID(itemResult.Get("call_id").String())))
|
||||
{
|
||||
name := itemResult.Get("name").String()
|
||||
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
|
||||
@@ -350,7 +350,7 @@ func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, original
|
||||
}
|
||||
|
||||
toolBlock := []byte(`{"type":"tool_use","id":"","name":"","input":{}}`)
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "id", shortenCodexCallIDIfNeeded(util.SanitizeClaudeToolID(item.Get("call_id").String())))
|
||||
toolBlock, _ = sjson.SetBytes(toolBlock, "name", name)
|
||||
inputRaw := "{}"
|
||||
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
|
||||
|
||||
@@ -459,6 +459,70 @@ func TestConvertCodexResponseToClaude_StreamEmptyOutputUsesOutputItemDoneMessage
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaude_ShortensLongToolUseIDs(t *testing.T) {
|
||||
longCallID := "call_" + strings.Repeat("a", 62)
|
||||
if len(longCallID) <= 64 {
|
||||
t.Fatalf("test setup error: longCallID length = %d, want > 64", len(longCallID))
|
||||
}
|
||||
|
||||
t.Run("stream", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"tools":[{"name":"lookup","input_schema":{"type":"object","properties":{}}}]}`)
|
||||
var param any
|
||||
|
||||
outputs := ConvertCodexResponseToClaude(ctx, "", originalRequest, nil, []byte(`data: {"type":"response.output_item.added","item":{"type":"function_call","call_id":"`+longCallID+`","name":"lookup"}}`), ¶m)
|
||||
|
||||
toolID := ""
|
||||
for _, out := range outputs {
|
||||
for _, line := range strings.Split(string(out), "\n") {
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
data := gjson.Parse(strings.TrimPrefix(line, "data: "))
|
||||
if data.Get("type").String() == "content_block_start" && data.Get("content_block.type").String() == "tool_use" {
|
||||
toolID = data.Get("content_block.id").String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if toolID == "" {
|
||||
t.Fatalf("missing stream tool_use block. Outputs=%q", outputs)
|
||||
}
|
||||
if len(toolID) > 64 {
|
||||
t.Fatalf("stream tool_use id length = %d, want <= 64: %q", len(toolID), toolID)
|
||||
}
|
||||
if toolID == longCallID {
|
||||
t.Fatalf("stream tool_use id was not shortened: %q", toolID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nonstream", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
originalRequest := []byte(`{"tools":[{"name":"lookup","input_schema":{"type":"object","properties":{}}}]}`)
|
||||
response := []byte(`{
|
||||
"type":"response.completed",
|
||||
"response":{
|
||||
"id":"resp_1",
|
||||
"model":"gpt-5",
|
||||
"usage":{"input_tokens":1,"output_tokens":1},
|
||||
"output":[{"type":"function_call","call_id":"` + longCallID + `","name":"lookup","arguments":"{}"}]
|
||||
}
|
||||
}`)
|
||||
|
||||
out := ConvertCodexResponseToClaudeNonStream(ctx, "", originalRequest, nil, response, nil)
|
||||
toolID := gjson.GetBytes(out, "content.0.id").String()
|
||||
if toolID == "" {
|
||||
t.Fatalf("missing nonstream tool_use id. Output: %s", string(out))
|
||||
}
|
||||
if len(toolID) > 64 {
|
||||
t.Fatalf("nonstream tool_use id length = %d, want <= 64: %q", len(toolID), toolID)
|
||||
}
|
||||
if toolID == longCallID {
|
||||
t.Fatalf("nonstream tool_use id was not shortened: %q", toolID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConvertCodexResponseToClaude_StreamStopReasonMapping(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -49,6 +49,9 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
if systemPromptResult.Get("type").String() == "text" {
|
||||
textResult := systemPromptResult.Get("text")
|
||||
if textResult.Type == gjson.String {
|
||||
if util.IsClaudeCodeAttributionSystemText(textResult.String()) {
|
||||
return true
|
||||
}
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", textResult.String())
|
||||
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
|
||||
@@ -60,7 +63,7 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
if hasSystemParts {
|
||||
out, _ = sjson.SetRawBytes(out, "request.systemInstruction", systemInstruction)
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
} else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) {
|
||||
out, _ = sjson.SetBytes(out, "request.systemInstruction.parts.-1.text", systemResult.String())
|
||||
}
|
||||
|
||||
|
||||
@@ -40,3 +40,24 @@ func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) {
|
||||
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToCLI_StripsClaudeCodeAttribution(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": [
|
||||
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
|
||||
{"type": "text", "text": "User system prompt"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false)
|
||||
|
||||
parts := gjson.GetBytes(output, "request.systemInstruction.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 system part after attribution strip, got %d: %s", len(parts), gjson.GetBytes(output, "request.systemInstruction.parts").Raw)
|
||||
}
|
||||
if got := parts[0].Get("text").String(); got != "User system prompt" {
|
||||
t.Fatalf("Unexpected system part: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,9 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if systemPromptResult.Get("type").String() == "text" {
|
||||
textResult := systemPromptResult.Get("text")
|
||||
if textResult.Type == gjson.String {
|
||||
if util.IsClaudeCodeAttributionSystemText(textResult.String()) {
|
||||
return true
|
||||
}
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", textResult.String())
|
||||
systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part)
|
||||
@@ -54,7 +57,7 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if hasSystemParts {
|
||||
out, _ = sjson.SetRawBytes(out, "system_instruction", systemInstruction)
|
||||
}
|
||||
} else if systemResult.Type == gjson.String {
|
||||
} else if systemResult.Type == gjson.String && !util.IsClaudeCodeAttributionSystemText(systemResult.String()) {
|
||||
out, _ = sjson.SetBytes(out, "system_instruction.parts.-1.text", systemResult.String())
|
||||
}
|
||||
|
||||
@@ -78,8 +81,12 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
|
||||
switch contentResult.Get("type").String() {
|
||||
case "text":
|
||||
text := contentResult.Get("text").String()
|
||||
if text == "" {
|
||||
return true
|
||||
}
|
||||
part := []byte(`{"text":""}`)
|
||||
part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String())
|
||||
part, _ = sjson.SetBytes(part, "text", text)
|
||||
contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part)
|
||||
|
||||
case "tool_use":
|
||||
|
||||
@@ -78,3 +78,57 @@ func TestConvertClaudeRequestToGemini_ImageContent(t *testing.T) {
|
||||
t.Fatalf("Expected image data 'aGVsbG8=', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToGemini_StripsClaudeCodeAttribution(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": [
|
||||
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
|
||||
{"type": "text", "text": "You are a Claude agent, built on Anthropic's Claude Agent SDK."},
|
||||
{"type": "text", "text": "User system prompt"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
|
||||
|
||||
parts := gjson.GetBytes(output, "system_instruction.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 system parts after attribution strip, got %d: %s", len(parts), gjson.GetBytes(output, "system_instruction.parts").Raw)
|
||||
}
|
||||
if got := parts[0].Get("text").String(); got != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
|
||||
t.Fatalf("Unexpected first system part: %q", got)
|
||||
}
|
||||
if got := parts[1].Get("text").String(); got != "User system prompt" {
|
||||
t.Fatalf("Unexpected second system part: %q", got)
|
||||
}
|
||||
if gjson.GetBytes(output, `system_instruction.parts.#(text%"x-anthropic-billing-header:*")`).Exists() {
|
||||
t.Fatalf("Claude Code attribution block was forwarded: %s", gjson.GetBytes(output, "system_instruction.parts").Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToGemini_SkipsEmptyTextParts(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": ""},
|
||||
{"type": "text", "text": "hello"},
|
||||
{"type": "text", "text": ""}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
|
||||
|
||||
parts := gjson.GetBytes(output, "contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part after skipping empty text, got %d: %s", len(parts), output)
|
||||
}
|
||||
if got := parts[0].Get("text").String(); got != "hello" {
|
||||
t.Fatalf("Expected part text 'hello', got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -103,7 +104,7 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
hasSystemContent := false
|
||||
if system := root.Get("system"); system.Exists() {
|
||||
if system.Type == gjson.String {
|
||||
if system.String() != "" {
|
||||
if system.String() != "" && !util.IsClaudeCodeAttributionSystemText(system.String()) {
|
||||
oldSystem := []byte(`{"type":"text","text":""}`)
|
||||
oldSystem, _ = sjson.SetBytes(oldSystem, "text", system.String())
|
||||
systemMsgJSON, _ = sjson.SetRawBytes(systemMsgJSON, "content.-1", oldSystem)
|
||||
@@ -334,7 +335,7 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
|
||||
switch partType {
|
||||
case "text":
|
||||
text := part.Get("text").String()
|
||||
if strings.TrimSpace(text) == "" {
|
||||
if strings.TrimSpace(text) == "" || util.IsClaudeCodeAttributionSystemText(text) {
|
||||
return "", false
|
||||
}
|
||||
textContent := []byte(`{"type":"text","text":""}`)
|
||||
|
||||
@@ -696,3 +696,28 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
|
||||
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToOpenAI_StripsClaudeCodeAttribution(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"system": [
|
||||
{"type": "text", "text": "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;"},
|
||||
{"type": "text", "text": "User system prompt"}
|
||||
],
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "hi"}]}]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToOpenAI("gpt-5", inputJSON, false)
|
||||
messages := gjson.GetBytes(output, "messages").Array()
|
||||
if len(messages) == 0 || messages[0].Get("role").String() != "system" {
|
||||
t.Fatalf("Expected first message to be system, got: %s", gjson.GetBytes(output, "messages").Raw)
|
||||
}
|
||||
|
||||
content := messages[0].Get("content").Array()
|
||||
if len(content) != 1 {
|
||||
t.Fatalf("Expected 1 system content item after attribution strip, got %d: %s", len(content), messages[0].Get("content").Raw)
|
||||
}
|
||||
if got := content[0].Get("text").String(); got != "User system prompt" {
|
||||
t.Fatalf("Unexpected system content: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
translatorcommon "github.com/router-for-me/CLIProxyAPI/v7/internal/translator/common"
|
||||
@@ -26,6 +27,9 @@ type ConvertOpenAIResponseToAnthropicParams struct {
|
||||
Model string
|
||||
CreatedAt int64
|
||||
ToolNameMap map[string]string
|
||||
// SawToolCall is true once at least one tool_use content_block_start has
|
||||
// been emitted on the wire. Using raw upstream tool_calls presence here
|
||||
// can produce stop_reason=tool_use with zero announced tool blocks.
|
||||
SawToolCall bool
|
||||
// Content accumulator for streaming
|
||||
ContentAccumulator strings.Builder
|
||||
@@ -60,6 +64,9 @@ type ToolCallAccumulator struct {
|
||||
ID string
|
||||
Name string
|
||||
Arguments strings.Builder
|
||||
// StartEmitted tracks whether content_block_start has already been sent
|
||||
// for this tool index.
|
||||
StartEmitted bool
|
||||
}
|
||||
|
||||
// ConvertOpenAIResponseToClaude converts OpenAI streaming response format to Anthropic API format.
|
||||
@@ -218,9 +225,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
}
|
||||
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
param.SawToolCall = true
|
||||
index := int(toolCall.Get("index").Int())
|
||||
blockIndex := param.toolContentBlockIndex(index)
|
||||
|
||||
// Initialize accumulator if needed
|
||||
if _, exists := param.ToolCallsAccumulator[index]; !exists {
|
||||
@@ -229,27 +234,25 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
|
||||
accumulator := param.ToolCallsAccumulator[index]
|
||||
|
||||
// Handle tool call ID
|
||||
if id := toolCall.Get("id"); id.Exists() {
|
||||
accumulator.ID = id.String()
|
||||
// Handle tool call ID. Only accept JSON-string, non-empty
|
||||
// values so malformed upstream fields do not overwrite a
|
||||
// valid ID or coerce into a content_block.id.
|
||||
if id := toolCall.Get("id"); id.Exists() && id.Type == gjson.String {
|
||||
if idStr := id.String(); idStr != "" {
|
||||
accumulator.ID = idStr
|
||||
}
|
||||
}
|
||||
|
||||
// Handle function name
|
||||
// Handle function name and arguments
|
||||
if function := toolCall.Get("function"); function.Exists() {
|
||||
if name := function.Get("name"); name.Exists() && name.String() != "" {
|
||||
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
|
||||
|
||||
stopThinkingContentBlock(param, &results)
|
||||
|
||||
stopTextContentBlock(param, &results)
|
||||
|
||||
// Send content_block_start for tool_use
|
||||
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||
contentBlockStartJSONBytes := []byte(contentBlockStartJSON)
|
||||
contentBlockStartJSONBytes, _ = sjson.SetBytes(contentBlockStartJSONBytes, "index", blockIndex)
|
||||
contentBlockStartJSONBytes, _ = sjson.SetBytes(contentBlockStartJSONBytes, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
|
||||
contentBlockStartJSONBytes, _ = sjson.SetBytes(contentBlockStartJSONBytes, "content_block.name", accumulator.Name)
|
||||
results = append(results, translatorcommon.AppendSSEEventBytes(nil, "content_block_start", contentBlockStartJSONBytes, 2))
|
||||
// Only record the name until content_block_start has been
|
||||
// emitted. Some upstreams send "name": "" or repeat the
|
||||
// field across chunks; reassigning after start could drift
|
||||
// from what was already announced.
|
||||
if !accumulator.StartEmitted {
|
||||
if name := function.Get("name"); name.Exists() && name.Type == gjson.String && name.String() != "" {
|
||||
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Handle function arguments
|
||||
@@ -261,6 +264,13 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
}
|
||||
}
|
||||
|
||||
// Re-check on every chunk, not only chunks with a function
|
||||
// object. Some upstreams split function.name and id across
|
||||
// separate deltas.
|
||||
if !accumulator.StartEmitted && accumulator.Name != "" && accumulator.ID != "" && !param.ContentBlocksStopped {
|
||||
emitToolUseStart(param, index, accumulator, &results)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -269,9 +279,12 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
// Handle finish_reason (but don't send message_delta/message_stop yet)
|
||||
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
|
||||
reason := finishReason.String()
|
||||
if param.SawToolCall {
|
||||
switch {
|
||||
case param.SawToolCall:
|
||||
param.FinishReason = "tool_calls"
|
||||
} else {
|
||||
case reason == "tool_calls":
|
||||
param.FinishReason = "stop"
|
||||
default:
|
||||
param.FinishReason = reason
|
||||
}
|
||||
|
||||
@@ -289,8 +302,17 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
|
||||
// Send content_block_stop for any tool calls
|
||||
if !param.ContentBlocksStopped {
|
||||
for index := range param.ToolCallsAccumulator {
|
||||
for _, index := range toolCallAccumulatorIndexes(param.ToolCallsAccumulator) {
|
||||
accumulator := param.ToolCallsAccumulator[index]
|
||||
if !accumulator.StartEmitted {
|
||||
// Belated emit for streams that supplied a valid name but
|
||||
// never sent an id. SanitizeClaudeToolID("") produces the
|
||||
// expected stable synthetic toolu_<nanos>_<n> ID shape.
|
||||
if accumulator.Name == "" {
|
||||
continue
|
||||
}
|
||||
emitToolUseStart(param, index, accumulator, &results)
|
||||
}
|
||||
blockIndex := param.toolContentBlockIndex(index)
|
||||
|
||||
// Send complete input_json_delta with all accumulated arguments
|
||||
@@ -353,8 +375,16 @@ func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams)
|
||||
stopTextContentBlock(param, &results)
|
||||
|
||||
if !param.ContentBlocksStopped {
|
||||
for index := range param.ToolCallsAccumulator {
|
||||
for _, index := range toolCallAccumulatorIndexes(param.ToolCallsAccumulator) {
|
||||
accumulator := param.ToolCallsAccumulator[index]
|
||||
if !accumulator.StartEmitted {
|
||||
// Belated emit at [DONE]; same behavior as the finish_reason
|
||||
// path for name-but-no-id streams.
|
||||
if accumulator.Name == "" {
|
||||
continue
|
||||
}
|
||||
emitToolUseStart(param, index, accumulator, &results)
|
||||
}
|
||||
blockIndex := param.toolContentBlockIndex(index)
|
||||
|
||||
if accumulator.Arguments.Len() > 0 {
|
||||
@@ -547,6 +577,29 @@ func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results
|
||||
param.TextContentBlockIndex = -1
|
||||
}
|
||||
|
||||
func emitToolUseStart(param *ConvertOpenAIResponseToAnthropicParams, openAIToolIndex int, accumulator *ToolCallAccumulator, results *[][]byte) {
|
||||
stopThinkingContentBlock(param, results)
|
||||
stopTextContentBlock(param, results)
|
||||
|
||||
blockIndex := param.toolContentBlockIndex(openAIToolIndex)
|
||||
contentBlockStartJSON := []byte(`{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`)
|
||||
contentBlockStartJSON, _ = sjson.SetBytes(contentBlockStartJSON, "index", blockIndex)
|
||||
contentBlockStartJSON, _ = sjson.SetBytes(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
|
||||
contentBlockStartJSON, _ = sjson.SetBytes(contentBlockStartJSON, "content_block.name", accumulator.Name)
|
||||
*results = append(*results, translatorcommon.AppendSSEEventBytes(nil, "content_block_start", contentBlockStartJSON, 2))
|
||||
accumulator.StartEmitted = true
|
||||
param.SawToolCall = true
|
||||
}
|
||||
|
||||
func toolCallAccumulatorIndexes(accumulators map[int]*ToolCallAccumulator) []int {
|
||||
indexes := make([]int, 0, len(accumulators))
|
||||
for index := range accumulators {
|
||||
indexes = append(indexes, index)
|
||||
}
|
||||
sort.Ints(indexes)
|
||||
return indexes
|
||||
}
|
||||
|
||||
// ConvertOpenAIResponseToClaudeNonStream converts a non-streaming OpenAI response to a non-streaming Anthropic response.
|
||||
//
|
||||
// Parameters:
|
||||
|
||||
@@ -3,11 +3,108 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
type sseEvent struct {
|
||||
Type string
|
||||
Payload string
|
||||
}
|
||||
|
||||
func runStream(t *testing.T, originalReq string, chunks ...string) []sseEvent {
|
||||
t.Helper()
|
||||
|
||||
var paramAny any
|
||||
var emitted [][]byte
|
||||
for _, chunk := range chunks {
|
||||
emitted = append(emitted, ConvertOpenAIResponseToClaude(
|
||||
context.Background(),
|
||||
"",
|
||||
[]byte(originalReq),
|
||||
nil,
|
||||
[]byte("data: "+chunk),
|
||||
¶mAny,
|
||||
)...)
|
||||
}
|
||||
emitted = append(emitted, ConvertOpenAIResponseToClaude(
|
||||
context.Background(),
|
||||
"",
|
||||
[]byte(originalReq),
|
||||
nil,
|
||||
[]byte("data: [DONE]"),
|
||||
¶mAny,
|
||||
)...)
|
||||
|
||||
var events []sseEvent
|
||||
for _, raw := range emitted {
|
||||
s := string(raw)
|
||||
if !strings.HasPrefix(s, "event: ") {
|
||||
continue
|
||||
}
|
||||
nl := strings.Index(s, "\n")
|
||||
if nl < 0 {
|
||||
continue
|
||||
}
|
||||
typ := strings.TrimPrefix(s[:nl], "event: ")
|
||||
rest := s[nl+1:]
|
||||
if !strings.HasPrefix(rest, "data: ") {
|
||||
continue
|
||||
}
|
||||
payload := strings.TrimRight(strings.TrimPrefix(rest, "data: "), "\n")
|
||||
events = append(events, sseEvent{Type: typ, Payload: payload})
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func countByType(events []sseEvent, typ string) int {
|
||||
n := 0
|
||||
for _, e := range events {
|
||||
if e.Type == typ {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func toolUseStarts(events []sseEvent) []sseEvent {
|
||||
var out []sseEvent
|
||||
for _, e := range events {
|
||||
if e.Type != "content_block_start" {
|
||||
continue
|
||||
}
|
||||
if gjson.Get(e.Payload, "content_block.type").String() == "tool_use" {
|
||||
out = append(out, e)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func blockIndices(events []sseEvent) []int64 {
|
||||
var idx []int64
|
||||
for _, e := range events {
|
||||
if e.Type == "content_block_start" {
|
||||
idx = append(idx, gjson.Get(e.Payload, "index").Int())
|
||||
}
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
func lastStopReason(events []sseEvent) string {
|
||||
for i := len(events) - 1; i >= 0; i-- {
|
||||
if events[i].Type == "message_delta" {
|
||||
return gjson.Get(events[i].Payload, "delta.stop_reason").String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
const streamReq = `{"stream":true}`
|
||||
|
||||
func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing.T) {
|
||||
originalRequest := []byte(`{"stream":true}`)
|
||||
originalRequest := []byte(streamReq)
|
||||
var param any
|
||||
|
||||
firstChunks := ConvertOpenAIResponseToClaude(
|
||||
@@ -39,3 +136,231 @@ func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing
|
||||
t.Fatalf("did not expect null tool name delta to emit an empty tool name, got %s", string(secondOutput))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_EmptyNameThroughout(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":"","arguments":""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"","arguments":"{\"x\":1}"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
if got := len(toolUseStarts(events)); got != 0 {
|
||||
t.Fatalf("expected zero tool_use content_block_start, got %d (events=%+v)", got, events)
|
||||
}
|
||||
if got := countByType(events, "content_block_delta"); got != 0 {
|
||||
t.Fatalf("expected zero content_block_delta when start was suppressed, got %d", got)
|
||||
}
|
||||
if got := countByType(events, "content_block_stop"); got != 0 {
|
||||
t.Fatalf("expected zero content_block_stop when start was suppressed, got %d", got)
|
||||
}
|
||||
if got := lastStopReason(events); got == "tool_use" {
|
||||
t.Fatalf("stop_reason must not be tool_use when zero tool_use blocks were emitted; got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_NullName(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":null,"arguments":""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
if got := len(toolUseStarts(events)); got != 0 {
|
||||
t.Fatalf("null name must not produce a tool_use start; got %d", got)
|
||||
}
|
||||
if got := countByType(events, "content_block_stop"); got != 0 {
|
||||
t.Fatalf("null name must not produce content_block_stop; got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_NonStringName(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":123,"arguments":""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
if got := len(toolUseStarts(events)); got != 0 {
|
||||
t.Fatalf("non-string name must not produce a tool_use start; got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_RepeatedName(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":"do_it","arguments":""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"do_it","arguments":"{\"x\""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":"do_it","arguments":":1}"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 1 {
|
||||
t.Fatalf("expected exactly one tool_use start, got %d", len(starts))
|
||||
}
|
||||
if name := gjson.Get(starts[0].Payload, "content_block.name").String(); name != "do_it" {
|
||||
t.Fatalf("announced tool name = %q, want %q", name, "do_it")
|
||||
}
|
||||
if got := countByType(events, "content_block_stop"); got != 1 {
|
||||
t.Fatalf("expected exactly one content_block_stop, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_MixedSuppressedAndValid(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[
|
||||
{"index":0,"id":"call_skip","function":{"name":"","arguments":""}},
|
||||
{"index":1,"id":"call_real","function":{"name":"do_it","arguments":""}}
|
||||
]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[
|
||||
{"index":1,"function":{"arguments":"{}"}}
|
||||
]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 1 {
|
||||
t.Fatalf("expected exactly one tool_use start, got %d", len(starts))
|
||||
}
|
||||
if got := countByType(events, "content_block_stop"); got != 1 {
|
||||
t.Fatalf("expected exactly one content_block_stop, got %d", got)
|
||||
}
|
||||
|
||||
indices := blockIndices(events)
|
||||
if len(indices) == 0 || indices[0] != 0 {
|
||||
t.Fatalf("first content_block_start index must be 0, got %v", indices)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_EmptyIDDeferStart(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"","function":{"name":"do_it","arguments":""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_real","function":{"arguments":"{}"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 1 {
|
||||
t.Fatalf("expected exactly one tool_use start once id arrived, got %d", len(starts))
|
||||
}
|
||||
if id := gjson.Get(starts[0].Payload, "content_block.id").String(); id != "call_real" {
|
||||
t.Fatalf("announced tool id = %q, want %q", id, "call_real")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_IDInDeltaWithoutFunction(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"function":{"name":"do_it"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_real"}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 1 {
|
||||
t.Fatalf("expected exactly one tool_use start when id arrives in a function-less delta, got %d", len(starts))
|
||||
}
|
||||
if id := gjson.Get(starts[0].Payload, "content_block.id").String(); id != "call_real" {
|
||||
t.Fatalf("announced tool id = %q, want %q", id, "call_real")
|
||||
}
|
||||
if name := gjson.Get(starts[0].Payload, "content_block.name").String(); name != "do_it" {
|
||||
t.Fatalf("announced tool name = %q, want %q", name, "do_it")
|
||||
}
|
||||
if got := countByType(events, "content_block_stop"); got != 1 {
|
||||
t.Fatalf("expected exactly one content_block_stop, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_StopReasonWithEmittedTool(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_a","function":{"name":"do_it","arguments":"{}"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`,
|
||||
)
|
||||
if got := lastStopReason(events); got != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want %q", got, "tool_use")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_StopReasonWhenIDNeverArrives(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"function":{"name":"do_it","arguments":""}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 1 {
|
||||
t.Fatalf("expected one belated tool_use start with synthetic id, got %d", len(starts))
|
||||
}
|
||||
id := gjson.Get(starts[0].Payload, "content_block.id").String()
|
||||
if !strings.HasPrefix(id, "toolu_") {
|
||||
t.Fatalf("synthetic id should match toolu_<nanos>_<n>, got %q", id)
|
||||
}
|
||||
if name := gjson.Get(starts[0].Payload, "content_block.name").String(); name != "do_it" {
|
||||
t.Fatalf("announced tool name = %q, want %q", name, "do_it")
|
||||
}
|
||||
if got := lastStopReason(events); got != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want %q", got, "tool_use")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_BelatedStartsUseOpenAIToolIndexOrder(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[
|
||||
{"index":2,"function":{"name":"third_tool","arguments":"{}"}},
|
||||
{"index":0,"function":{"name":"first_tool","arguments":"{}"}},
|
||||
{"index":1,"function":{"name":"second_tool","arguments":"{}"}}
|
||||
]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 3 {
|
||||
t.Fatalf("expected three belated tool_use starts, got %d", len(starts))
|
||||
}
|
||||
|
||||
wantNames := []string{"first_tool", "second_tool", "third_tool"}
|
||||
for i, wantName := range wantNames {
|
||||
if name := gjson.Get(starts[i].Payload, "content_block.name").String(); name != wantName {
|
||||
t.Fatalf("tool_use start %d name = %q, want %q (starts=%+v)", i, name, wantName, starts)
|
||||
}
|
||||
if blockIndex := gjson.Get(starts[i].Payload, "index").Int(); blockIndex != int64(i) {
|
||||
t.Fatalf("tool_use start %d block index = %d, want %d", i, blockIndex, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_LateIDAfterFinalization(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"function":{"name":"do_it"}}]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":1,"completion_tokens":1}}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_late"}]}}]}`,
|
||||
)
|
||||
|
||||
starts := toolUseStarts(events)
|
||||
if len(starts) != 1 {
|
||||
t.Fatalf("expected one belated tool_use start, got %d", len(starts))
|
||||
}
|
||||
|
||||
var sawMessageStop bool
|
||||
for _, e := range events {
|
||||
if e.Type == "message_stop" {
|
||||
sawMessageStop = true
|
||||
continue
|
||||
}
|
||||
if sawMessageStop {
|
||||
switch e.Type {
|
||||
case "content_block_start", "content_block_delta", "content_block_stop":
|
||||
t.Fatalf("event %q emitted after message_stop (events=%+v)", e.Type, events)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamingTool_StopReasonMixedSuppressedAndValid(t *testing.T) {
|
||||
events := runStream(t, streamReq,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[
|
||||
{"index":0,"id":"call_skip","function":{"name":"","arguments":""}},
|
||||
{"index":1,"id":"call_real","function":{"name":"do_it","arguments":"{}"}}
|
||||
]}}]}`,
|
||||
`{"id":"c1","model":"m","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`,
|
||||
)
|
||||
if got := lastStopReason(events); got != "tool_use" {
|
||||
t.Fatalf("stop_reason = %q, want %q", got, "tool_use")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
const claudeCodeAttributionSystemPrefix = "x-anthropic-billing-header:"
|
||||
|
||||
// IsClaudeCodeAttributionSystemText reports whether text is the Claude Code
|
||||
// attribution block that carries per-request billing and prompt fingerprint data.
|
||||
func IsClaudeCodeAttributionSystemText(text string) bool {
|
||||
text = strings.TrimLeftFunc(text, unicode.IsSpace)
|
||||
return strings.HasPrefix(text, claudeCodeAttributionSystemPrefix)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsClaudeCodeAttributionSystemText(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "Claude Code attribution block",
|
||||
text: "x-anthropic-billing-header: cc_version=2.1.63.abc; cc_entrypoint=cli; cch=12345;",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "leading whitespace",
|
||||
text: "\n\t x-anthropic-billing-header: cc_version=2.1.63.abc; cch=12345;",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "regular system prompt",
|
||||
text: "You are helpful.",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty text",
|
||||
text: "",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsClaudeCodeAttributionSystemText(tt.text); got != tt.want {
|
||||
t.Fatalf("IsClaudeCodeAttributionSystemText(%q) = %v, want %v", tt.text, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
@@ -20,7 +21,7 @@ func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias) + "|" + fmt.Sprintf("image=%t", model.Image))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
|
||||
@@ -25,6 +25,17 @@ func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_IncludesImageFlag(t *testing.T) {
|
||||
textModel := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-image", Alias: "image"}})
|
||||
imageModel := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-image", Alias: "image", Image: true}})
|
||||
if textModel == "" || imageModel == "" {
|
||||
t.Fatal("hashes should not be empty")
|
||||
}
|
||||
if textModel == imageModel {
|
||||
t.Fatal("hash should change when image flag changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
|
||||
a := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
|
||||
@@ -153,7 +153,7 @@ func openAICompatSignature(entry config.OpenAICompatibility) string {
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
|
||||
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias)+"|"+fmt.Sprintf("image=%t", model.Image))
|
||||
}
|
||||
if len(models) > 0 {
|
||||
sort.Strings(models)
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v7/internal/constant"
|
||||
@@ -257,6 +259,15 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
if errMsg, okPendingErr := pendingClaudeStreamError(errChan); okPendingErr {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
@@ -282,6 +293,21 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
}
|
||||
}
|
||||
|
||||
func pendingClaudeStreamError(errs <-chan *interfaces.ErrorMessage) (*interfaces.ErrorMessage, bool) {
|
||||
if errs == nil {
|
||||
return nil, false
|
||||
}
|
||||
select {
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return errMsg, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
@@ -317,11 +343,135 @@ type claudeErrorResponse struct {
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) toClaudeError(msg *interfaces.ErrorMessage) claudeErrorResponse {
|
||||
status := http.StatusInternalServerError
|
||||
errText := http.StatusText(status)
|
||||
if msg != nil {
|
||||
if msg.StatusCode > 0 {
|
||||
status = msg.StatusCode
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
if msg.Error != nil {
|
||||
if v := strings.TrimSpace(msg.Error.Error()); v != "" {
|
||||
errText = v
|
||||
}
|
||||
}
|
||||
}
|
||||
errType, message := claudeErrorDetailFromText(status, errText)
|
||||
return claudeErrorResponse{
|
||||
Type: "error",
|
||||
Error: claudeErrorDetail{
|
||||
Type: "api_error",
|
||||
Message: msg.Error.Error(),
|
||||
Type: errType,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||
status := http.StatusInternalServerError
|
||||
if msg != nil && msg.StatusCode > 0 {
|
||||
status = msg.StatusCode
|
||||
}
|
||||
if msg != nil && msg.Addon != nil && handlers.PassthroughHeadersEnabled(h.Cfg) {
|
||||
for key, values := range msg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
c.Writer.Header().Del(key)
|
||||
for _, value := range values {
|
||||
c.Writer.Header().Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(h.toClaudeError(msg))
|
||||
if err != nil {
|
||||
body = []byte(`{"type":"error","error":{"type":"api_error","message":"Internal Server Error"}}`)
|
||||
}
|
||||
appendClaudeAPIResponse(c, body)
|
||||
if !c.Writer.Written() {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Status(status)
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
|
||||
func claudeErrorDetailFromText(status int, errText string) (string, string) {
|
||||
message := strings.TrimSpace(errText)
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
errType := claudeErrorTypeFromStatus(status)
|
||||
|
||||
var payload map[string]any
|
||||
if json.Valid([]byte(message)) {
|
||||
if err := json.Unmarshal([]byte(message), &payload); err == nil {
|
||||
if e, ok := payload["error"].(map[string]any); ok {
|
||||
if t, ok := e["type"].(string); ok && strings.TrimSpace(t) != "" {
|
||||
errType = strings.TrimSpace(t)
|
||||
}
|
||||
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
} else if c, ok := e["code"].(string); ok && strings.TrimSpace(c) != "" {
|
||||
message = strings.TrimSpace(c)
|
||||
}
|
||||
} else {
|
||||
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) != "" && strings.TrimSpace(t) != "error" {
|
||||
errType = strings.TrimSpace(t)
|
||||
}
|
||||
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errType, message
|
||||
}
|
||||
|
||||
func claudeErrorTypeFromStatus(status int) string {
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
return "authentication_error"
|
||||
case http.StatusPaymentRequired:
|
||||
return "billing_error"
|
||||
case http.StatusForbidden:
|
||||
return "permission_error"
|
||||
case http.StatusNotFound:
|
||||
return "not_found_error"
|
||||
case http.StatusRequestEntityTooLarge:
|
||||
return "request_too_large"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_error"
|
||||
case http.StatusGatewayTimeout:
|
||||
return "timeout_error"
|
||||
case 529:
|
||||
return "overloaded_error"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "api_error"
|
||||
}
|
||||
return "invalid_request_error"
|
||||
}
|
||||
}
|
||||
|
||||
func appendClaudeAPIResponse(c *gin.Context, data []byte) {
|
||||
if c == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); !exists {
|
||||
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
|
||||
combined = append(combined, existingBytes...)
|
||||
if existingBytes[len(existingBytes)-1] != '\n' {
|
||||
combined = append(combined, '\n')
|
||||
}
|
||||
combined = append(combined, data...)
|
||||
c.Set("API_RESPONSE", combined)
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("API_RESPONSE", bytes.Clone(data))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestClaudeErrorExtractsOpenAIStyleUpstreamJSON(t *testing.T) {
|
||||
handler := &ClaudeCodeAPIHandler{}
|
||||
msg := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: errors.New(`{"error":{"message":"Your input exceeds the context window of this model. Please adjust your input and try again.","type":"invalid_request_error","code":"context_too_large"}}`),
|
||||
}
|
||||
|
||||
got := handler.toClaudeError(msg)
|
||||
|
||||
if got.Type != "error" {
|
||||
t.Fatalf("type = %q, want error", got.Type)
|
||||
}
|
||||
if got.Error.Type != "invalid_request_error" {
|
||||
t.Fatalf("error.type = %q, want invalid_request_error", got.Error.Type)
|
||||
}
|
||||
if got.Error.Message != "Your input exceeds the context window of this model. Please adjust your input and try again." {
|
||||
t.Fatalf("error.message = %q", got.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeErrorExtractsClaudeStyleUpstreamJSON(t *testing.T) {
|
||||
handler := &ClaudeCodeAPIHandler{}
|
||||
msg := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Error: errors.New(`{"type":"error","error":{"type":"rate_limit_error","message":"This request would exceed your account's rate limit. Please try again later."},"request_id":"req_123"}`),
|
||||
}
|
||||
|
||||
got := handler.toClaudeError(msg)
|
||||
|
||||
if got.Error.Type != "rate_limit_error" {
|
||||
t.Fatalf("error.type = %q, want rate_limit_error", got.Error.Type)
|
||||
}
|
||||
if got.Error.Message != "This request would exceed your account's rate limit. Please try again later." {
|
||||
t.Fatalf("error.message = %q", got.Error.Message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteClaudeErrorResponseUsesClaudeEnvelope(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
handler := &ClaudeCodeAPIHandler{}
|
||||
msg := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: errors.New(`{"error":{"message":"Your input exceeds the context window of this model. Please adjust your input and try again.","type":"invalid_request_error","code":"context_too_large"}}`),
|
||||
}
|
||||
|
||||
handler.WriteErrorResponse(c, msg)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusBadRequest)
|
||||
}
|
||||
body := recorder.Body.Bytes()
|
||||
if got := gjson.GetBytes(body, "type").String(); got != "error" {
|
||||
t.Fatalf("type = %q, want error; body=%s", got, body)
|
||||
}
|
||||
if got := gjson.GetBytes(body, "error.type").String(); got != "invalid_request_error" {
|
||||
t.Fatalf("error.type = %q, want invalid_request_error; body=%s", got, body)
|
||||
}
|
||||
if got := gjson.GetBytes(body, "error.message").String(); got != "Your input exceeds the context window of this model. Please adjust your input and try again." {
|
||||
t.Fatalf("error.message = %q; body=%s", got, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPendingClaudeStreamErrorUsesBufferedError(t *testing.T) {
|
||||
wantErr := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: errors.New(`{"error":{"message":"Your input exceeds the context window of this model. Please adjust your input and try again.","type":"invalid_request_error","code":"context_too_large"}}`),
|
||||
}
|
||||
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||
errs <- wantErr
|
||||
close(errs)
|
||||
|
||||
gotErr, ok := pendingClaudeStreamError(errs)
|
||||
if !ok {
|
||||
t.Fatal("expected pending stream error")
|
||||
}
|
||||
if gotErr != wantErr {
|
||||
t.Fatalf("pending error = %p, want %p", gotErr, wantErr)
|
||||
}
|
||||
}
|
||||
@@ -231,6 +231,17 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
return meta
|
||||
}
|
||||
|
||||
func setReasoningEffortMetadata(meta map[string]any, handlerType, model string, rawJSON []byte) {
|
||||
if meta == nil {
|
||||
return
|
||||
}
|
||||
effort := thinking.ExtractReasoningEffort(rawJSON, handlerType, model)
|
||||
if effort == "" {
|
||||
return
|
||||
}
|
||||
meta[coreexecutor.ReasoningEffortMetadataKey] = effort
|
||||
}
|
||||
|
||||
// headersFromContext extracts the original HTTP request headers from the gin context
|
||||
// embedded in the provided context. This allows session affinity selectors to read
|
||||
// client headers like X-Amp-Thread-Id.
|
||||
@@ -400,6 +411,7 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
newCtx = logging.WithEndpoint(newCtx, endpoint)
|
||||
}
|
||||
newCtx = logging.WithResponseStatusHolder(newCtx)
|
||||
newCtx = logging.WithResponseHeadersHolder(newCtx)
|
||||
|
||||
cancelCtx := newCtx
|
||||
if requestCtx != nil && requestCtx != parentCtx {
|
||||
@@ -534,12 +546,22 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
|
||||
}
|
||||
|
||||
// ExecuteImageWithAuthManager executes an OpenAI-compatible image endpoint request.
|
||||
func (h *BaseAPIHandler) ExecuteImageWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) ([]byte, http.Header, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
|
||||
if errMsg != nil {
|
||||
return nil, nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
|
||||
setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON)
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
@@ -588,6 +610,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
|
||||
setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON)
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
@@ -631,7 +654,16 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
// This path is the only supported execution route.
|
||||
// The returned http.Header carries upstream response headers captured before streaming begins.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
|
||||
}
|
||||
|
||||
// ExecuteImageStreamWithAuthManager executes a streaming OpenAI-compatible image endpoint request.
|
||||
func (h *BaseAPIHandler) ExecuteImageStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- errMsg
|
||||
@@ -640,6 +672,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
|
||||
setReasoningEffortMetadata(reqMeta, handlerType, normalizedModel, rawJSON)
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
@@ -847,6 +880,10 @@ func statusFromError(err error) int {
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
|
||||
return h.getRequestDetailsWithOptions(modelName, false)
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetailsWithOptions(modelName string, allowImageModel bool) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
|
||||
resolvedModelName := modelName
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
@@ -871,10 +908,10 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||
|
||||
if strings.EqualFold(baseModel, "gpt-image-2") {
|
||||
if strings.EqualFold(routeModelBaseName(baseModel), "gpt-image-2") && !allowImageModel {
|
||||
return nil, "", &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", baseModel),
|
||||
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", routeModelBaseName(baseModel)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -901,6 +938,14 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
return providers, resolvedModelName, nil
|
||||
}
|
||||
|
||||
func routeModelBaseName(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 {
|
||||
return strings.TrimSpace(model[idx+1:])
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -18,3 +18,23 @@ func TestRequestExecutionMetadataIncludesExecutionSessionWithoutIdempotencyKey(t
|
||||
t.Fatalf("unexpected idempotency key in metadata: %v", meta[idempotencyKeyMetadataKey])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetReasoningEffortMetadataUsesSuffixOverBody(t *testing.T) {
|
||||
meta := make(map[string]any)
|
||||
|
||||
setReasoningEffortMetadata(meta, "openai", "gpt-5.4(high)", []byte(`{"reasoning_effort":"low"}`))
|
||||
|
||||
if got := meta[coreexecutor.ReasoningEffortMetadataKey]; got != "high" {
|
||||
t.Fatalf("ReasoningEffortMetadataKey = %v, want %q", got, "high")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetReasoningEffortMetadataSupportsOpenAIResponses(t *testing.T) {
|
||||
meta := make(map[string]any)
|
||||
|
||||
setReasoningEffortMetadata(meta, "openai-response", "gpt-5.4", []byte(`{"reasoning":{"effort":"medium"}}`))
|
||||
|
||||
if got := meta[coreexecutor.ReasoningEffortMetadataKey]; got != "medium" {
|
||||
t.Fatalf("ReasoningEffortMetadataKey = %v, want %q", got, "medium")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,14 @@ var (
|
||||
codexClientModelTemplatesErr error
|
||||
)
|
||||
|
||||
var codexClientAllowedReasoningLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
"xhigh": {},
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) codexClientModelsResponse() map[string]any {
|
||||
return CodexClientModelsResponse(h.Models())
|
||||
}
|
||||
@@ -45,6 +53,7 @@ func buildCodexClientModels(models []map[string]any) []map[string]any {
|
||||
|
||||
if template, ok := templates[id]; ok {
|
||||
entry := cloneCodexClientModelMap(template)
|
||||
sanitizeCodexClientReasoningMetadata(entry)
|
||||
applyCodexClientVisibilityOverride(entry, id)
|
||||
result = append(result, entry)
|
||||
continue
|
||||
@@ -52,6 +61,7 @@ func buildCodexClientModels(models []map[string]any) []map[string]any {
|
||||
|
||||
entry := cloneCodexClientModelMap(defaultTemplate)
|
||||
applyCodexClientModelMetadata(entry, id, model)
|
||||
sanitizeCodexClientReasoningMetadata(entry)
|
||||
applyCodexClientVisibilityOverride(entry, id)
|
||||
result = append(result, entry)
|
||||
}
|
||||
@@ -104,6 +114,9 @@ func applyCodexClientModelMetadata(entry map[string]any, id string, model map[st
|
||||
if info.ContextLength > 0 {
|
||||
contextWindow = info.ContextLength
|
||||
}
|
||||
if info.Type == registry.OpenAIImageModelType {
|
||||
entry["visibility"] = "hide"
|
||||
}
|
||||
applyCodexClientThinkingMetadata(entry, info.Thinking)
|
||||
}
|
||||
|
||||
@@ -150,12 +163,16 @@ func applyCodexClientThinkingMetadata(entry map[string]any, thinking *registry.T
|
||||
|
||||
levels := make([]any, 0, len(thinking.Levels))
|
||||
defaultLevel := ""
|
||||
firstLevel := ""
|
||||
for _, rawLevel := range thinking.Levels {
|
||||
level := strings.ToLower(strings.TrimSpace(rawLevel))
|
||||
if level == "" || level == "none" {
|
||||
level := normalizeCodexClientReasoningLevel(rawLevel)
|
||||
if level == "" {
|
||||
continue
|
||||
}
|
||||
if defaultLevel == "" || level == "medium" {
|
||||
if firstLevel == "" {
|
||||
firstLevel = level
|
||||
}
|
||||
if (defaultLevel == "" && level != "none") || level == "medium" {
|
||||
defaultLevel = level
|
||||
}
|
||||
levels = append(levels, map[string]any{
|
||||
@@ -166,15 +183,64 @@ func applyCodexClientThinkingMetadata(entry map[string]any, thinking *registry.T
|
||||
if len(levels) == 0 {
|
||||
return
|
||||
}
|
||||
if defaultLevel == "" {
|
||||
defaultLevel = firstLevel
|
||||
}
|
||||
|
||||
entry["supported_reasoning_levels"] = levels
|
||||
entry["default_reasoning_level"] = defaultLevel
|
||||
}
|
||||
|
||||
func sanitizeCodexClientReasoningMetadata(entry map[string]any) {
|
||||
rawLevels, ok := entry["supported_reasoning_levels"].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
levels := make([]any, 0, len(rawLevels))
|
||||
allowedDefaults := make(map[string]struct{}, len(rawLevels))
|
||||
for _, rawLevelEntry := range rawLevels {
|
||||
levelEntry, ok := rawLevelEntry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
level := normalizeCodexClientReasoningLevel(stringModelValue(levelEntry, "effort"))
|
||||
if level == "" {
|
||||
continue
|
||||
}
|
||||
clonedEntry := cloneCodexClientModelMap(levelEntry)
|
||||
clonedEntry["effort"] = level
|
||||
levels = append(levels, clonedEntry)
|
||||
allowedDefaults[level] = struct{}{}
|
||||
}
|
||||
|
||||
if len(levels) == 0 {
|
||||
delete(entry, "supported_reasoning_levels")
|
||||
delete(entry, "default_reasoning_level")
|
||||
return
|
||||
}
|
||||
|
||||
defaultLevel := normalizeCodexClientReasoningLevel(stringModelValue(entry, "default_reasoning_level"))
|
||||
if _, ok := allowedDefaults[defaultLevel]; !ok {
|
||||
defaultLevel = stringModelValue(levels[0].(map[string]any), "effort")
|
||||
}
|
||||
|
||||
entry["supported_reasoning_levels"] = levels
|
||||
entry["default_reasoning_level"] = defaultLevel
|
||||
}
|
||||
|
||||
func normalizeCodexClientReasoningLevel(rawLevel string) string {
|
||||
level := strings.ToLower(strings.TrimSpace(rawLevel))
|
||||
if _, ok := codexClientAllowedReasoningLevels[level]; !ok {
|
||||
return ""
|
||||
}
|
||||
return level
|
||||
}
|
||||
|
||||
func codexClientReasoningDescription(level string) string {
|
||||
switch level {
|
||||
case "minimal":
|
||||
return "Fastest responses with minimal reasoning"
|
||||
case "none":
|
||||
return "No reasoning"
|
||||
case "low":
|
||||
return "Fast responses with lighter reasoning"
|
||||
case "medium":
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -143,7 +145,20 @@ func isSupportedImagesModel(model string) bool {
|
||||
if baseModel == defaultImagesToolModel {
|
||||
return true
|
||||
}
|
||||
return isXAIImagesModel(model)
|
||||
return isXAIImagesModel(model) || isOpenAICompatImagesModel(model)
|
||||
}
|
||||
|
||||
func isDefaultImagesToolModel(model string) bool {
|
||||
return imagesModelBase(model) == defaultImagesToolModel
|
||||
}
|
||||
|
||||
func isOpenAICompatImagesModel(model string) bool {
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
return false
|
||||
}
|
||||
info := registry.LookupModelInfo(model)
|
||||
return info != nil && info.Type == registry.OpenAIImageModelType
|
||||
}
|
||||
|
||||
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
|
||||
@@ -153,7 +168,7 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
|
||||
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
|
||||
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, %s, or a configured openai-compatibility image model.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
@@ -376,6 +391,90 @@ func multipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
|
||||
return "data:" + mediaType + ";base64," + b64, nil
|
||||
}
|
||||
|
||||
func buildOpenAICompatImagesJSONRequest(rawJSON []byte, imageModel string, stream bool) []byte {
|
||||
payload := rawJSON
|
||||
if model := strings.TrimSpace(imageModel); model != "" {
|
||||
payload, _ = sjson.SetBytes(payload, "model", model)
|
||||
}
|
||||
if stream {
|
||||
payload, _ = sjson.SetBytes(payload, "stream", true)
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "stream")
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func cloneMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
|
||||
dst := make(textproto.MIMEHeader, len(src))
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func buildOpenAICompatImagesMultipartRequest(form *multipart.Form, imageModel string, stream bool) ([]byte, string, error) {
|
||||
if form == nil {
|
||||
return nil, "", fmt.Errorf("multipart form is nil")
|
||||
}
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
|
||||
if errWrite := writer.WriteField("model", imageModel); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
|
||||
}
|
||||
if stream {
|
||||
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
|
||||
}
|
||||
}
|
||||
for key, values := range form.Value {
|
||||
if key == "model" || key == "stream" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
if errWrite := writer.WriteField(key, value); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for key, files := range form.File {
|
||||
for _, fileHeader := range files {
|
||||
if fileHeader == nil {
|
||||
continue
|
||||
}
|
||||
header := cloneMIMEHeader(fileHeader.Header)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
|
||||
if header.Get("Content-Type") == "" {
|
||||
header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
|
||||
}
|
||||
src, errOpen := fileHeader.Open()
|
||||
if errOpen != nil {
|
||||
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
|
||||
}
|
||||
_, errCopy := io.Copy(part, src)
|
||||
if errClose := src.Close(); errClose != nil {
|
||||
log.Errorf("openai images: close upload file error: %v", errClose)
|
||||
if errCopy == nil {
|
||||
errCopy = errClose
|
||||
}
|
||||
}
|
||||
if errCopy != nil {
|
||||
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
|
||||
}
|
||||
return body.Bytes(), writer.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
func parseIntField(raw string, fallback int64) int64 {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
@@ -454,11 +553,21 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
}
|
||||
stream := gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
|
||||
if isDefaultImagesToolModel(imageModel) {
|
||||
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
|
||||
h.handleRoutedImages(c, imageReq, imageModel, stream)
|
||||
return
|
||||
}
|
||||
if isXAIImagesModel(imageModel) {
|
||||
xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat)
|
||||
h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream)
|
||||
return
|
||||
}
|
||||
if isOpenAICompatImagesModel(imageModel) {
|
||||
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
|
||||
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_generation", stream)
|
||||
return
|
||||
}
|
||||
|
||||
tool := []byte(`{"type":"image_generation","action":"generate"}`)
|
||||
tool, _ = sjson.SetBytes(tool, "model", imageModel)
|
||||
@@ -589,6 +698,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
}
|
||||
stream := parseBoolField(c.PostForm("stream"), false)
|
||||
|
||||
if isDefaultImagesToolModel(imageModel) {
|
||||
imageReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
|
||||
if errBuild != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", errBuild),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Request.Header.Set("Content-Type", contentType)
|
||||
h.handleRoutedImages(c, imageReq, imageModel, stream)
|
||||
return
|
||||
}
|
||||
if isXAIImagesModel(imageModel) {
|
||||
aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "")
|
||||
aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio)
|
||||
@@ -598,6 +722,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
|
||||
return
|
||||
}
|
||||
if isOpenAICompatImagesModel(imageModel) {
|
||||
compatReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
|
||||
if errBuild != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", errBuild),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Request.Header.Set("Content-Type", contentType)
|
||||
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
|
||||
return
|
||||
}
|
||||
|
||||
var maskDataURL *string
|
||||
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
|
||||
@@ -701,6 +840,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
}
|
||||
stream := gjson.GetBytes(rawJSON, "stream").Bool()
|
||||
|
||||
if isDefaultImagesToolModel(imageModel) {
|
||||
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
|
||||
h.handleRoutedImages(c, imageReq, imageModel, stream)
|
||||
return
|
||||
}
|
||||
if isXAIImagesModel(imageModel) {
|
||||
images := collectXAIImagesFromJSON(rawJSON)
|
||||
if len(images) == 0 {
|
||||
@@ -717,6 +861,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
|
||||
return
|
||||
}
|
||||
if isOpenAICompatImagesModel(imageModel) {
|
||||
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
|
||||
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
|
||||
return
|
||||
}
|
||||
|
||||
var images []string
|
||||
imagesResult := gjson.GetBytes(rawJSON, "images")
|
||||
@@ -904,14 +1053,247 @@ func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, respon
|
||||
h.collectXAIImages(c, xaiReq, responseFormat)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string, responseFormat string, streamPrefix string, stream bool) {
|
||||
if stream {
|
||||
h.streamOpenAICompatImages(c, compatReq, imageModel)
|
||||
return
|
||||
}
|
||||
h.collectImagesWithModel(c, compatReq, imageModel, responseFormat)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) handleRoutedImages(c *gin.Context, imageReq []byte, imageModel string, stream bool) {
|
||||
if stream {
|
||||
h.streamRoutedImages(c, imageReq, imageModel)
|
||||
return
|
||||
}
|
||||
h.collectRoutedImages(c, imageReq, imageModel)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) collectRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
model := strings.TrimSpace(imageModel)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteImageWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg.Error != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel(nil)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
|
||||
model := strings.TrimSpace(imageModel)
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteImageStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
flusher.Flush()
|
||||
h.forwardRawImageStream(cliCtx, c, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Context, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
emitError := func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case <-ctx.Done():
|
||||
cancel(ctx.Err())
|
||||
return
|
||||
case errMsg, ok := <-errs:
|
||||
if ok && errMsg != nil {
|
||||
emitError(errMsg)
|
||||
cancel(errMsg.Error)
|
||||
return
|
||||
}
|
||||
errs = nil
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
if flusher, ok := c.Writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
model := strings.TrimSpace(imageModel)
|
||||
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, compatReq, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
setSSEHeaders()
|
||||
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
flusher.Flush()
|
||||
h.ForwardStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(next []byte) {
|
||||
_, _ = c.Writer.Write(next)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) {
|
||||
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
|
||||
h.collectImagesWithModel(c, xaiReq, model, responseFormat)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) collectImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||
|
||||
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
|
||||
model = strings.TrimSpace(model)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
|
||||
stopKeepAlive()
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
@@ -937,6 +1319,11 @@ func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, respo
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) {
|
||||
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
|
||||
h.streamImagesWithModel(c, xaiReq, model, responseFormat, streamPrefix)
|
||||
}
|
||||
|
||||
func (h *OpenAIAPIHandler) streamImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string, streamPrefix string) {
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
@@ -949,8 +1336,8 @@ func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, respon
|
||||
}
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
|
||||
model = strings.TrimSpace(model)
|
||||
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg.Error != nil {
|
||||
|
||||
@@ -3,14 +3,17 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -40,7 +43,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
|
||||
}
|
||||
|
||||
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
|
||||
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "."
|
||||
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", " + xaiImagesQualityModel + ", or a configured openai-compatibility image model."
|
||||
if message != expectedMessage {
|
||||
t.Fatalf("error message = %q, want %q", message, expectedMessage)
|
||||
}
|
||||
@@ -63,6 +66,25 @@ func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagesModelValidationAllowsOpenAICompatImageModels(t *testing.T) {
|
||||
modelRegistry := registry.GetGlobalRegistry()
|
||||
clientID := "test-openai-compat-image-model-validation"
|
||||
modelRegistry.RegisterClient(clientID, "openai-compatibility", []*registry.ModelInfo{
|
||||
{ID: "compat-image-model", Object: "model", OwnedBy: "compat", Type: registry.OpenAIImageModelType},
|
||||
{ID: "compat-chat-model", Object: "model", OwnedBy: "compat", Type: "openai-compatibility"},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
modelRegistry.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
if !isSupportedImagesModel("compat-image-model") {
|
||||
t.Fatal("expected configured openai-compatibility image model to be supported")
|
||||
}
|
||||
if isSupportedImagesModel("compat-chat-model") {
|
||||
t.Fatal("expected non-image openai-compatibility model to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildXAIImagesGenerationsRequest(t *testing.T) {
|
||||
rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`)
|
||||
|
||||
@@ -122,6 +144,100 @@ func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatImagesJSONRequestPreservesStreamForStreaming(t *testing.T) {
|
||||
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":false}`), "upstream-image", true)
|
||||
|
||||
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
|
||||
}
|
||||
if !gjson.GetBytes(req, "stream").Bool() {
|
||||
t.Fatalf("stream flag missing: %s", string(req))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatImagesJSONRequestDropsStreamForNonStreaming(t *testing.T) {
|
||||
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":true}`), "upstream-image", false)
|
||||
|
||||
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
|
||||
}
|
||||
if gjson.GetBytes(req, "stream").Exists() {
|
||||
t.Fatalf("stream flag should be removed from non-streaming request: %s", string(req))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatImagesMultipartRequestPreservesStreamAndFileContentType(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
|
||||
t.Fatalf("write model field: %v", errWrite)
|
||||
}
|
||||
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
|
||||
t.Fatalf("write stream field: %v", errWrite)
|
||||
}
|
||||
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
|
||||
t.Fatalf("write prompt field: %v", errWrite)
|
||||
}
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
|
||||
header.Set("Content-Type", "image/png")
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
t.Fatalf("create image field: %v", errCreate)
|
||||
}
|
||||
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
|
||||
t.Fatalf("write image field: %v", errWrite)
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
t.Fatalf("close multipart writer: %v", errClose)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), writer.Boundary())
|
||||
form, errRead := reader.ReadForm(32 << 20)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read source form: %v", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := form.RemoveAll(); errRemove != nil {
|
||||
t.Fatalf("remove source form files: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
|
||||
out, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, "upstream-image", true)
|
||||
if errBuild != nil {
|
||||
t.Fatalf("buildOpenAICompatImagesMultipartRequest error: %v", errBuild)
|
||||
}
|
||||
mediaType, params, errParse := mime.ParseMediaType(contentType)
|
||||
if errParse != nil {
|
||||
t.Fatalf("parse content type: %v", errParse)
|
||||
}
|
||||
if mediaType != "multipart/form-data" {
|
||||
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
|
||||
}
|
||||
rewrittenReader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
|
||||
rewrittenForm, errRead := rewrittenReader.ReadForm(32 << 20)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read rewritten form: %v", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := rewrittenForm.RemoveAll(); errRemove != nil {
|
||||
t.Fatalf("remove rewritten form files: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
if got := rewrittenForm.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
|
||||
t.Fatalf("model values = %#v, want upstream-image", got)
|
||||
}
|
||||
if got := rewrittenForm.Value["stream"]; len(got) != 1 || got[0] != "true" {
|
||||
t.Fatalf("stream values = %#v, want true", got)
|
||||
}
|
||||
if got := rewrittenForm.Value["prompt"]; len(got) != 1 || got[0] != "edit" {
|
||||
t.Fatalf("prompt values = %#v, want edit", got)
|
||||
}
|
||||
if got := rewrittenForm.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/png" {
|
||||
t.Fatalf("image headers = %#v, want image/png", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildImagesAPIResponseFromXAI(t *testing.T) {
|
||||
payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`)
|
||||
|
||||
|
||||
@@ -177,12 +177,15 @@ waitForCallback:
|
||||
if accessToken != "" {
|
||||
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
|
||||
if errProject != nil {
|
||||
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
|
||||
return nil, fmt.Errorf("antigravity: failed to fetch project ID: %w", errProject)
|
||||
} else {
|
||||
projectID = fetchedProjectID
|
||||
log.Infof("antigravity: obtained project ID %s", projectID)
|
||||
log.Infof("antigravity: obtained project ID %s", util.HideAPIKey(projectID))
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
return nil, fmt.Errorf("antigravity: project ID discovery returned empty project")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
metadata := map[string]any{
|
||||
@@ -208,7 +211,7 @@ waitForCallback:
|
||||
|
||||
fmt.Println("Antigravity authentication successful")
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
fmt.Printf("Using GCP project: %s\n", util.HideAPIKey(projectID))
|
||||
}
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
|
||||
@@ -4,12 +4,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type antigravityCreditsFallbackExecutor struct {
|
||||
@@ -48,6 +50,43 @@ func (e *antigravityCreditsFallbackExecutor) HttpRequest(context.Context, *Auth,
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
|
||||
}
|
||||
|
||||
type codexOnlyFailureExecutor struct{}
|
||||
|
||||
func (codexOnlyFailureExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (codexOnlyFailureExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
|
||||
}
|
||||
|
||||
func (codexOnlyFailureExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
|
||||
}
|
||||
|
||||
func (codexOnlyFailureExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (codexOnlyFailureExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
|
||||
}
|
||||
|
||||
func (codexOnlyFailureExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusTooManyRequests, Message: "codex quota exhausted"}
|
||||
}
|
||||
|
||||
type captureLogHook struct {
|
||||
messages []string
|
||||
}
|
||||
|
||||
func (h *captureLogHook) Levels() []log.Level {
|
||||
return log.AllLevels
|
||||
}
|
||||
|
||||
func (h *captureLogHook) Fire(entry *log.Entry) error {
|
||||
h.messages = append(h.messages, entry.Message)
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_AntigravityCreditsFallbackAfterBootstrap429(t *testing.T) {
|
||||
const model = "claude-opus-4-6-thinking"
|
||||
executor := &antigravityCreditsFallbackExecutor{}
|
||||
@@ -88,6 +127,51 @@ func TestManagerExecuteStream_AntigravityCreditsFallbackAfterBootstrap429(t *tes
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecuteStream_CodexOnlyDoesNotEnterAntigravityCreditsFallback(t *testing.T) {
|
||||
const model = "gpt-5.5"
|
||||
logger := log.StandardLogger()
|
||||
oldLevel := logger.GetLevel()
|
||||
oldHooks := logger.ReplaceHooks(make(log.LevelHooks))
|
||||
hook := &captureLogHook{}
|
||||
logger.SetLevel(log.DebugLevel)
|
||||
logger.AddHook(hook)
|
||||
t.Cleanup(func() {
|
||||
logger.SetLevel(oldLevel)
|
||||
logger.ReplaceHooks(oldHooks)
|
||||
})
|
||||
|
||||
manager := NewManager(nil, nil, nil)
|
||||
manager.SetConfig(&internalconfig.Config{
|
||||
QuotaExceeded: internalconfig.QuotaExceeded{AntigravityCredits: true},
|
||||
})
|
||||
manager.RegisterExecutor(codexOnlyFailureExecutor{})
|
||||
manager.RegisterExecutor(&antigravityCreditsFallbackExecutor{})
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("codex-only", "codex", []*registry.ModelInfo{{ID: model}})
|
||||
reg.RegisterClient("ag-unrelated", "antigravity", []*registry.ModelInfo{{ID: "gemini-3-flash"}})
|
||||
t.Cleanup(func() {
|
||||
reg.UnregisterClient("codex-only")
|
||||
reg.UnregisterClient("ag-unrelated")
|
||||
})
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "codex-only", Provider: "codex"}); errRegister != nil {
|
||||
t.Fatalf("register codex auth: %v", errRegister)
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "ag-unrelated", Provider: "antigravity"}); errRegister != nil {
|
||||
t.Fatalf("register antigravity auth: %v", errRegister)
|
||||
}
|
||||
|
||||
_, errExecute := manager.ExecuteStream(context.Background(), []string{"codex"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
|
||||
if errExecute == nil {
|
||||
t.Fatal("expected codex execution failure")
|
||||
}
|
||||
|
||||
for _, message := range hook.messages {
|
||||
if strings.Contains(message, "shouldAttemptAntigravityCreditsFallback") {
|
||||
t.Fatalf("codex-only request entered antigravity credits fallback gate; messages=%v", hook.messages)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusCodeFromError_UnwrapsStreamBootstrap429(t *testing.T) {
|
||||
bootstrapErr := newStreamBootstrapError(&Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota exhausted"}, nil)
|
||||
wrappedErr := fmt.Errorf("conductor stream failed: %w", bootstrapErr)
|
||||
|
||||
+142
-15
@@ -45,6 +45,13 @@ type ProviderExecutor interface {
|
||||
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// RequestAuthPreparer lets an executor update missing auth metadata immediately
|
||||
// before a request. Manager serializes and persists returned updates.
|
||||
type RequestAuthPreparer interface {
|
||||
ShouldPrepareRequestAuth(auth *Auth) bool
|
||||
PrepareRequestAuth(ctx context.Context, auth *Auth) (*Auth, error)
|
||||
}
|
||||
|
||||
// ExecutionSessionCloser allows executors to release per-session runtime resources.
|
||||
type ExecutionSessionCloser interface {
|
||||
CloseExecutionSession(sessionID string)
|
||||
@@ -182,6 +189,8 @@ type Manager struct {
|
||||
// Auto refresh state
|
||||
refreshCancel context.CancelFunc
|
||||
refreshLoop *authAutoRefreshLoop
|
||||
|
||||
requestPrepareLocks sync.Map
|
||||
}
|
||||
|
||||
// NewManager constructs a manager with optional custom selector and hook.
|
||||
@@ -1238,7 +1247,7 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
if shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
|
||||
if hasAntigravityProvider(normalized) && shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
|
||||
if resp, ok := m.tryAntigravityCreditsExecute(ctx, req, opts); ok {
|
||||
return resp, nil
|
||||
}
|
||||
@@ -1304,7 +1313,7 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
if shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
|
||||
if hasAntigravityProvider(normalized) && shouldAttemptAntigravityCreditsFallback(m, lastErr, normalized) {
|
||||
if result, ok := m.tryAntigravityCreditsExecuteStream(ctx, req, opts); ok {
|
||||
return result, nil
|
||||
}
|
||||
@@ -1365,6 +1374,17 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var errPrepare error
|
||||
auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth)
|
||||
if errPrepare != nil {
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errPrepare
|
||||
continue
|
||||
}
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
|
||||
@@ -1453,6 +1473,17 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var errPrepare error
|
||||
auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth)
|
||||
if errPrepare != nil {
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errPrepare
|
||||
continue
|
||||
}
|
||||
var authErr error
|
||||
for _, upstreamModel := range models {
|
||||
resultModel := m.stateModelForExecution(auth, routeModel, upstreamModel, pooled)
|
||||
@@ -1539,6 +1570,17 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
continue
|
||||
}
|
||||
attempted[auth.ID] = struct{}{}
|
||||
var errPrepare error
|
||||
auth, errPrepare = m.prepareRequestAuth(execCtx, executor, auth)
|
||||
if errPrepare != nil {
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: &Error{Message: errPrepare.Error()}}
|
||||
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errPrepare); ok && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errPrepare
|
||||
continue
|
||||
}
|
||||
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
|
||||
if errStream != nil {
|
||||
if errCtx := execCtx.Err(); errCtx != nil {
|
||||
@@ -1630,9 +1672,69 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
type requestAuthPrepareLock struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *Manager) prepareRequestAuth(ctx context.Context, executor ProviderExecutor, auth *Auth) (*Auth, error) {
|
||||
if m == nil || executor == nil || auth == nil {
|
||||
return auth, nil
|
||||
}
|
||||
preparer, ok := executor.(RequestAuthPreparer)
|
||||
if !ok || preparer == nil || !preparer.ShouldPrepareRequestAuth(auth) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
id := strings.TrimSpace(auth.ID)
|
||||
if id == "" {
|
||||
return preparer.PrepareRequestAuth(ctx, auth.Clone())
|
||||
}
|
||||
|
||||
lockValue, _ := m.requestPrepareLocks.LoadOrStore(id, &requestAuthPrepareLock{})
|
||||
lock, ok := lockValue.(*requestAuthPrepareLock)
|
||||
if !ok || lock == nil {
|
||||
return preparer.PrepareRequestAuth(ctx, auth.Clone())
|
||||
}
|
||||
|
||||
lock.mu.Lock()
|
||||
defer lock.mu.Unlock()
|
||||
|
||||
target := auth.Clone()
|
||||
m.mu.RLock()
|
||||
if current := m.auths[id]; current != nil {
|
||||
target = current.Clone()
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !preparer.ShouldPrepareRequestAuth(target) {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
updated, errPrepare := preparer.PrepareRequestAuth(ctx, target)
|
||||
if errPrepare != nil {
|
||||
return auth, errPrepare
|
||||
}
|
||||
if updated == nil {
|
||||
return target, nil
|
||||
}
|
||||
|
||||
saved, errUpdate := m.Update(ctx, updated)
|
||||
if errUpdate != nil {
|
||||
return updated, errUpdate
|
||||
}
|
||||
if saved != nil {
|
||||
return saved, nil
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func contextWithRequestedModelAlias(ctx context.Context, opts cliproxyexecutor.Options, fallback string) context.Context {
|
||||
alias := requestedModelAliasFromOptions(opts, fallback)
|
||||
return coreusage.WithRequestedModelAlias(ctx, alias)
|
||||
ctx = coreusage.WithRequestedModelAlias(ctx, alias)
|
||||
if effort := reasoningEffortFromOptions(opts); effort != "" {
|
||||
ctx = coreusage.WithReasoningEffort(ctx, effort)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback string) string {
|
||||
@@ -1660,6 +1762,24 @@ func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback stri
|
||||
}
|
||||
}
|
||||
|
||||
func reasoningEffortFromOptions(opts cliproxyexecutor.Options) string {
|
||||
if len(opts.Metadata) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := opts.Metadata[cliproxyexecutor.ReasoningEffortMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(value)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(value))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func pinnedAuthIDFromMetadata(meta map[string]any) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
@@ -3587,6 +3707,15 @@ type creditsCandidateEntry struct {
|
||||
provider string
|
||||
}
|
||||
|
||||
func hasAntigravityProvider(providers []string) bool {
|
||||
for _, p := range providers {
|
||||
if strings.EqualFold(strings.TrimSpace(p), "antigravity") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldAttemptAntigravityCreditsFallback(m *Manager, lastErr error, providers []string) bool {
|
||||
status := statusCodeFromError(lastErr)
|
||||
log.WithFields(log.Fields{
|
||||
@@ -3597,18 +3726,6 @@ func shouldAttemptAntigravityCreditsFallback(m *Manager, lastErr error, provider
|
||||
if m == nil || lastErr == nil {
|
||||
return false
|
||||
}
|
||||
if len(providers) > 0 {
|
||||
hasAntigravity := false
|
||||
for _, p := range providers {
|
||||
if strings.EqualFold(strings.TrimSpace(p), "antigravity") {
|
||||
hasAntigravity = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasAntigravity {
|
||||
return false
|
||||
}
|
||||
}
|
||||
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
||||
if cfg == nil || !cfg.QuotaExceeded.AntigravityCredits {
|
||||
return false
|
||||
@@ -3645,6 +3762,11 @@ func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxy
|
||||
}
|
||||
creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
|
||||
creditsCtx = contextWithRequestedModelAlias(creditsCtx, creditsOpts, routeModel)
|
||||
preparedAuth, errPrepare := m.prepareRequestAuth(creditsCtx, c.executor, c.auth)
|
||||
if errPrepare != nil {
|
||||
continue
|
||||
}
|
||||
c.auth = preparedAuth
|
||||
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
|
||||
models := m.executionModelCandidates(c.auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
@@ -3687,6 +3809,11 @@ func (m *Manager) tryAntigravityCreditsExecuteStream(ctx context.Context, req cl
|
||||
creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
|
||||
preparedAuth, errPrepare := m.prepareRequestAuth(creditsCtx, c.executor, c.auth)
|
||||
if errPrepare != nil {
|
||||
continue
|
||||
}
|
||||
c.auth = preparedAuth
|
||||
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
|
||||
models := m.executionModelCandidates(c.auth, routeModel)
|
||||
if len(models) == 0 {
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestContextWithRequestedModelAliasIncludesReasoningEffort(t *testing.T) {
|
||||
ctx := contextWithRequestedModelAlias(context.Background(), cliproxyexecutor.Options{
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.RequestedModelMetadataKey: "client-model",
|
||||
cliproxyexecutor.ReasoningEffortMetadataKey: "medium",
|
||||
},
|
||||
}, "fallback-model")
|
||||
|
||||
if got := coreusage.RequestedModelAliasFromContext(ctx); got != "client-model" {
|
||||
t.Fatalf("requested model alias = %q, want %q", got, "client-model")
|
||||
}
|
||||
if got := coreusage.ReasoningEffortFromContext(ctx); got != "medium" {
|
||||
t.Fatalf("reasoning effort = %q, want %q", got, "medium")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type requestPrepareStore struct {
|
||||
saveCount atomic.Int32
|
||||
mu sync.Mutex
|
||||
last *Auth
|
||||
}
|
||||
|
||||
func (s *requestPrepareStore) List(context.Context) ([]*Auth, error) { return nil, nil }
|
||||
|
||||
func (s *requestPrepareStore) Save(_ context.Context, auth *Auth) (string, error) {
|
||||
s.saveCount.Add(1)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.last = auth.Clone()
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *requestPrepareStore) Delete(context.Context, string) error { return nil }
|
||||
|
||||
func (s *requestPrepareStore) lastAuth() *Auth {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.last.Clone()
|
||||
}
|
||||
|
||||
type requestPrepareExecutor struct {
|
||||
prepareCalls atomic.Int32
|
||||
executeCalls atomic.Int32
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) Identifier() string { return "antigravity" }
|
||||
|
||||
func (e *requestPrepareExecutor) ShouldPrepareRequestAuth(auth *Auth) bool {
|
||||
return auth == nil || auth.Metadata == nil || testStringValue(auth.Metadata["project_id"]) == ""
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) PrepareRequestAuth(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
e.prepareCalls.Add(1)
|
||||
updated := auth.Clone()
|
||||
if updated.Metadata == nil {
|
||||
updated.Metadata = make(map[string]any)
|
||||
}
|
||||
updated.Metadata["project_id"] = "prepared-project"
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) Execute(_ context.Context, auth *Auth, _ cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
e.executeCalls.Add(1)
|
||||
if got := testStringValue(auth.Metadata["project_id"]); got != "prepared-project" {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusBadRequest, Message: "missing prepared project"}
|
||||
}
|
||||
return cliproxyexecutor.Response{Payload: []byte("ok")}, nil
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "stream not implemented"}
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, &Error{HTTPStatus: http.StatusNotImplemented, Message: "count not implemented"}
|
||||
}
|
||||
|
||||
func (e *requestPrepareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "http not implemented"}
|
||||
}
|
||||
|
||||
func TestManagerExecute_PreparesAndPersistsMissingRequestAuthMetadata(t *testing.T) {
|
||||
const model = "gemini-3.1-pro"
|
||||
store := &requestPrepareStore{}
|
||||
executor := &requestPrepareExecutor{}
|
||||
manager := NewManager(store, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth := &Auth{
|
||||
ID: "auth-request-prepare",
|
||||
Provider: "antigravity",
|
||||
Metadata: map[string]any{"access_token": "token"},
|
||||
}
|
||||
if _, errRegister := manager.Register(WithSkipPersist(context.Background()), auth); errRegister != nil {
|
||||
t.Fatalf("register auth: %v", errRegister)
|
||||
}
|
||||
registry.GetGlobalRegistry().RegisterClient(auth.ID, "antigravity", []*registry.ModelInfo{{ID: model}})
|
||||
t.Cleanup(func() { registry.GetGlobalRegistry().UnregisterClient(auth.ID) })
|
||||
|
||||
resp, errExecute := manager.Execute(context.Background(), []string{"antigravity"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{})
|
||||
if errExecute != nil {
|
||||
t.Fatalf("Execute error: %v", errExecute)
|
||||
}
|
||||
if string(resp.Payload) != "ok" {
|
||||
t.Fatalf("payload = %q, want ok", string(resp.Payload))
|
||||
}
|
||||
if got := executor.prepareCalls.Load(); got != 1 {
|
||||
t.Fatalf("prepare calls = %d, want 1", got)
|
||||
}
|
||||
if got := store.saveCount.Load(); got < 1 {
|
||||
t.Fatalf("save count = %d, want at least 1", got)
|
||||
}
|
||||
if got := testStringValue(store.lastAuth().Metadata["project_id"]); got != "prepared-project" {
|
||||
t.Fatalf("persisted project_id = %q, want prepared-project", got)
|
||||
}
|
||||
current, ok := manager.GetByID(auth.ID)
|
||||
if !ok {
|
||||
t.Fatal("expected auth in manager")
|
||||
}
|
||||
if got := testStringValue(current.Metadata["project_id"]); got != "prepared-project" {
|
||||
t.Fatalf("manager project_id = %q, want prepared-project", got)
|
||||
}
|
||||
|
||||
if _, errExecute = manager.Execute(context.Background(), []string{"antigravity"}, cliproxyexecutor.Request{Model: model}, cliproxyexecutor.Options{}); errExecute != nil {
|
||||
t.Fatalf("second Execute error: %v", errExecute)
|
||||
}
|
||||
if got := executor.prepareCalls.Load(); got != 1 {
|
||||
t.Fatalf("prepare calls after second execute = %d, want 1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func testStringValue(value any) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(typed))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,9 @@ const RequestPathMetadataKey = "request_path"
|
||||
// DisallowFreeAuthMetadataKey instructs auth selection to skip known free-tier credentials.
|
||||
const DisallowFreeAuthMetadataKey = "disallow_free_auth"
|
||||
|
||||
// ReasoningEffortMetadataKey stores the client-requested reasoning effort for usage logs.
|
||||
const ReasoningEffortMetadataKey = "reasoning_effort"
|
||||
|
||||
const (
|
||||
// PinnedAuthMetadataKey locks execution to a specific auth ID.
|
||||
PinnedAuthMetadataKey = "pinned_auth_id"
|
||||
|
||||
+38
-24
@@ -1208,30 +1208,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
if strings.EqualFold(compat.Name, compatName) {
|
||||
isCompatAuth = true
|
||||
// Convert compatibility models to registry models
|
||||
ms := make([]*ModelInfo, 0, len(compat.Models))
|
||||
for j := range compat.Models {
|
||||
m := compat.Models[j]
|
||||
// Use alias as model ID, fallback to name if alias is empty
|
||||
modelID := m.Alias
|
||||
if modelID == "" {
|
||||
modelID = m.Name
|
||||
}
|
||||
thinking := m.Thinking
|
||||
if thinking == nil {
|
||||
thinking = ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
|
||||
}
|
||||
ms = append(ms, &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: compat.Name,
|
||||
Type: "openai-compatibility",
|
||||
DisplayName: modelID,
|
||||
UserDefined: false,
|
||||
Thinking: thinking,
|
||||
})
|
||||
}
|
||||
ms := buildOpenAICompatibilityConfigModels(compat)
|
||||
// Register and return
|
||||
if len(ms) > 0 {
|
||||
if providerKey == "" {
|
||||
@@ -1578,6 +1555,43 @@ type modelEntry interface {
|
||||
GetAlias() string
|
||||
}
|
||||
|
||||
func buildOpenAICompatibilityConfigModels(compat *config.OpenAICompatibility) []*ModelInfo {
|
||||
if compat == nil || len(compat.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
models := make([]*ModelInfo, 0, len(compat.Models))
|
||||
for i := range compat.Models {
|
||||
model := compat.Models[i]
|
||||
modelID := strings.TrimSpace(model.Alias)
|
||||
if modelID == "" {
|
||||
modelID = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if modelID == "" {
|
||||
continue
|
||||
}
|
||||
modelType := "openai-compatibility"
|
||||
if model.Image {
|
||||
modelType = registry.OpenAIImageModelType
|
||||
}
|
||||
thinking := model.Thinking
|
||||
if thinking == nil && !model.Image {
|
||||
thinking = ®istry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
|
||||
}
|
||||
models = append(models, &ModelInfo{
|
||||
ID: modelID,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: compat.Name,
|
||||
Type: modelType,
|
||||
DisplayName: modelID,
|
||||
UserDefined: false,
|
||||
Thinking: thinking,
|
||||
})
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
internalregistry "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
)
|
||||
@@ -63,3 +64,71 @@ func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T
|
||||
t.Fatal("expected global excluded model to be present when attribute override is set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterModelsForAuth_OpenAICompatibilityImageModelType(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "images",
|
||||
BaseURL: "https://example.com/v1",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "upstream-image", Alias: "compat-image", Image: true},
|
||||
{Name: "upstream-chat", Alias: "compat-chat"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "auth-openai-compat-image",
|
||||
Provider: "openai-compatibility",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "api_key",
|
||||
"compat_name": "images",
|
||||
"provider_key": "images",
|
||||
},
|
||||
}
|
||||
|
||||
modelRegistry := internalregistry.GetGlobalRegistry()
|
||||
modelRegistry.UnregisterClient(auth.ID)
|
||||
t.Cleanup(func() {
|
||||
modelRegistry.UnregisterClient(auth.ID)
|
||||
})
|
||||
|
||||
service.registerModelsForAuth(auth)
|
||||
|
||||
models := modelRegistry.GetModelsForClient(auth.ID)
|
||||
var imageModel *internalregistry.ModelInfo
|
||||
var chatModel *internalregistry.ModelInfo
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
switch strings.TrimSpace(model.ID) {
|
||||
case "compat-image":
|
||||
imageModel = model
|
||||
case "compat-chat":
|
||||
chatModel = model
|
||||
}
|
||||
}
|
||||
if imageModel == nil {
|
||||
t.Fatal("expected compat-image to be registered")
|
||||
}
|
||||
if imageModel.Type != internalregistry.OpenAIImageModelType {
|
||||
t.Fatalf("image model type = %q, want %q", imageModel.Type, internalregistry.OpenAIImageModelType)
|
||||
}
|
||||
if imageModel.Thinking != nil {
|
||||
t.Fatalf("image model thinking = %+v, want nil", imageModel.Thinking)
|
||||
}
|
||||
if chatModel == nil {
|
||||
t.Fatal("expected compat-chat to be registered")
|
||||
}
|
||||
if chatModel.Type != "openai-compatibility" {
|
||||
t.Fatalf("chat model type = %q, want openai-compatibility", chatModel.Type)
|
||||
}
|
||||
if chatModel.Thinking == nil {
|
||||
t.Fatal("expected chat model to keep default thinking support")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package usage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -11,19 +12,23 @@ import (
|
||||
|
||||
// Record contains the usage statistics captured for a single provider request.
|
||||
type Record struct {
|
||||
Provider string
|
||||
Model string
|
||||
Alias string
|
||||
APIKey string
|
||||
AuthID string
|
||||
AuthIndex string
|
||||
AuthType string
|
||||
Source string
|
||||
RequestedAt time.Time
|
||||
Latency time.Duration
|
||||
Failed bool
|
||||
Fail Failure
|
||||
Detail Detail
|
||||
Provider string
|
||||
Model string
|
||||
Alias string
|
||||
APIKey string
|
||||
AuthID string
|
||||
AuthIndex string
|
||||
AuthType string
|
||||
Source string
|
||||
// ReasoningEffort stores the client-requested thinking level for request event logs.
|
||||
ReasoningEffort string
|
||||
RequestedAt time.Time
|
||||
Latency time.Duration
|
||||
Failed bool
|
||||
Fail Failure
|
||||
Detail Detail
|
||||
// ResponseHeaders stores a snapshot of upstream response headers for usage sinks.
|
||||
ResponseHeaders http.Header
|
||||
}
|
||||
|
||||
// Failure holds HTTP failure metadata for an upstream request attempt.
|
||||
@@ -44,6 +49,7 @@ type Detail struct {
|
||||
}
|
||||
|
||||
type requestedModelAliasContextKey struct{}
|
||||
type reasoningEffortContextKey struct{}
|
||||
|
||||
// WithRequestedModelAlias stores the client-requested model name for usage sinks.
|
||||
func WithRequestedModelAlias(ctx context.Context, alias string) context.Context {
|
||||
@@ -73,6 +79,34 @@ func RequestedModelAliasFromContext(ctx context.Context) string {
|
||||
}
|
||||
}
|
||||
|
||||
// WithReasoningEffort stores the client-requested reasoning effort for usage sinks.
|
||||
func WithReasoningEffort(ctx context.Context, effort string) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
effort = strings.TrimSpace(effort)
|
||||
if effort == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, reasoningEffortContextKey{}, effort)
|
||||
}
|
||||
|
||||
// ReasoningEffortFromContext returns the client-requested reasoning effort stored in ctx.
|
||||
func ReasoningEffortFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
raw := ctx.Value(reasoningEffortContextKey{})
|
||||
switch value := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(value)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(value))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Plugin consumes usage records emitted by the proxy runtime.
|
||||
type Plugin interface {
|
||||
HandleUsage(ctx context.Context, record Record)
|
||||
|
||||
+122
-1
@@ -1,7 +1,10 @@
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@@ -50,7 +53,7 @@ func Parse(raw string) (Setting, error) {
|
||||
parsedURL, errParse := url.Parse(trimmed)
|
||||
if errParse != nil {
|
||||
setting.Mode = ModeInvalid
|
||||
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
|
||||
return setting, fmt.Errorf("parse proxy URL failed")
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
setting.Mode = ModeInvalid
|
||||
@@ -134,6 +137,9 @@ func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
|
||||
case ModeDirect:
|
||||
return proxy.Direct, setting.Mode, nil
|
||||
case ModeProxy:
|
||||
if setting.URL.Scheme == "http" || setting.URL.Scheme == "https" {
|
||||
return &httpConnectDialer{proxyURL: setting.URL, dialer: proxy.Direct}, setting.Mode, nil
|
||||
}
|
||||
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
|
||||
if errDialer != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
|
||||
@@ -143,3 +149,118 @@ func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
}
|
||||
|
||||
type httpConnectDialer struct {
|
||||
proxyURL *url.URL
|
||||
dialer proxy.Dialer
|
||||
}
|
||||
|
||||
func (d *httpConnectDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
proxyConn, errDial := d.dialer.Dial(network, proxyDialAddr(d.proxyURL))
|
||||
if errDial != nil {
|
||||
return nil, fmt.Errorf("dial HTTP proxy failed: %w", errDial)
|
||||
}
|
||||
|
||||
conn := proxyConn
|
||||
if d.proxyURL.Scheme == "https" {
|
||||
tlsConn := tls.Client(conn, &tls.Config{ServerName: d.proxyURL.Hostname()})
|
||||
if errHandshake := tlsConn.Handshake(); errHandshake != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("HTTPS proxy TLS handshake failed: %w; close failed: %v", errHandshake, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("HTTPS proxy TLS handshake failed: %w", errHandshake)
|
||||
}
|
||||
conn = tlsConn
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Host: addr},
|
||||
Host: addr,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
if d.proxyURL.User != nil {
|
||||
req.Header.Set("Proxy-Authorization", proxyAuthorization(d.proxyURL.User))
|
||||
}
|
||||
if errWrite := req.Write(conn); errWrite != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("write CONNECT request failed: %w; close failed: %v", errWrite, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("write CONNECT request failed: %w", errWrite)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
resp, errRead := http.ReadResponse(reader, req)
|
||||
if errRead != nil {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("read CONNECT response failed: %w; close failed: %v", errRead, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("read CONNECT response failed: %w", errRead)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
return nil, fmt.Errorf("proxy CONNECT returned status %s; close failed: %v", resp.Status, errClose)
|
||||
}
|
||||
return nil, fmt.Errorf("proxy CONNECT returned status %s", resp.Status)
|
||||
}
|
||||
|
||||
if reader.Buffered() > 0 {
|
||||
return &bufferedConn{Conn: conn, reader: reader}, nil
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func proxyDialAddr(proxyURL *url.URL) string {
|
||||
port := proxyURL.Port()
|
||||
if port == "" {
|
||||
port = "80"
|
||||
if proxyURL.Scheme == "https" {
|
||||
port = "443"
|
||||
}
|
||||
}
|
||||
return net.JoinHostPort(proxyURL.Hostname(), port)
|
||||
}
|
||||
|
||||
func proxyAuthorization(user *url.Userinfo) string {
|
||||
username := user.Username()
|
||||
password, _ := user.Password()
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password))
|
||||
return "Basic " + encoded
|
||||
}
|
||||
|
||||
// Redact returns a log-safe proxy URL with credentials and path-like data removed.
|
||||
func Redact(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parsedURL, errParse := url.Parse(trimmed)
|
||||
if errParse != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return "<invalid proxy URL>"
|
||||
}
|
||||
|
||||
redacted := &url.URL{
|
||||
Scheme: parsedURL.Scheme,
|
||||
Host: parsedURL.Host,
|
||||
}
|
||||
if parsedURL.User != nil {
|
||||
redacted.User = url.User("redacted")
|
||||
}
|
||||
return redacted.String()
|
||||
}
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
if c.reader.Buffered() > 0 {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
package proxyutil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func mustDefaultTransport(t *testing.T) *http.Transport {
|
||||
@@ -159,3 +166,157 @@ func TestBuildHTTPTransportSOCKS5HProxy(t *testing.T) {
|
||||
t.Fatal("expected SOCKS5H transport to have custom DialContext")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildDialerHTTPProxyCONNECT(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listener, errListen := net.Listen("tcp", "127.0.0.1:0")
|
||||
if errListen != nil {
|
||||
t.Fatalf("net.Listen returned error: %v", errListen)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := listener.Close(); errClose != nil {
|
||||
t.Errorf("listener.Close returned error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
conn, errAccept := listener.Accept()
|
||||
if errAccept != nil {
|
||||
done <- errAccept
|
||||
return
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
if errDeadline := conn.SetDeadline(time.Now().Add(5 * time.Second)); errDeadline != nil {
|
||||
done <- errDeadline
|
||||
return
|
||||
}
|
||||
|
||||
req, errRead := http.ReadRequest(bufio.NewReader(conn))
|
||||
if errRead != nil {
|
||||
done <- fmt.Errorf("read CONNECT request failed: %w", errRead)
|
||||
return
|
||||
}
|
||||
if req.Method != http.MethodConnect {
|
||||
done <- fmt.Errorf("method = %s, want CONNECT", req.Method)
|
||||
return
|
||||
}
|
||||
if req.Host != "target.example.com:443" {
|
||||
done <- fmt.Errorf("host = %s, want target.example.com:443", req.Host)
|
||||
return
|
||||
}
|
||||
wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("user:pass"))
|
||||
if gotAuth := req.Header.Get("Proxy-Authorization"); gotAuth != wantAuth {
|
||||
done <- fmt.Errorf("Proxy-Authorization = %q, want %q", gotAuth, wantAuth)
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(conn, "HTTP/1.1 200 Connection Established\r\n\r\nok"); errWrite != nil {
|
||||
done <- fmt.Errorf("write CONNECT response failed: %w", errWrite)
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, 4)
|
||||
n, errReadTunnel := io.ReadFull(conn, buf)
|
||||
if errReadTunnel != nil {
|
||||
done <- fmt.Errorf("read tunneled payload failed after %d bytes: %w", n, errReadTunnel)
|
||||
return
|
||||
}
|
||||
if string(buf) != "ping" {
|
||||
done <- fmt.Errorf("tunneled payload = %q, want ping", string(buf))
|
||||
return
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
dialer, mode, errBuild := BuildDialer("http://user:pass@" + listener.Addr().String())
|
||||
if errBuild != nil {
|
||||
t.Fatalf("BuildDialer returned error: %v", errBuild)
|
||||
}
|
||||
if mode != ModeProxy {
|
||||
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||
}
|
||||
if dialer == nil {
|
||||
t.Fatal("expected dialer, got nil")
|
||||
}
|
||||
|
||||
conn, errDial := dialer.Dial("tcp", "target.example.com:443")
|
||||
if errDial != nil {
|
||||
t.Fatalf("dialer.Dial returned error: %v", errDial)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
t.Errorf("conn.Close returned error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 2)
|
||||
n, errRead := io.ReadFull(conn, buf)
|
||||
if errRead != nil {
|
||||
t.Fatalf("conn.Read returned error after %d bytes: %v", n, errRead)
|
||||
}
|
||||
if string(buf) != "ok" {
|
||||
t.Fatalf("buffered tunnel payload = %q, want ok", string(buf))
|
||||
}
|
||||
|
||||
if _, errWrite := conn.Write([]byte("ping")); errWrite != nil {
|
||||
t.Fatalf("conn.Write returned error: %v", errWrite)
|
||||
}
|
||||
|
||||
if errServer := <-done; errServer != nil {
|
||||
t.Fatalf("proxy server returned error: %v", errServer)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedactProxyURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "with credentials",
|
||||
input: "http://user:pass@proxy.example.com:8080/path?token=secret",
|
||||
want: "http://redacted@proxy.example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "without credentials",
|
||||
input: "socks5://proxy.example.com:1080",
|
||||
want: "socks5://proxy.example.com:1080",
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: "bad-value",
|
||||
want: "<invalid proxy URL>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := Redact(tt.input); got != tt.want {
|
||||
t.Fatalf("Redact() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseErrorDoesNotExposeProxyCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
input := "http://user:secret%@proxy.example.com:8080"
|
||||
_, errParse := Parse(input)
|
||||
if errParse == nil {
|
||||
t.Fatal("expected Parse to return an error")
|
||||
}
|
||||
if strings.Contains(errParse.Error(), input) ||
|
||||
strings.Contains(errParse.Error(), "user") ||
|
||||
strings.Contains(errParse.Error(), "secret") {
|
||||
t.Fatalf("parse error exposes proxy credentials: %q", errParse.Error())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user