Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b7d7021af | ||
|
|
e41c22ef44 | ||
|
|
55271403fb | ||
|
|
36fba66619 | ||
|
|
b9b127a7ea | ||
|
|
2741e7b7b3 | ||
|
|
1767a56d4f | ||
|
|
779e6c2d2f | ||
|
|
73c831747b | ||
|
|
10b824fcac | ||
|
|
e5d3541b5a | ||
|
|
79755e76ea | ||
|
|
35f158d526 | ||
|
|
6962e09dd9 | ||
|
|
4c4cbd44da | ||
|
|
26eca8b6ba | ||
|
|
62b17f40a1 | ||
|
|
511b8a992e | ||
|
|
7dccc7ba2f | ||
|
|
70c90687fd | ||
|
|
8144ffd5c8 | ||
|
|
0ab977c236 | ||
|
|
224f0de353 | ||
|
|
6b45d311ec | ||
|
|
d54de441d3 | ||
|
|
1821bf7051 | ||
|
|
754f3bcbc3 | ||
|
|
36973d4a6f | ||
|
|
c89d19b300 | ||
|
|
cc32f5ff61 | ||
|
|
fbff68b9e0 | ||
|
|
7e1a543b79 | ||
|
|
74b862d8b8 | ||
|
|
5c817a9b42 |
16
README.md
16
README.md
@@ -30,6 +30,14 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>Huge thanks to BmoPlus for sponsoring this project! BmoPlus is a highly reliable AI account provider built strictly for heavy AI users and developers. They offer rock-solid, ready-to-use accounts and official top-up services for ChatGPT Plus / ChatGPT Pro (Full Warranty) / Claude Pro / Super Grok / Gemini Pro. By registering and ordering through <a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus - Premium AI Accounts & Top-ups</a>, users can unlock the mind-blowing rate of <b>10% of the official GPT subscription price (90% OFF)</b>!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>Thanks to LingtrueAPI for its sponsorship of this project! LingtrueAPI is a global large - model API intermediary service platform that provides API calling services for various top - notch models such as Claude Code, Codex, and Gemini. It is committed to enabling users to connect to global AI capabilities at low cost and with high stability. LingtrueAPI offers special discounts to users of this software: register using <a href="https://www.lingtrue.com/register">this link</a>, and enter the promo code "LingtrueAPI" when making the first recharge to enjoy a 10% discount.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -74,6 +82,14 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A
|
||||
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- Security-first design with localhost-only management endpoints
|
||||
|
||||
When you need the request/response shape of a specific backend family, use the provider-specific paths instead of the merged `/v1/...` endpoints:
|
||||
|
||||
- Use `/api/provider/{provider}/v1/messages` for messages-style backends.
|
||||
- Use `/api/provider/{provider}/v1beta/models/...` for model-scoped generate endpoints.
|
||||
- Use `/api/provider/{provider}/v1/chat/completions` for chat-completions backends.
|
||||
|
||||
These routes help you select the protocol surface, but they do not by themselves guarantee a unique inference executor when the same client-visible model name is reused across multiple backends. Inference routing is still resolved from the request model/alias. For strict backend pinning, use unique aliases, prefixes, or otherwise avoid overlapping client-visible model names.
|
||||
|
||||
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK Docs
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -30,6 +30,14 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">此链接</a>注册的用户,可享受首充8折,企业客户最高可享 7.5 折!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>感谢 BmoPlus 赞助了本项目!BmoPlus 是一家专为AI订阅重度用户打造的可靠 AI 账号代充服务商,提供稳定的 ChatGPT Plus / ChatGPT Pro(全程质保) / Claude Pro / Super Grok / Gemini Pro 的官方代充&成品账号。 通过<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AI成品号专卖/代充</a>注册下单的用户,可享GPT <b>官网订阅一折</b> 的震撼价格!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>感谢 LingtrueAPI 对本项目的赞助!LingtrueAPI 是一家全球大模型API中转服务平台,提供Claude Code、Codex、Gemini 等多种顶级模型API调用服务,致力于让用户以低成本、高稳定性链接全球AI能力。LingtrueAPI为本软件用户提供了特别优惠:使用<a href="https://www.lingtrue.com/register">此链接</a>注册,并在首次充值时输入 "LingtrueAPI" 优惠码即可享受9折优惠。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -73,6 +81,14 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
|
||||
- 智能模型回退与自动路由
|
||||
- 以安全为先的设计,管理端点仅限 localhost
|
||||
|
||||
当你需要某一类后端的请求/响应协议形态时,优先使用 provider-specific 路径,而不是合并后的 `/v1/...` 端点:
|
||||
|
||||
- 对于 messages 风格的后端,使用 `/api/provider/{provider}/v1/messages`。
|
||||
- 对于按模型路径暴露生成接口的后端,使用 `/api/provider/{provider}/v1beta/models/...`。
|
||||
- 对于 chat-completions 风格的后端,使用 `/api/provider/{provider}/v1/chat/completions`。
|
||||
|
||||
这些路径有助于选择协议表面,但当多个后端复用同一个客户端可见模型名时,它们本身并不能保证唯一的推理执行器。实际的推理路由仍然根据请求里的 model/alias 解析。若要严格钉住某个后端,请使用唯一 alias、前缀,或避免让多个后端暴露相同的客户端模型名。
|
||||
|
||||
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK 文档
|
||||
|
||||
16
README_JA.md
16
README_JA.md
@@ -30,6 +30,14 @@ GLM CODING PLANを10%割引で取得:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>AICodeMirrorのスポンサーシップに感謝します!AICodeMirrorはClaude Code / Codex / Gemini CLI向けの公式高安定性リレーサービスを提供しており、エンタープライズグレードの同時接続、迅速な請求書発行、24時間365日の専任技術サポートを備えています。Claude Code / Codex / Geminiの公式チャネルが元の価格の38% / 2% / 9%で利用でき、チャージ時にはさらに割引があります!CLIProxyAPIユーザー向けの特別特典:<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">こちらのリンク</a>から登録すると、初回チャージが20%割引になり、エンタープライズのお客様は最大25%割引を受けられます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://shop.bmoplus.com/?utm_source=github"><img src="./assets/bmoplus.png" alt="BmoPlus" width="150"></a></td>
|
||||
<td>本プロジェクトにご支援いただいた BmoPlus に感謝いたします!BmoPlusは、AIサブスクリプションのヘビーユーザー向けに特化した信頼性の高いAIアカウントサービスプロバイダーであり、安定した ChatGPT Plus / ChatGPT Pro (完全保証) / Claude Pro / Super Grok / Gemini Pro の公式代行チャージおよび即納アカウントを提供しています。こちらの<a href="https://shop.bmoplus.com/?utm_source=github">BmoPlus AIアカウント専門店/代行チャージ</a>経由でご登録・ご注文いただいたユーザー様は、GPTを <b>公式サイト価格の約1割(90% OFF)</b> という驚異的な価格でご利用いただけます!</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.lingtrue.com/register"><img src="./assets/lingtrue.png" alt="LingtrueAPI" width="150"></a></td>
|
||||
<td>LingtrueAPIのスポンサーシップに感謝します!LingtrueAPIはグローバルな大規模モデルAPIリレーサービスプラットフォームで、Claude Code、Codex、GeminiなどのトップモデルAPI呼び出しサービスを提供し、ユーザーが低コストかつ高い安定性で世界中のAI能力に接続できるよう支援しています。LingtrueAPIは本ソフトウェアのユーザーに特別割引を提供しています:<a href="https://www.lingtrue.com/register">こちらのリンク</a>から登録し、初回チャージ時にプロモーションコード「LingtrueAPI」を入力すると10%割引になります。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -74,6 +82,14 @@ CLIProxyAPIは[Amp CLI](https://ampcode.com)およびAmp IDE拡張機能の統
|
||||
- 利用できないモデルを代替モデルにルーティングする**モデルマッピング**(例:`claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- localhostのみの管理エンドポイントによるセキュリティファーストの設計
|
||||
|
||||
特定のバックエンド系統のリクエスト/レスポンス形状が必要な場合は、統合された `/v1/...` エンドポイントよりも provider-specific のパスを優先してください。
|
||||
|
||||
- messages 系のバックエンドには `/api/provider/{provider}/v1/messages`
|
||||
- モデル単位の generate 系エンドポイントには `/api/provider/{provider}/v1beta/models/...`
|
||||
- chat-completions 系のバックエンドには `/api/provider/{provider}/v1/chat/completions`
|
||||
|
||||
これらのパスはプロトコル面の選択には役立ちますが、同じクライアント向けモデル名が複数バックエンドで再利用されている場合、それだけで推論実行系が一意に固定されるわけではありません。実際の推論ルーティングは、引き続きリクエスト内の model/alias 解決に従います。厳密にバックエンドを固定したい場合は、一意な alias や prefix を使うか、クライアント向けモデル名の重複自体を避けてください。
|
||||
|
||||
**→ [Amp CLI統合ガイドの完全版](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDKドキュメント
|
||||
|
||||
BIN
assets/bmoplus.png
Normal file
BIN
assets/bmoplus.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
BIN
assets/lingtrue.png
Normal file
BIN
assets/lingtrue.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 129 KiB |
@@ -281,6 +281,10 @@ nonstream-keepalive-interval: 0
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||
# you select the protocol surface, but inference backend selection can still follow the resolved
|
||||
# model/alias. For strict backend pinning, use unique aliases/prefixes or avoid overlapping names.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
# oauth-model-alias:
|
||||
# gemini-cli:
|
||||
|
||||
@@ -541,10 +541,23 @@ func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
|
||||
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
|
||||
}
|
||||
|
||||
func isUnsafeAuthFileName(name string) bool {
|
||||
if strings.TrimSpace(name) == "" {
|
||||
return true
|
||||
}
|
||||
if strings.ContainsAny(name, "/\\") {
|
||||
return true
|
||||
}
|
||||
if filepath.VolumeName(name) != "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Download single auth file by name
|
||||
func (h *Handler) DownloadAuthFile(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
name := strings.TrimSpace(c.Query("name"))
|
||||
if isUnsafeAuthFileName(name) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
@@ -626,8 +639,8 @@ func (h *Handler) UploadAuthFile(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "no files uploaded"})
|
||||
return
|
||||
}
|
||||
name := c.Query("name")
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
name := strings.TrimSpace(c.Query("name"))
|
||||
if isUnsafeAuthFileName(name) {
|
||||
c.JSON(400, gin.H{"error": "invalid name"})
|
||||
return
|
||||
}
|
||||
@@ -860,7 +873,7 @@ func uniqueAuthFileNames(names []string) []string {
|
||||
|
||||
func (h *Handler) deleteAuthFileByName(ctx context.Context, name string) (string, int, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
|
||||
if isUnsafeAuthFileName(name) {
|
||||
return "", http.StatusBadRequest, fmt.Errorf("invalid name")
|
||||
}
|
||||
|
||||
|
||||
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
62
internal/api/handlers/management/auth_files_download_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDownloadAuthFile_ReturnsFile(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
authDir := t.TempDir()
|
||||
fileName := "download-user.json"
|
||||
expected := []byte(`{"type":"codex"}`)
|
||||
if err := os.WriteFile(filepath.Join(authDir, fileName), expected, 0o600); err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(fileName), nil)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected download status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
if got := rec.Body.Bytes(); string(got) != string(expected) {
|
||||
t.Fatalf("unexpected download content: %q", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAuthFile_RejectsPathSeparators(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, nil)
|
||||
|
||||
for _, name := range []string{
|
||||
"../external/secret.json",
|
||||
`..\\external\\secret.json`,
|
||||
"nested/secret.json",
|
||||
`nested\\secret.json`,
|
||||
} {
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/auth-files/download?name="+url.QueryEscape(name), nil)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected %d for name %q, got %d with body %s", http.StatusBadRequest, name, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
//go:build windows
|
||||
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDownloadAuthFile_PreventsWindowsSlashTraversal(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
authDir := filepath.Join(tempDir, "auth")
|
||||
externalDir := filepath.Join(tempDir, "external")
|
||||
if err := os.MkdirAll(authDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create auth dir: %v", err)
|
||||
}
|
||||
if err := os.MkdirAll(externalDir, 0o700); err != nil {
|
||||
t.Fatalf("failed to create external dir: %v", err)
|
||||
}
|
||||
|
||||
secretName := "secret.json"
|
||||
secretPath := filepath.Join(externalDir, secretName)
|
||||
if err := os.WriteFile(secretPath, []byte(`{"secret":true}`), 0o600); err != nil {
|
||||
t.Fatalf("failed to write external file: %v", err)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, nil)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ctx, _ := gin.CreateTestContext(rec)
|
||||
ctx.Request = httptest.NewRequest(
|
||||
http.MethodGet,
|
||||
"/v0/management/auth-files/download?name="+url.QueryEscape("../external/"+secretName),
|
||||
nil,
|
||||
)
|
||||
h.DownloadAuthFile(ctx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d with body %s", http.StatusBadRequest, rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
125
internal/runtime/executor/codex_continuity.go
Normal file
125
internal/runtime/executor/codex_continuity.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type codexContinuity struct {
|
||||
Key string
|
||||
Source string
|
||||
}
|
||||
|
||||
func metadataString(meta map[string]any, key string) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := meta[key]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func principalString(raw any) string {
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(v.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", raw))
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCodexContinuity(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) codexContinuity {
|
||||
if promptCacheKey := strings.TrimSpace(gjson.GetBytes(req.Payload, "prompt_cache_key").String()); promptCacheKey != "" {
|
||||
return codexContinuity{Key: promptCacheKey, Source: "prompt_cache_key"}
|
||||
}
|
||||
if executionSession := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); executionSession != "" {
|
||||
return codexContinuity{Key: executionSession, Source: "execution_session"}
|
||||
}
|
||||
if ginCtx := ginContextFrom(ctx); ginCtx != nil {
|
||||
if ginCtx.Request != nil {
|
||||
if v := strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key")); v != "" {
|
||||
return codexContinuity{Key: v, Source: "idempotency_key"}
|
||||
}
|
||||
}
|
||||
if v, exists := ginCtx.Get("apiKey"); exists && v != nil {
|
||||
if trimmed := principalString(v); trimmed != "" {
|
||||
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+trimmed)).String(), Source: "client_principal"}
|
||||
}
|
||||
}
|
||||
}
|
||||
if auth != nil {
|
||||
if authID := strings.TrimSpace(auth.ID); authID != "" {
|
||||
return codexContinuity{Key: uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:"+authID)).String(), Source: "auth_id"}
|
||||
}
|
||||
}
|
||||
return codexContinuity{}
|
||||
}
|
||||
|
||||
func applyCodexContinuityBody(rawJSON []byte, continuity codexContinuity) []byte {
|
||||
if continuity.Key == "" {
|
||||
return rawJSON
|
||||
}
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", continuity.Key)
|
||||
return rawJSON
|
||||
}
|
||||
|
||||
func applyCodexContinuityHeaders(headers http.Header, continuity codexContinuity) {
|
||||
if headers == nil || continuity.Key == "" {
|
||||
return
|
||||
}
|
||||
headers.Set("session_id", continuity.Key)
|
||||
}
|
||||
|
||||
func logCodexRequestDiagnostics(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, headers http.Header, body []byte, continuity codexContinuity) {
|
||||
if !log.IsLevelEnabled(log.DebugLevel) {
|
||||
return
|
||||
}
|
||||
entry := logWithRequestID(ctx)
|
||||
authID := ""
|
||||
authFile := ""
|
||||
if auth != nil {
|
||||
authID = strings.TrimSpace(auth.ID)
|
||||
authFile = strings.TrimSpace(auth.FileName)
|
||||
}
|
||||
selectedAuthID := metadataString(opts.Metadata, cliproxyexecutor.SelectedAuthMetadataKey)
|
||||
executionSessionID := metadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey)
|
||||
entry.Debugf(
|
||||
"codex request diagnostics auth_id=%s selected_auth_id=%s auth_file=%s exec_session=%s continuity_source=%s session_id=%s prompt_cache_key=%s prompt_cache_retention=%s store=%t has_instructions=%t reasoning_effort=%s reasoning_summary=%s chatgpt_account_id=%t originator=%s model=%s source_format=%s",
|
||||
authID,
|
||||
selectedAuthID,
|
||||
authFile,
|
||||
executionSessionID,
|
||||
continuity.Source,
|
||||
strings.TrimSpace(headers.Get("session_id")),
|
||||
gjson.GetBytes(body, "prompt_cache_key").String(),
|
||||
gjson.GetBytes(body, "prompt_cache_retention").String(),
|
||||
gjson.GetBytes(body, "store").Bool(),
|
||||
gjson.GetBytes(body, "instructions").Exists(),
|
||||
gjson.GetBytes(body, "reasoning.effort").String(),
|
||||
gjson.GetBytes(body, "reasoning.summary").String(),
|
||||
strings.TrimSpace(headers.Get("Chatgpt-Account-Id")) != "",
|
||||
strings.TrimSpace(headers.Get("Originator")),
|
||||
req.Model,
|
||||
opts.SourceFormat.String(),
|
||||
)
|
||||
}
|
||||
@@ -111,18 +111,19 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -222,11 +223,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
body, _ = sjson.DeleteBytes(body, "stream")
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -309,19 +311,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
|
||||
httpReq, continuity, err := e.cacheHelper(ctx, auth, from, url, req, opts, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, httpReq.Header, body, continuity)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -415,6 +418,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -596,8 +600,9 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
|
||||
func (e *CodexExecutor) cacheHelper(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, url string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) (*http.Request, codexContinuity, error) {
|
||||
var cache codexCache
|
||||
continuity := codexContinuity{}
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
@@ -610,30 +615,26 @@ func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Form
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key")
|
||||
if promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
||||
}
|
||||
} else if from == "openai" {
|
||||
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
|
||||
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
|
||||
}
|
||||
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
||||
cache.ID = continuity.Key
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
}
|
||||
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, continuity, err
|
||||
}
|
||||
if cache.ID != "" {
|
||||
httpReq.Header.Set("Conversation_id", cache.ID)
|
||||
httpReq.Header.Set("Session_id", cache.ID)
|
||||
}
|
||||
return httpReq, nil
|
||||
applyCodexContinuityHeaders(httpReq.Header, continuity)
|
||||
return httpReq, continuity, nil
|
||||
}
|
||||
|
||||
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
|
||||
@@ -646,7 +647,7 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Codex-Turn-Metadata", "")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-Client-Request-Id", "")
|
||||
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
|
||||
@@ -685,13 +686,39 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
}
|
||||
|
||||
func newCodexStatusErr(statusCode int, body []byte) statusErr {
|
||||
err := statusErr{code: statusCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
|
||||
errCode := statusCode
|
||||
if isCodexModelCapacityError(body) {
|
||||
errCode = http.StatusTooManyRequests
|
||||
}
|
||||
err := statusErr{code: errCode, msg: string(body)}
|
||||
if retryAfter := parseCodexRetryAfter(errCode, body, time.Now()); retryAfter != nil {
|
||||
err.retryAfter = retryAfter
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func isCodexModelCapacityError(errorBody []byte) bool {
|
||||
if len(errorBody) == 0 {
|
||||
return false
|
||||
}
|
||||
candidates := []string{
|
||||
gjson.GetBytes(errorBody, "error.message").String(),
|
||||
gjson.GetBytes(errorBody, "message").String(),
|
||||
string(errorBody),
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
lower := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if lower == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(lower, "selected model is at capacity") ||
|
||||
strings.Contains(lower, "model is at capacity. please try a different model") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
|
||||
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -27,7 +28,7 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
}
|
||||
url := "https://example.com/responses"
|
||||
|
||||
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
httpReq, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
@@ -42,14 +43,14 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
if gotKey != expectedKey {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
|
||||
}
|
||||
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
|
||||
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
|
||||
if gotSession := httpReq.Header.Get("session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("session_id = %q, want %q", gotSession, expectedKey)
|
||||
}
|
||||
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
|
||||
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
|
||||
if got := httpReq.Header.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
|
||||
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
|
||||
httpReq2, _, err := executor.cacheHelper(ctx, nil, sdktranslator.FromString("openai"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error (second call): %v", err)
|
||||
}
|
||||
@@ -62,3 +63,118 @@ func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFrom
|
||||
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIResponses_PreservesPromptCacheRetention(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
url := "https://example.com/responses"
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.3-codex",
|
||||
Payload: []byte(`{"model":"gpt-5.3-codex","prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
|
||||
}
|
||||
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true,"prompt_cache_retention":"persistent"}`)
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai-response"), url, req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "cache-key-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := gjson.GetBytes(body, "prompt_cache_retention").String(); got != "persistent" {
|
||||
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != "cache-key-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := httpReq.Header.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_UsesExecutionSessionForContinuity(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4"}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{Metadata: map[string]any{cliproxyexecutor.ExecutionSessionMetadataKey: "exec-session-1"}}
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("openai"), "https://example.com/responses", req, opts, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != "exec-session-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "exec-session-1")
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != "exec-session-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "exec-session-1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_FallsBackToStableAuthID(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5.4",
|
||||
Payload: []byte(`{"model":"gpt-5.4"}`),
|
||||
}
|
||||
auth := &cliproxyauth.Auth{ID: "codex-auth-1", Provider: "codex"}
|
||||
|
||||
httpReq, _, err := executor.cacheHelper(context.Background(), auth, sdktranslator.FromString("openai"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
|
||||
expected := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:auth:codex-auth-1")).String()
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != expected {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, expected)
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != expected {
|
||||
t.Fatalf("session_id = %q, want %q", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexExecutorCacheHelper_ClaudePreservesCacheContinuity(t *testing.T) {
|
||||
executor := &CodexExecutor{}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "claude-3-7-sonnet",
|
||||
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
|
||||
}
|
||||
rawJSON := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
|
||||
httpReq, continuity, err := executor.cacheHelper(context.Background(), nil, sdktranslator.FromString("claude"), "https://example.com/responses", req, cliproxyexecutor.Options{}, rawJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("cacheHelper error: %v", err)
|
||||
}
|
||||
if continuity.Key == "" {
|
||||
t.Fatal("continuity.Key = empty, want non-empty")
|
||||
}
|
||||
body, err := io.ReadAll(httpReq.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read request body: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(body, "prompt_cache_key").String(); got != continuity.Key {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
if got := httpReq.Header.Get("session_id"); got != continuity.Key {
|
||||
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,19 @@ func TestParseCodexRetryAfter(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewCodexStatusErrTreatsCapacityAsRetryableRateLimit(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model."}}`)
|
||||
|
||||
err := newCodexStatusErr(http.StatusBadRequest, body)
|
||||
|
||||
if got := err.StatusCode(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("status code = %d, want %d", got, http.StatusTooManyRequests)
|
||||
}
|
||||
if err.RetryAfter() != nil {
|
||||
t.Fatalf("expected nil explicit retryAfter for capacity fallback, got %v", *err.RetryAfter())
|
||||
}
|
||||
}
|
||||
|
||||
func itoa(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
@@ -178,7 +178,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
if !gjson.GetBytes(body, "instructions").Exists() {
|
||||
body, _ = sjson.SetBytes(body, "instructions", "")
|
||||
@@ -190,7 +189,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
return resp, err
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
@@ -209,6 +208,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -385,7 +385,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
|
||||
body, wsHeaders, continuity := applyCodexPromptCacheHeaders(ctx, auth, from, req, opts, body)
|
||||
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
@@ -403,6 +403,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr
|
||||
}
|
||||
|
||||
wsReqBody := buildCodexWebsocketRequestBody(body)
|
||||
logCodexRequestDiagnostics(ctx, auth, req, opts, wsHeaders, body, continuity)
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: wsURL,
|
||||
Method: "WEBSOCKET",
|
||||
@@ -761,13 +762,14 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
|
||||
return parsed.String(), nil
|
||||
}
|
||||
|
||||
func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) {
|
||||
func applyCodexPromptCacheHeaders(ctx context.Context, auth *cliproxyauth.Auth, from sdktranslator.Format, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, rawJSON []byte) ([]byte, http.Header, codexContinuity) {
|
||||
headers := http.Header{}
|
||||
if len(rawJSON) == 0 {
|
||||
return rawJSON, headers
|
||||
return rawJSON, headers, codexContinuity{}
|
||||
}
|
||||
|
||||
var cache codexCache
|
||||
continuity := codexContinuity{}
|
||||
if from == "claude" {
|
||||
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
|
||||
if userIDResult.Exists() {
|
||||
@@ -781,20 +783,22 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
}
|
||||
setCodexCache(key, cache)
|
||||
}
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "claude_user_cache"}
|
||||
}
|
||||
} else if from == "openai-response" {
|
||||
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
||||
cache.ID = promptCacheKey.String()
|
||||
continuity = codexContinuity{Key: cache.ID, Source: "prompt_cache_key"}
|
||||
}
|
||||
} else if from == "openai" {
|
||||
continuity = resolveCodexContinuity(ctx, auth, req, opts)
|
||||
cache.ID = continuity.Key
|
||||
}
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
headers.Set("Conversation_id", cache.ID)
|
||||
headers.Set("Session_id", cache.ID)
|
||||
}
|
||||
rawJSON = applyCodexContinuityBody(rawJSON, continuity)
|
||||
applyCodexContinuityHeaders(headers, continuity)
|
||||
|
||||
return rawJSON, headers
|
||||
return rawJSON, headers, continuity
|
||||
}
|
||||
|
||||
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
|
||||
@@ -826,7 +830,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
betaHeader = codexResponsesWebsocketBetaHeaderValue
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(headers, ginHeaders, "session_id", uuid.NewString())
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
|
||||
isAPIKey := false
|
||||
|
||||
@@ -9,7 +9,9 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -32,6 +34,49 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeaders_PreservesPromptCacheRetention(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5-codex",
|
||||
Payload: []byte(`{"prompt_cache_key":"cache-key-1","prompt_cache_retention":"persistent"}`),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5-codex","stream":true,"prompt_cache_retention":"persistent"}`)
|
||||
|
||||
updatedBody, headers, _ := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("openai-response"), req, cliproxyexecutor.Options{}, body)
|
||||
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != "cache-key-1" {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_retention").String(); got != "persistent" {
|
||||
t.Fatalf("prompt_cache_retention = %q, want %q", got, "persistent")
|
||||
}
|
||||
if got := headers.Get("session_id"); got != "cache-key-1" {
|
||||
t.Fatalf("session_id = %q, want %q", got, "cache-key-1")
|
||||
}
|
||||
if got := headers.Get("Conversation_id"); got != "" {
|
||||
t.Fatalf("Conversation_id = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeaders_ClaudePreservesContinuity(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "claude-3-7-sonnet",
|
||||
Payload: []byte(`{"metadata":{"user_id":"user-1"}}`),
|
||||
}
|
||||
body := []byte(`{"model":"gpt-5.4","stream":true}`)
|
||||
|
||||
updatedBody, headers, continuity := applyCodexPromptCacheHeaders(context.Background(), nil, sdktranslator.FromString("claude"), req, cliproxyexecutor.Options{}, body)
|
||||
|
||||
if continuity.Key == "" {
|
||||
t.Fatal("continuity.Key = empty, want non-empty")
|
||||
}
|
||||
if got := gjson.GetBytes(updatedBody, "prompt_cache_key").String(); got != continuity.Key {
|
||||
t.Fatalf("prompt_cache_key = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
if got := headers.Get("session_id"); got != continuity.Key {
|
||||
t.Fatalf("session_id = %q, want %q", got, continuity.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
type oaiToResponsesStateReasoning struct {
|
||||
ReasoningID string
|
||||
ReasoningData string
|
||||
OutputIndex int
|
||||
}
|
||||
type oaiToResponsesState struct {
|
||||
Seq int
|
||||
@@ -29,16 +31,19 @@ type oaiToResponsesState struct {
|
||||
MsgTextBuf map[int]*strings.Builder
|
||||
ReasoningBuf strings.Builder
|
||||
Reasonings []oaiToResponsesStateReasoning
|
||||
FuncArgsBuf map[int]*strings.Builder // index -> args
|
||||
FuncNames map[int]string // index -> name
|
||||
FuncCallIDs map[int]string // index -> call_id
|
||||
FuncArgsBuf map[string]*strings.Builder
|
||||
FuncNames map[string]string
|
||||
FuncCallIDs map[string]string
|
||||
FuncOutputIx map[string]int
|
||||
MsgOutputIx map[int]int
|
||||
NextOutputIx int
|
||||
// message item state per output index
|
||||
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
|
||||
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
|
||||
MsgItemDone map[int]bool // whether message done events were emitted
|
||||
// function item done state
|
||||
FuncArgsDone map[int]bool
|
||||
FuncItemDone map[int]bool
|
||||
FuncArgsDone map[string]bool
|
||||
FuncItemDone map[string]bool
|
||||
// usage aggregation
|
||||
PromptTokens int64
|
||||
CachedTokens int64
|
||||
@@ -60,15 +65,17 @@ func emitRespEvent(event string, payload []byte) []byte {
|
||||
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) [][]byte {
|
||||
if *param == nil {
|
||||
*param = &oaiToResponsesState{
|
||||
FuncArgsBuf: make(map[int]*strings.Builder),
|
||||
FuncNames: make(map[int]string),
|
||||
FuncCallIDs: make(map[int]string),
|
||||
FuncArgsBuf: make(map[string]*strings.Builder),
|
||||
FuncNames: make(map[string]string),
|
||||
FuncCallIDs: make(map[string]string),
|
||||
FuncOutputIx: make(map[string]int),
|
||||
MsgOutputIx: make(map[int]int),
|
||||
MsgTextBuf: make(map[int]*strings.Builder),
|
||||
MsgItemAdded: make(map[int]bool),
|
||||
MsgContentAdded: make(map[int]bool),
|
||||
MsgItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[int]bool),
|
||||
FuncItemDone: make(map[int]bool),
|
||||
FuncArgsDone: make(map[string]bool),
|
||||
FuncItemDone: make(map[string]bool),
|
||||
Reasonings: make([]oaiToResponsesStateReasoning, 0),
|
||||
}
|
||||
}
|
||||
@@ -125,6 +132,12 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
|
||||
nextSeq := func() int { st.Seq++; return st.Seq }
|
||||
allocOutputIndex := func() int {
|
||||
ix := st.NextOutputIx
|
||||
st.NextOutputIx++
|
||||
return ix
|
||||
}
|
||||
toolStateKey := func(outputIndex, toolIndex int) string { return fmt.Sprintf("%d:%d", outputIndex, toolIndex) }
|
||||
var out [][]byte
|
||||
|
||||
if !st.Started {
|
||||
@@ -135,14 +148,17 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningID = ""
|
||||
st.ReasoningIndex = 0
|
||||
st.FuncArgsBuf = make(map[int]*strings.Builder)
|
||||
st.FuncNames = make(map[int]string)
|
||||
st.FuncCallIDs = make(map[int]string)
|
||||
st.FuncArgsBuf = make(map[string]*strings.Builder)
|
||||
st.FuncNames = make(map[string]string)
|
||||
st.FuncCallIDs = make(map[string]string)
|
||||
st.FuncOutputIx = make(map[string]int)
|
||||
st.MsgOutputIx = make(map[int]int)
|
||||
st.NextOutputIx = 0
|
||||
st.MsgItemAdded = make(map[int]bool)
|
||||
st.MsgContentAdded = make(map[int]bool)
|
||||
st.MsgItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[int]bool)
|
||||
st.FuncItemDone = make(map[int]bool)
|
||||
st.FuncArgsDone = make(map[string]bool)
|
||||
st.FuncItemDone = make(map[string]bool)
|
||||
st.PromptTokens = 0
|
||||
st.CachedTokens = 0
|
||||
st.CompletionTokens = 0
|
||||
@@ -185,7 +201,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
outputItemDone, _ = sjson.SetBytes(outputItemDone, "item.summary.text", text)
|
||||
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
|
||||
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
|
||||
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text, OutputIndex: st.ReasoningIndex})
|
||||
st.ReasoningID = ""
|
||||
}
|
||||
|
||||
@@ -201,10 +217,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
stopReasoning(st.ReasoningBuf.String())
|
||||
st.ReasoningBuf.Reset()
|
||||
}
|
||||
if _, exists := st.MsgOutputIx[idx]; !exists {
|
||||
st.MsgOutputIx[idx] = allocOutputIndex()
|
||||
}
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
if !st.MsgItemAdded[idx] {
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", msgOutputIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
st.MsgItemAdded[idx] = true
|
||||
@@ -213,7 +233,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
part := []byte(`{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
part, _ = sjson.SetBytes(part, "sequence_number", nextSeq())
|
||||
part, _ = sjson.SetBytes(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
part, _ = sjson.SetBytes(part, "output_index", idx)
|
||||
part, _ = sjson.SetBytes(part, "output_index", msgOutputIndex)
|
||||
part, _ = sjson.SetBytes(part, "content_index", 0)
|
||||
out = append(out, emitRespEvent("response.content_part.added", part))
|
||||
st.MsgContentAdded[idx] = true
|
||||
@@ -222,7 +242,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
msg := []byte(`{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`)
|
||||
msg, _ = sjson.SetBytes(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.SetBytes(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", idx)
|
||||
msg, _ = sjson.SetBytes(msg, "output_index", msgOutputIndex)
|
||||
msg, _ = sjson.SetBytes(msg, "content_index", 0)
|
||||
msg, _ = sjson.SetBytes(msg, "delta", c.String())
|
||||
out = append(out, emitRespEvent("response.output_text.delta", msg))
|
||||
@@ -238,10 +258,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// On first appearance, add reasoning item and part
|
||||
if st.ReasoningID == "" {
|
||||
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
|
||||
st.ReasoningIndex = idx
|
||||
st.ReasoningIndex = allocOutputIndex()
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
item, _ = sjson.SetBytes(item, "output_index", idx)
|
||||
item, _ = sjson.SetBytes(item, "output_index", st.ReasoningIndex)
|
||||
item, _ = sjson.SetBytes(item, "item.id", st.ReasoningID)
|
||||
out = append(out, emitRespEvent("response.output_item.added", item))
|
||||
part := []byte(`{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`)
|
||||
@@ -269,6 +289,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
// Before emitting any function events, if a message is open for this index,
|
||||
// close its text/content to match Codex expected ordering.
|
||||
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
|
||||
msgOutputIndex := st.MsgOutputIx[idx]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[idx]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -276,7 +297,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
done, _ = sjson.SetBytes(done, "output_index", idx)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -284,69 +305,72 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", idx)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", idx)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.MsgItemDone[idx] = true
|
||||
}
|
||||
|
||||
// Only emit item.added once per tool call and preserve call_id across chunks.
|
||||
newCallID := tcs.Get("0.id").String()
|
||||
nameChunk := tcs.Get("0.function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[idx] = nameChunk
|
||||
}
|
||||
existingCallID := st.FuncCallIDs[idx]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
// First time seeing a valid call_id for this index
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[idx] = newCallID
|
||||
shouldEmitItem = true
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", idx)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
name := st.FuncNames[idx]
|
||||
o, _ = sjson.SetBytes(o, "item.name", name)
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
// Ensure args buffer exists for this index
|
||||
if st.FuncArgsBuf[idx] == nil {
|
||||
st.FuncArgsBuf[idx] = &strings.Builder{}
|
||||
}
|
||||
|
||||
// Append arguments delta if available and we have a valid call_id to reference
|
||||
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
|
||||
// Prefer an already known call_id; fall back to newCallID if first time
|
||||
refCallID := st.FuncCallIDs[idx]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
tcs.ForEach(func(_, tc gjson.Result) bool {
|
||||
toolIndex := int(tc.Get("index").Int())
|
||||
key := toolStateKey(idx, toolIndex)
|
||||
newCallID := tc.Get("id").String()
|
||||
nameChunk := tc.Get("function.name").String()
|
||||
if nameChunk != "" {
|
||||
st.FuncNames[key] = nameChunk
|
||||
}
|
||||
if refCallID != "" {
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", idx)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
|
||||
existingCallID := st.FuncCallIDs[key]
|
||||
effectiveCallID := existingCallID
|
||||
shouldEmitItem := false
|
||||
if existingCallID == "" && newCallID != "" {
|
||||
effectiveCallID = newCallID
|
||||
st.FuncCallIDs[key] = newCallID
|
||||
st.FuncOutputIx[key] = allocOutputIndex()
|
||||
shouldEmitItem = true
|
||||
}
|
||||
st.FuncArgsBuf[idx].WriteString(args.String())
|
||||
}
|
||||
|
||||
if shouldEmitItem && effectiveCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
o := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`)
|
||||
o, _ = sjson.SetBytes(o, "sequence_number", nextSeq())
|
||||
o, _ = sjson.SetBytes(o, "output_index", outputIndex)
|
||||
o, _ = sjson.SetBytes(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
|
||||
o, _ = sjson.SetBytes(o, "item.call_id", effectiveCallID)
|
||||
o, _ = sjson.SetBytes(o, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.added", o))
|
||||
}
|
||||
|
||||
if st.FuncArgsBuf[key] == nil {
|
||||
st.FuncArgsBuf[key] = &strings.Builder{}
|
||||
}
|
||||
|
||||
if args := tc.Get("function.arguments"); args.Exists() && args.String() != "" {
|
||||
refCallID := st.FuncCallIDs[key]
|
||||
if refCallID == "" {
|
||||
refCallID = newCallID
|
||||
}
|
||||
if refCallID != "" {
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
ad := []byte(`{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`)
|
||||
ad, _ = sjson.SetBytes(ad, "sequence_number", nextSeq())
|
||||
ad, _ = sjson.SetBytes(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
|
||||
ad, _ = sjson.SetBytes(ad, "output_index", outputIndex)
|
||||
ad, _ = sjson.SetBytes(ad, "delta", args.String())
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
|
||||
}
|
||||
st.FuncArgsBuf[key].WriteString(args.String())
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,15 +384,10 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
for i := range st.MsgItemAdded {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
sort.Slice(idxs, func(i, j int) bool { return st.MsgOutputIx[idxs[i]] < st.MsgOutputIx[idxs[j]] })
|
||||
for _, i := range idxs {
|
||||
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
|
||||
msgOutputIndex := st.MsgOutputIx[i]
|
||||
fullText := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
fullText = b.String()
|
||||
@@ -376,7 +395,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
done, _ = sjson.SetBytes(done, "output_index", i)
|
||||
done, _ = sjson.SetBytes(done, "output_index", msgOutputIndex)
|
||||
done, _ = sjson.SetBytes(done, "content_index", 0)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_text.done", done))
|
||||
@@ -384,14 +403,14 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", i)
|
||||
partDone, _ = sjson.SetBytes(partDone, "output_index", msgOutputIndex)
|
||||
partDone, _ = sjson.SetBytes(partDone, "content_index", 0)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitRespEvent("response.content_part.done", partDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", msgOutputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.content.0.text", fullText)
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
@@ -407,43 +426,42 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
|
||||
// Emit function call done events for any active function calls
|
||||
if len(st.FuncCallIDs) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncCallIDs))
|
||||
for i := range st.FuncCallIDs {
|
||||
idxs = append(idxs, i)
|
||||
keys := make([]string, 0, len(st.FuncCallIDs))
|
||||
for key := range st.FuncCallIDs {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
callID := st.FuncCallIDs[i]
|
||||
if callID == "" || st.FuncItemDone[i] {
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
left := st.FuncOutputIx[keys[i]]
|
||||
right := st.FuncOutputIx[keys[j]]
|
||||
return left < right || (left == right && keys[i] < keys[j])
|
||||
})
|
||||
for _, key := range keys {
|
||||
callID := st.FuncCallIDs[key]
|
||||
if callID == "" || st.FuncItemDone[key] {
|
||||
continue
|
||||
}
|
||||
outputIndex := st.FuncOutputIx[key]
|
||||
args := "{}"
|
||||
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
|
||||
if b := st.FuncArgsBuf[key]; b != nil && b.Len() > 0 {
|
||||
args = b.String()
|
||||
}
|
||||
fcDone := []byte(`{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "sequence_number", nextSeq())
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", i)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "output_index", outputIndex)
|
||||
fcDone, _ = sjson.SetBytes(fcDone, "arguments", args)
|
||||
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
|
||||
|
||||
itemDone := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "sequence_number", nextSeq())
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", i)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "output_index", outputIndex)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.arguments", args)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.call_id", callID)
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[i])
|
||||
itemDone, _ = sjson.SetBytes(itemDone, "item.name", st.FuncNames[key])
|
||||
out = append(out, emitRespEvent("response.output_item.done", itemDone))
|
||||
st.FuncItemDone[i] = true
|
||||
st.FuncArgsDone[i] = true
|
||||
st.FuncItemDone[key] = true
|
||||
st.FuncArgsDone[key] = true
|
||||
}
|
||||
}
|
||||
completed := []byte(`{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`)
|
||||
@@ -516,28 +534,21 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
// Build response.output using aggregated buffers
|
||||
outputsWrapper := []byte(`{"arr":[]}`)
|
||||
type completedOutputItem struct {
|
||||
index int
|
||||
raw []byte
|
||||
}
|
||||
outputItems := make([]completedOutputItem, 0, len(st.Reasonings)+len(st.MsgItemAdded)+len(st.FuncArgsBuf))
|
||||
if len(st.Reasonings) > 0 {
|
||||
for _, r := range st.Reasonings {
|
||||
item := []byte(`{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`)
|
||||
item, _ = sjson.SetBytes(item, "id", r.ReasoningID)
|
||||
item, _ = sjson.SetBytes(item, "summary.0.text", r.ReasoningData)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: r.OutputIndex, raw: item})
|
||||
}
|
||||
}
|
||||
// Append message items in ascending index order
|
||||
if len(st.MsgItemAdded) > 0 {
|
||||
midxs := make([]int, 0, len(st.MsgItemAdded))
|
||||
for i := range st.MsgItemAdded {
|
||||
midxs = append(midxs, i)
|
||||
}
|
||||
for i := 0; i < len(midxs); i++ {
|
||||
for j := i + 1; j < len(midxs); j++ {
|
||||
if midxs[j] < midxs[i] {
|
||||
midxs[i], midxs[j] = midxs[j], midxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range midxs {
|
||||
txt := ""
|
||||
if b := st.MsgTextBuf[i]; b != nil {
|
||||
txt = b.String()
|
||||
@@ -545,37 +556,29 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
item := []byte(`{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
|
||||
item, _ = sjson.SetBytes(item, "content.0.text", txt)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.MsgOutputIx[i], raw: item})
|
||||
}
|
||||
}
|
||||
if len(st.FuncArgsBuf) > 0 {
|
||||
idxs := make([]int, 0, len(st.FuncArgsBuf))
|
||||
for i := range st.FuncArgsBuf {
|
||||
idxs = append(idxs, i)
|
||||
}
|
||||
// small-N sort without extra imports
|
||||
for i := 0; i < len(idxs); i++ {
|
||||
for j := i + 1; j < len(idxs); j++ {
|
||||
if idxs[j] < idxs[i] {
|
||||
idxs[i], idxs[j] = idxs[j], idxs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, i := range idxs {
|
||||
for key := range st.FuncArgsBuf {
|
||||
args := ""
|
||||
if b := st.FuncArgsBuf[i]; b != nil {
|
||||
if b := st.FuncArgsBuf[key]; b != nil {
|
||||
args = b.String()
|
||||
}
|
||||
callID := st.FuncCallIDs[i]
|
||||
name := st.FuncNames[i]
|
||||
callID := st.FuncCallIDs[key]
|
||||
name := st.FuncNames[key]
|
||||
item := []byte(`{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`)
|
||||
item, _ = sjson.SetBytes(item, "id", fmt.Sprintf("fc_%s", callID))
|
||||
item, _ = sjson.SetBytes(item, "arguments", args)
|
||||
item, _ = sjson.SetBytes(item, "call_id", callID)
|
||||
item, _ = sjson.SetBytes(item, "name", name)
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item)
|
||||
outputItems = append(outputItems, completedOutputItem{index: st.FuncOutputIx[key], raw: item})
|
||||
}
|
||||
}
|
||||
sort.Slice(outputItems, func(i, j int) bool { return outputItems[i].index < outputItems[j].index })
|
||||
for _, item := range outputItems {
|
||||
outputsWrapper, _ = sjson.SetRawBytes(outputsWrapper, "arr.-1", item.raw)
|
||||
}
|
||||
if gjson.GetBytes(outputsWrapper, "arr.#").Int() > 0 {
|
||||
completed, _ = sjson.SetRawBytes(completed, "response.output", []byte(gjson.GetBytes(outputsWrapper, "arr").Raw))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,305 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func parseOpenAIResponsesSSEEvent(t *testing.T, chunk []byte) (string, gjson.Result) {
|
||||
t.Helper()
|
||||
|
||||
lines := strings.Split(string(chunk), "\n")
|
||||
if len(lines) < 2 {
|
||||
t.Fatalf("unexpected SSE chunk: %q", chunk)
|
||||
}
|
||||
|
||||
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
|
||||
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
|
||||
if !gjson.Valid(dataLine) {
|
||||
t.Fatalf("invalid SSE data JSON: %q", dataLine)
|
||||
}
|
||||
return event, gjson.Parse(dataLine)
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultipleToolCallsRemainSeparate(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\",\"limit\":400,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.{yml,yaml}\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_test","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
addedNames := map[string]string{}
|
||||
doneArgs := map[string]string{}
|
||||
doneNames := map[string]string{}
|
||||
outputItems := map[string]gjson.Result{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
addedNames[data.Get("item.call_id").String()] = data.Get("item.name").String()
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
doneArgs[callID] = data.Get("item.arguments").String()
|
||||
doneNames[callID] = data.Get("item.name").String()
|
||||
case "response.completed":
|
||||
output := data.Get("response.output")
|
||||
for _, item := range output.Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
outputItems[item.Get("call_id").String()] = item
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(addedNames) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(addedNames))
|
||||
}
|
||||
if len(doneArgs) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(doneArgs))
|
||||
}
|
||||
|
||||
if addedNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected added name for call_read: %q", addedNames["call_read"])
|
||||
}
|
||||
if addedNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected added name for call_glob: %q", addedNames["call_glob"])
|
||||
}
|
||||
|
||||
if !gjson.Valid(doneArgs["call_read"]) {
|
||||
t.Fatalf("invalid JSON args for call_read: %q", doneArgs["call_read"])
|
||||
}
|
||||
if !gjson.Valid(doneArgs["call_glob"]) {
|
||||
t.Fatalf("invalid JSON args for call_glob: %q", doneArgs["call_glob"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_read"], "}{") {
|
||||
t.Fatalf("call_read args were concatenated: %q", doneArgs["call_read"])
|
||||
}
|
||||
if strings.Contains(doneArgs["call_glob"], "}{") {
|
||||
t.Fatalf("call_glob args were concatenated: %q", doneArgs["call_glob"])
|
||||
}
|
||||
|
||||
if doneNames["call_read"] != "read" {
|
||||
t.Fatalf("unexpected done name for call_read: %q", doneNames["call_read"])
|
||||
}
|
||||
if doneNames["call_glob"] != "glob" {
|
||||
t.Fatalf("unexpected done name for call_glob: %q", doneNames["call_glob"])
|
||||
}
|
||||
|
||||
if got := gjson.Get(doneArgs["call_read"], "filePath").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected filePath for call_read: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "path").String(); got != `C:\repo` {
|
||||
t.Fatalf("unexpected path for call_glob: %q", got)
|
||||
}
|
||||
if got := gjson.Get(doneArgs["call_glob"], "pattern").String(); got != "*.{yml,yaml}" {
|
||||
t.Fatalf("unexpected pattern for call_glob: %q", got)
|
||||
}
|
||||
|
||||
if len(outputItems) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in response.output, got %d", len(outputItems))
|
||||
}
|
||||
if outputItems["call_read"].Get("name").String() != "read" {
|
||||
t.Fatalf("unexpected response.output name for call_read: %q", outputItems["call_read"].Get("name").String())
|
||||
}
|
||||
if outputItems["call_glob"].Get("name").String() != "glob" {
|
||||
t.Fatalf("unexpected response.output name for call_glob: %q", outputItems["call_glob"].Get("name").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MultiChoiceToolCallsUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice0","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_multi_choice","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
type fcEvent struct {
|
||||
outputIndex int64
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
added := map[string]fcEvent{}
|
||||
done := map[string]fcEvent{}
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.added":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
added[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
}
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() != "function_call" {
|
||||
continue
|
||||
}
|
||||
callID := data.Get("item.call_id").String()
|
||||
done[callID] = fcEvent{
|
||||
outputIndex: data.Get("output_index").Int(),
|
||||
name: data.Get("item.name").String(),
|
||||
arguments: data.Get("item.arguments").String(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(added) != 2 {
|
||||
t.Fatalf("expected 2 function_call added events, got %d", len(added))
|
||||
}
|
||||
if len(done) != 2 {
|
||||
t.Fatalf("expected 2 function_call done events, got %d", len(done))
|
||||
}
|
||||
|
||||
if added["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected added name for call_choice0: %q", added["call_choice0"].name)
|
||||
}
|
||||
if added["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected added name for call_choice1: %q", added["call_choice1"].name)
|
||||
}
|
||||
if added["call_choice0"].outputIndex == added["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct output indexes for different choices, both got %d", added["call_choice0"].outputIndex)
|
||||
}
|
||||
|
||||
if !gjson.Valid(done["call_choice0"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice0: %q", done["call_choice0"].arguments)
|
||||
}
|
||||
if !gjson.Valid(done["call_choice1"].arguments) {
|
||||
t.Fatalf("invalid JSON args for call_choice1: %q", done["call_choice1"].arguments)
|
||||
}
|
||||
if done["call_choice0"].outputIndex == done["call_choice1"].outputIndex {
|
||||
t.Fatalf("expected distinct done output indexes for different choices, both got %d", done["call_choice0"].outputIndex)
|
||||
}
|
||||
if done["call_choice0"].name != "glob" {
|
||||
t.Fatalf("unexpected done name for call_choice0: %q", done["call_choice0"].name)
|
||||
}
|
||||
if done["call_choice1"].name != "read" {
|
||||
t.Fatalf("unexpected done name for call_choice1: %q", done["call_choice1"].name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_MixedMessageAndToolUseDistinctOutputIndexes(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello","reasoning_content":null,"tool_calls":null},"finish_reason":null},{"index":1,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_choice1","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_mixed","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"stop"},{"index":1,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var messageOutputIndex int64 = -1
|
||||
var toolOutputIndex int64 = -1
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
if ev != "response.output_item.added" {
|
||||
continue
|
||||
}
|
||||
switch data.Get("item.type").String() {
|
||||
case "message":
|
||||
if data.Get("item.id").String() == "msg_resp_mixed_0" {
|
||||
messageOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
case "function_call":
|
||||
if data.Get("item.call_id").String() == "call_choice1" {
|
||||
toolOutputIndex = data.Get("output_index").Int()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messageOutputIndex < 0 {
|
||||
t.Fatal("did not find message output index")
|
||||
}
|
||||
if toolOutputIndex < 0 {
|
||||
t.Fatal("did not find tool output index")
|
||||
}
|
||||
if messageOutputIndex == toolOutputIndex {
|
||||
t.Fatalf("expected distinct output indexes for message and tool call, both got %d", messageOutputIndex)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIChatCompletionsResponseToOpenAIResponses_FunctionCallDoneAndCompletedOutputStayAscending(t *testing.T) {
|
||||
in := []string{
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":0,"id":"call_glob","type":"function","function":{"name":"glob","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"C:\\\\repo\",\"pattern\":\"*.go\"}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"index":1,"id":"call_read","type":"function","function":{"name":"read","arguments":""}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"{\"filePath\":\"C:\\\\repo\\\\README.md\",\"limit\":20,\"offset\":1}"}}]},"finish_reason":null}]}`,
|
||||
`data: {"id":"resp_order","object":"chat.completion.chunk","created":1773896263,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":"tool_calls"}],"usage":{"completion_tokens":10,"total_tokens":20,"prompt_tokens":10}}`,
|
||||
}
|
||||
|
||||
request := []byte(`{"model":"gpt-5.4","tool_choice":"auto","parallel_tool_calls":true}`)
|
||||
|
||||
var param any
|
||||
var out [][]byte
|
||||
for _, line := range in {
|
||||
out = append(out, ConvertOpenAIChatCompletionsResponseToOpenAIResponses(context.Background(), "model", request, request, []byte(line), ¶m)...)
|
||||
}
|
||||
|
||||
var doneIndexes []int64
|
||||
var completedOrder []string
|
||||
|
||||
for _, chunk := range out {
|
||||
ev, data := parseOpenAIResponsesSSEEvent(t, chunk)
|
||||
switch ev {
|
||||
case "response.output_item.done":
|
||||
if data.Get("item.type").String() == "function_call" {
|
||||
doneIndexes = append(doneIndexes, data.Get("output_index").Int())
|
||||
}
|
||||
case "response.completed":
|
||||
for _, item := range data.Get("response.output").Array() {
|
||||
if item.Get("type").String() == "function_call" {
|
||||
completedOrder = append(completedOrder, item.Get("call_id").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(doneIndexes) != 2 {
|
||||
t.Fatalf("expected 2 function_call done indexes, got %d", len(doneIndexes))
|
||||
}
|
||||
if doneIndexes[0] >= doneIndexes[1] {
|
||||
t.Fatalf("expected ascending done output indexes, got %v", doneIndexes)
|
||||
}
|
||||
if len(completedOrder) != 2 {
|
||||
t.Fatalf("expected 2 function_call items in completed output, got %d", len(completedOrder))
|
||||
}
|
||||
if completedOrder[0] != "call_glob" || completedOrder[1] != "call_read" {
|
||||
t.Fatalf("unexpected completed function_call order: %v", completedOrder)
|
||||
}
|
||||
}
|
||||
@@ -923,8 +923,10 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
auth.Index = existing.Index
|
||||
auth.indexAssigned = existing.indexAssigned
|
||||
}
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
if !existing.Disabled && existing.Status != StatusDisabled && !auth.Disabled && auth.Status != StatusDisabled {
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
}
|
||||
}
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
|
||||
@@ -47,3 +47,158 @@ func TestManager_Update_PreservesModelStates(t *testing.T) {
|
||||
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_DisabledExistingDoesNotInheritModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register a disabled auth with existing ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-disabled",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 5},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Update with empty ModelStates — should NOT inherit stale states.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-disabled",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-disabled")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected disabled auth NOT to inherit ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_ActiveToDisabledDoesNotInheritModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register an active auth with ModelStates (simulates existing live auth).
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-a2d",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 9},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// File watcher deletes config → synthesizes Disabled=true auth → Update.
|
||||
// Even though existing is active, incoming auth is disabled → skip inheritance.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-a2d",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-a2d")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected active→disabled transition NOT to inherit ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_DisabledToActiveDoesNotInheritStaleModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
// Register a disabled auth with stale ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-d2a",
|
||||
Provider: "claude",
|
||||
Disabled: true,
|
||||
Status: StatusDisabled,
|
||||
ModelStates: map[string]*ModelState{
|
||||
"stale-model": {
|
||||
Quota: QuotaState{BackoffLevel: 4},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Re-enable: incoming auth is active, existing is disabled → skip inheritance.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-d2a",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-d2a")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected disabled→active transition NOT to inherit stale ModelStates, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_ActiveInheritsModelStates(t *testing.T) {
|
||||
m := NewManager(nil, nil, nil)
|
||||
|
||||
model := "active-model"
|
||||
backoffLevel := 3
|
||||
|
||||
// Register an active auth with ModelStates.
|
||||
if _, err := m.Register(context.Background(), &Auth{
|
||||
ID: "auth-active",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
ModelStates: map[string]*ModelState{
|
||||
model: {
|
||||
Quota: QuotaState{BackoffLevel: backoffLevel},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register auth: %v", err)
|
||||
}
|
||||
|
||||
// Update with empty ModelStates — both sides active → SHOULD inherit.
|
||||
if _, err := m.Update(context.Background(), &Auth{
|
||||
ID: "auth-active",
|
||||
Provider: "claude",
|
||||
Status: StatusActive,
|
||||
}); err != nil {
|
||||
t.Fatalf("update auth: %v", err)
|
||||
}
|
||||
|
||||
updated, ok := m.GetByID("auth-active")
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected auth to be present")
|
||||
}
|
||||
if len(updated.ModelStates) == 0 {
|
||||
t.Fatalf("expected active auth to inherit ModelStates")
|
||||
}
|
||||
state := updated.ModelStates[model]
|
||||
if state == nil {
|
||||
t.Fatalf("expected model state to be present")
|
||||
}
|
||||
if state.Quota.BackoffLevel != backoffLevel {
|
||||
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -286,10 +286,12 @@ func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.A
|
||||
var err error
|
||||
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
if !existing.Disabled && existing.Status != coreauth.StatusDisabled && !auth.Disabled && auth.Status != coreauth.StatusDisabled {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
auth.NextRefreshAfter = existing.NextRefreshAfter
|
||||
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
|
||||
auth.ModelStates = existing.ModelStates
|
||||
}
|
||||
}
|
||||
op = "update"
|
||||
_, err = s.coreManager.Update(ctx, auth)
|
||||
|
||||
85
sdk/cliproxy/service_stale_state_test.go
Normal file
85
sdk/cliproxy/service_stale_state_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestServiceApplyCoreAuthAddOrUpdate_DeleteReAddDoesNotInheritStaleRuntimeState(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
|
||||
authID := "service-stale-state-auth"
|
||||
modelID := "stale-model"
|
||||
lastRefreshedAt := time.Date(2026, time.March, 1, 8, 0, 0, 0, time.UTC)
|
||||
nextRefreshAfter := lastRefreshedAt.Add(30 * time.Minute)
|
||||
|
||||
t.Cleanup(func() {
|
||||
GlobalModelRegistry().UnregisterClient(authID)
|
||||
})
|
||||
|
||||
service.applyCoreAuthAddOrUpdate(context.Background(), &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: "claude",
|
||||
Status: coreauth.StatusActive,
|
||||
LastRefreshedAt: lastRefreshedAt,
|
||||
NextRefreshAfter: nextRefreshAfter,
|
||||
ModelStates: map[string]*coreauth.ModelState{
|
||||
modelID: {
|
||||
Quota: coreauth.QuotaState{BackoffLevel: 7},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
service.applyCoreAuthRemoval(context.Background(), authID)
|
||||
|
||||
disabled, ok := service.coreManager.GetByID(authID)
|
||||
if !ok || disabled == nil {
|
||||
t.Fatalf("expected disabled auth after removal")
|
||||
}
|
||||
if !disabled.Disabled || disabled.Status != coreauth.StatusDisabled {
|
||||
t.Fatalf("expected disabled auth after removal, got disabled=%v status=%v", disabled.Disabled, disabled.Status)
|
||||
}
|
||||
if disabled.LastRefreshedAt.IsZero() {
|
||||
t.Fatalf("expected disabled auth to still carry prior LastRefreshedAt for regression setup")
|
||||
}
|
||||
if disabled.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("expected disabled auth to still carry prior NextRefreshAfter for regression setup")
|
||||
}
|
||||
if len(disabled.ModelStates) == 0 {
|
||||
t.Fatalf("expected disabled auth to still carry prior ModelStates for regression setup")
|
||||
}
|
||||
|
||||
service.applyCoreAuthAddOrUpdate(context.Background(), &coreauth.Auth{
|
||||
ID: authID,
|
||||
Provider: "claude",
|
||||
Status: coreauth.StatusActive,
|
||||
})
|
||||
|
||||
updated, ok := service.coreManager.GetByID(authID)
|
||||
if !ok || updated == nil {
|
||||
t.Fatalf("expected re-added auth to be present")
|
||||
}
|
||||
if updated.Disabled {
|
||||
t.Fatalf("expected re-added auth to be active")
|
||||
}
|
||||
if !updated.LastRefreshedAt.IsZero() {
|
||||
t.Fatalf("expected LastRefreshedAt to reset on delete -> re-add, got %v", updated.LastRefreshedAt)
|
||||
}
|
||||
if !updated.NextRefreshAfter.IsZero() {
|
||||
t.Fatalf("expected NextRefreshAfter to reset on delete -> re-add, got %v", updated.NextRefreshAfter)
|
||||
}
|
||||
if len(updated.ModelStates) != 0 {
|
||||
t.Fatalf("expected ModelStates to reset on delete -> re-add, got %d entries", len(updated.ModelStates))
|
||||
}
|
||||
if models := registry.GetGlobalRegistry().GetModelsForClient(authID); len(models) == 0 {
|
||||
t.Fatalf("expected re-added auth to re-register models in global registry")
|
||||
}
|
||||
}
|
||||
@@ -68,14 +68,18 @@ func Parse(raw string) (Setting, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func cloneDefaultTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
return transport.Clone()
|
||||
}
|
||||
return &http.Transport{}
|
||||
}
|
||||
|
||||
// NewDirectTransport returns a transport that bypasses environment proxies.
|
||||
func NewDirectTransport() *http.Transport {
|
||||
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
|
||||
clone := transport.Clone()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
return &http.Transport{Proxy: nil}
|
||||
clone := cloneDefaultTransport()
|
||||
clone.Proxy = nil
|
||||
return clone
|
||||
}
|
||||
|
||||
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
|
||||
@@ -102,14 +106,16 @@ func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
|
||||
if errSOCKS5 != nil {
|
||||
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
|
||||
}
|
||||
return &http.Transport{
|
||||
Proxy: nil,
|
||||
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}, setting.Mode, nil
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Proxy = nil
|
||||
transport.DialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
}
|
||||
return transport, setting.Mode, nil
|
||||
}
|
||||
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
|
||||
transport := cloneDefaultTransport()
|
||||
transport.Proxy = http.ProxyURL(setting.URL)
|
||||
return transport, setting.Mode, nil
|
||||
default:
|
||||
return nil, setting.Mode, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,16 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func mustDefaultTransport(t *testing.T) *http.Transport {
|
||||
t.Helper()
|
||||
|
||||
transport, ok := http.DefaultTransport.(*http.Transport)
|
||||
if !ok || transport == nil {
|
||||
t.Fatal("http.DefaultTransport is not an *http.Transport")
|
||||
}
|
||||
return transport
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -86,4 +96,44 @@ func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
|
||||
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
|
||||
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
|
||||
}
|
||||
|
||||
defaultTransport := mustDefaultTransport(t)
|
||||
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
|
||||
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
|
||||
}
|
||||
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
|
||||
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildHTTPTransportSOCKS5ProxyInheritsDefaultTransportSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport, mode, errBuild := BuildHTTPTransport("socks5://proxy.example.com:1080")
|
||||
if errBuild != nil {
|
||||
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
|
||||
}
|
||||
if mode != ModeProxy {
|
||||
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
|
||||
}
|
||||
if transport == nil {
|
||||
t.Fatal("expected transport, got nil")
|
||||
}
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("expected SOCKS5 transport to bypass http proxy function")
|
||||
}
|
||||
|
||||
defaultTransport := mustDefaultTransport(t)
|
||||
if transport.ForceAttemptHTTP2 != defaultTransport.ForceAttemptHTTP2 {
|
||||
t.Fatalf("ForceAttemptHTTP2 = %v, want %v", transport.ForceAttemptHTTP2, defaultTransport.ForceAttemptHTTP2)
|
||||
}
|
||||
if transport.IdleConnTimeout != defaultTransport.IdleConnTimeout {
|
||||
t.Fatalf("IdleConnTimeout = %v, want %v", transport.IdleConnTimeout, defaultTransport.IdleConnTimeout)
|
||||
}
|
||||
if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout {
|
||||
t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user