Merge branch 'router-for-me:main' into my-fix

This commit is contained in:
AhDEV
2026-05-06 16:41:14 +08:00
committed by GitHub
81 changed files with 4470 additions and 1903 deletions
@@ -0,0 +1,107 @@
package management
import (
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type apiKeyUsageEntry struct {
Success int64 `json:"success"`
Failed int64 `json:"failed"`
RecentRequests []coreauth.RecentRequestBucket `json:"recent_requests"`
}
func mergeRecentRequestBuckets(dst, src []coreauth.RecentRequestBucket) []coreauth.RecentRequestBucket {
if len(dst) == 0 {
return src
}
if len(src) == 0 {
return dst
}
if len(dst) != len(src) {
n := len(dst)
if len(src) < n {
n = len(src)
}
for i := 0; i < n; i++ {
dst[i].Success += src[i].Success
dst[i].Failed += src[i].Failed
}
return dst
}
for i := range dst {
dst[i].Success += src[i].Success
dst[i].Failed += src[i].Failed
}
return dst
}
// GetAPIKeyUsage returns recent request buckets for all in-memory api_key auths,
// grouped by provider and keyed by "base_url|api_key".
func (h *Handler) GetAPIKeyUsage(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler not initialized"})
return
}
h.mu.Lock()
manager := h.authManager
h.mu.Unlock()
if manager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
now := time.Now()
out := make(map[string]map[string]apiKeyUsageEntry)
for _, auth := range manager.List() {
if auth == nil {
continue
}
kind, apiKey := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
continue
}
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
continue
}
baseURL := ""
if auth.Attributes != nil {
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
if baseURL == "" {
baseURL = strings.TrimSpace(auth.Attributes["base-url"])
}
}
compositeKey := baseURL + "|" + apiKey
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "" {
provider = "unknown"
}
recent := auth.RecentRequestsSnapshot(now)
providerBucket, ok := out[provider]
if !ok {
providerBucket = make(map[string]apiKeyUsageEntry)
out[provider] = providerBucket
}
if existing, exists := providerBucket[compositeKey]; exists {
existing.Success += auth.Success
existing.Failed += auth.Failed
existing.RecentRequests = mergeRecentRequestBuckets(existing.RecentRequests, recent)
providerBucket[compositeKey] = existing
continue
}
providerBucket[compositeKey] = apiKeyUsageEntry{
Success: auth.Success,
Failed: auth.Failed,
RecentRequests: recent,
}
}
c.JSON(http.StatusOK, out)
}
@@ -0,0 +1,95 @@
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func sumRecentRequestBuckets(buckets []coreauth.RecentRequestBucket) (int64, int64) {
var success int64
var failed int64
for _, bucket := range buckets {
success += bucket.Success
failed += bucket.Failed
}
return success, failed
}
func TestGetAPIKeyUsage_GroupsByProviderAndAPIKey(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
manager := coreauth.NewManager(nil, nil, nil)
if _, err := manager.Register(context.Background(), &coreauth.Auth{
ID: "codex-auth",
Provider: "codex",
Attributes: map[string]string{
"api_key": "codex-key",
"base_url": "https://codex.example.com",
},
}); err != nil {
t.Fatalf("register codex auth: %v", err)
}
if _, err := manager.Register(context.Background(), &coreauth.Auth{
ID: "claude-auth",
Provider: "claude",
Attributes: map[string]string{
"api_key": "claude-key",
"base_url": "https://claude.example.com",
},
}); err != nil {
t.Fatalf("register claude auth: %v", err)
}
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: true})
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: false})
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "claude-auth", Provider: "claude", Model: "claude-4", Success: true})
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodGet, "/v0/management/api-key-usage", nil)
ginCtx.Request = req
h.GetAPIKeyUsage(ginCtx)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload map[string]map[string]apiKeyUsageEntry
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode payload: %v", err)
}
codexEntry := payload["codex"]["https://codex.example.com|codex-key"]
if codexEntry.Success != 1 || codexEntry.Failed != 1 {
t.Fatalf("codex totals = %d/%d, want 1/1", codexEntry.Success, codexEntry.Failed)
}
if len(codexEntry.RecentRequests) != 20 {
t.Fatalf("codex buckets len = %d, want 20", len(codexEntry.RecentRequests))
}
codexSuccess, codexFailed := sumRecentRequestBuckets(codexEntry.RecentRequests)
if codexSuccess != 1 || codexFailed != 1 {
t.Fatalf("codex totals = %d/%d, want 1/1", codexSuccess, codexFailed)
}
claudeEntry := payload["claude"]["https://claude.example.com|claude-key"]
if claudeEntry.Success != 1 || claudeEntry.Failed != 0 {
t.Fatalf("claude totals = %d/%d, want 1/0", claudeEntry.Success, claudeEntry.Failed)
}
if len(claudeEntry.RecentRequests) != 20 {
t.Fatalf("claude buckets len = %d, want 20", len(claudeEntry.RecentRequests))
}
claudeSuccess, claudeFailed := sumRecentRequestBuckets(claudeEntry.RecentRequests)
if claudeSuccess != 1 || claudeFailed != 0 {
t.Fatalf("claude totals = %d/%d, want 1/0", claudeSuccess, claudeFailed)
}
}
+6 -16
View File
@@ -388,6 +388,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
"source": "memory",
"size": int64(0),
}
entry["success"] = auth.Success
entry["failed"] = auth.Failed
entry["recent_requests"] = auth.RecentRequestsSnapshot(time.Now())
if email := authEmail(auth); email != "" {
entry["email"] = email
}
@@ -2395,23 +2398,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// For free users, use backend project ID for preview model access
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
finalProjectID = responseProjectID
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s", responseProjectID)
}
finalProjectID = responseProjectID
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
@@ -0,0 +1,94 @@
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestListAuthFiles_IncludesRecentRequestsBuckets(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
manager := coreauth.NewManager(nil, nil, nil)
record := &coreauth.Auth{
ID: "runtime-only-auth-1",
Provider: "codex",
Attributes: map[string]string{
"runtime_only": "true",
},
Metadata: map[string]any{
"type": "codex",
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
h.tokenStore = &memoryAuthStore{}
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
req := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
ginCtx.Request = req
h.ListAuthFiles(ginCtx)
if rec.Code != http.StatusOK {
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
}
var payload map[string]any
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
}
filesRaw, ok := payload["files"].([]any)
if !ok {
t.Fatalf("expected files array, payload: %#v", payload)
}
if len(filesRaw) != 1 {
t.Fatalf("expected 1 auth entry, got %d", len(filesRaw))
}
fileEntry, ok := filesRaw[0].(map[string]any)
if !ok {
t.Fatalf("expected file entry object, got %#v", filesRaw[0])
}
if _, ok := fileEntry["success"].(float64); !ok {
t.Fatalf("expected success number, got %#v", fileEntry["success"])
}
if _, ok := fileEntry["failed"].(float64); !ok {
t.Fatalf("expected failed number, got %#v", fileEntry["failed"])
}
recentRaw, ok := fileEntry["recent_requests"].([]any)
if !ok {
t.Fatalf("expected recent_requests array, got %#v", fileEntry["recent_requests"])
}
if len(recentRaw) != 20 {
t.Fatalf("expected 20 recent_requests buckets, got %d", len(recentRaw))
}
for idx, item := range recentRaw {
bucket, ok := item.(map[string]any)
if !ok {
t.Fatalf("expected bucket object at %d, got %#v", idx, item)
}
if _, ok := bucket["time"].(string); !ok {
t.Fatalf("expected bucket time string at %d, got %#v", idx, bucket["time"])
}
if _, ok := bucket["success"].(float64); !ok {
t.Fatalf("expected bucket success number at %d, got %#v", idx, bucket["success"])
}
if _, ok := bucket["failed"].(float64); !ok {
t.Fatalf("expected bucket failed number at %d, got %#v", idx, bucket["failed"])
}
}
}
@@ -15,7 +15,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"golang.org/x/crypto/bcrypt"
@@ -41,7 +40,6 @@ type Handler struct {
attemptsMu sync.Mutex
failedAttempts map[string]*attemptInfo // keyed by client IP
authManager *coreauth.Manager
usageStats *usage.RequestStatistics
tokenStore coreauth.Store
localPassword string
allowRemoteOverride bool
@@ -60,7 +58,6 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
authManager: manager,
usageStats: usage.GetRequestStatistics(),
tokenStore: sdkAuth.GetTokenStore(),
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
@@ -124,9 +121,6 @@ func (h *Handler) SetAuthManager(manager *coreauth.Manager) {
h.mu.Unlock()
}
// SetUsageStatistics allows replacing the usage statistics reference.
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
+33 -57
View File
@@ -2,78 +2,54 @@ package management
import (
"encoding/json"
"errors"
"net/http"
"time"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageQueueRecord []byte
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
func (r usageQueueRecord) MarshalJSON() ([]byte, error) {
if json.Valid(r) {
return append([]byte(nil), r...), nil
}
c.JSON(http.StatusOK, gin.H{
"usage": snapshot,
"failed_requests": snapshot.FailureCount,
})
return json.Marshal(string(r))
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
// GetUsageQueue pops queued usage records from the usage queue.
func (h *Handler) GetUsageQueue(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
count, errCount := parseUsageQueueCount(c.Query("count"))
if errCount != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": errCount.Error()})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
items := redisqueue.PopOldest(count)
records := make([]usageQueueRecord, 0, len(items))
for _, item := range items {
records = append(records, usageQueueRecord(append([]byte(nil), item...)))
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
c.JSON(http.StatusOK, records)
}
func parseUsageQueueCount(value string) (int, error) {
value = strings.TrimSpace(value)
if value == "" {
return 1, nil
}
count, errCount := strconv.Atoi(value)
if errCount != nil || count <= 0 {
return 0, errors.New("count must be a positive integer")
}
return count, nil
}
@@ -0,0 +1,98 @@
package management
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
)
func TestGetUsageQueuePopsRequestedRecords(t *testing.T) {
gin.SetMode(gin.TestMode)
withManagementUsageQueue(t, func() {
redisqueue.Enqueue([]byte(`{"id":1}`))
redisqueue.Enqueue([]byte(`{"id":2}`))
redisqueue.Enqueue([]byte(`{"id":3}`))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
h := &Handler{}
h.GetUsageQueue(ginCtx)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
}
var payload []json.RawMessage
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
t.Fatalf("unmarshal response: %v", errUnmarshal)
}
if len(payload) != 2 {
t.Fatalf("response records = %d, want 2", len(payload))
}
requireRecordID(t, payload[0], 1)
requireRecordID(t, payload[1], 2)
remaining := redisqueue.PopOldest(10)
if len(remaining) != 1 || string(remaining[0]) != `{"id":3}` {
t.Fatalf("remaining queue = %q, want third item only", remaining)
}
})
}
func TestGetUsageQueueInvalidCountDoesNotPop(t *testing.T) {
gin.SetMode(gin.TestMode)
withManagementUsageQueue(t, func() {
redisqueue.Enqueue([]byte(`{"id":1}`))
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=0", nil)
h := &Handler{}
h.GetUsageQueue(ginCtx)
if rec.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
}
remaining := redisqueue.PopOldest(10)
if len(remaining) != 1 || string(remaining[0]) != `{"id":1}` {
t.Fatalf("remaining queue = %q, want original item", remaining)
}
})
}
func withManagementUsageQueue(t *testing.T, fn func()) {
t.Helper()
prevQueueEnabled := redisqueue.Enabled()
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(true)
defer func() {
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(prevQueueEnabled)
}()
fn()
}
func requireRecordID(t *testing.T, raw json.RawMessage, want int) {
t.Helper()
var payload struct {
ID int `json:"id"`
}
if errUnmarshal := json.Unmarshal(raw, &payload); errUnmarshal != nil {
t.Fatalf("unmarshal record: %v", errUnmarshal)
}
if payload.ID != want {
t.Fatalf("record id = %d, want %d", payload.ID, want)
}
}
@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// ampCanonicalToolNames maps tool names to the exact casing expected by the
// Amp mode tool whitelist (case-sensitive match).
var ampCanonicalToolNames = map[string]string{
"bash": "Bash",
"read": "Read",
"grep": "Grep",
"glob": "glob",
"task": "Task",
"check": "Check",
}
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
// which causes Amp's case-sensitive mode whitelist to reject them.
func normalizeAmpToolNames(data []byte) []byte {
// Non-streaming: content[].name in tool_use blocks
for index, block := range gjson.GetBytes(data, "content").Array() {
if block.Get("type").String() != "tool_use" {
continue
}
name := block.Get("name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
path := fmt.Sprintf("content.%d.name", index)
var err error
data, err = sjson.SetBytes(data, path, canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
}
}
}
// Streaming: content_block.name in content_block_start events
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
name := gjson.GetBytes(data, "content_block.name").String()
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
var err error
data, err = sjson.SetBytes(data, "content_block.name", canonical)
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
}
}
}
return data
}
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
// in API responses so that the Amp TUI does not crash on P.signature.length.
func ensureAmpSignature(data []byte) []byte {
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
data = ensureAmpSignature(data)
data = normalizeAmpToolNames(data)
data = rw.suppressAmpThinking(data)
if len(data) == 0 {
return data
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
// Inject empty signature where needed
data = ensureAmpSignature(data)
// Normalize tool names to canonical casing
data = normalizeAmpToolNames(data)
// Rewrite model name
if rw.originalModel != "" {
for _, path := range modelFieldPaths {
@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
}
}
func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
result := normalizeAmpToolNames(input)
if !contains(result, []byte(`"name":"Bash"`)) {
t.Errorf("expected bash->Bash, got %s", string(result))
}
if !contains(result, []byte(`"name":"Read"`)) {
t.Errorf("expected read->Read, got %s", string(result))
}
if contains(result, []byte(`"name":"bash"`)) {
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
result := normalizeAmpToolNames(input)
if !contains(result, []byte(`"name":"Grep"`)) {
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected glob to remain lowercase, got %s", string(result))
}
}
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
result := normalizeAmpToolNames(input)
if string(result) != string(input) {
t.Errorf("expected no modification for unknown tool, got %s", string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {
+7 -5
View File
@@ -31,7 +31,6 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
@@ -507,9 +506,6 @@ func (s *Server) registerManagementRoutes() {
mgmt := s.engine.Group("/v0/management")
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -554,6 +550,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
mgmt.GET("/api-key-usage", s.mgmt.GetAPIKeyUsage)
mgmt.GET("/usage-queue", s.mgmt.GetUsageQueue)
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
@@ -1000,7 +998,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
}
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
redisqueue.SetUsageStatisticsEnabled(cfg.UsageStatisticsEnabled)
}
if oldCfg == nil || oldCfg.RedisUsageQueueRetentionSeconds != cfg.RedisUsageQueueRetentionSeconds {
redisqueue.SetRetentionSeconds(cfg.RedisUsageQueueRetentionSeconds)
}
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
+63
View File
@@ -13,6 +13,7 @@ import (
gin "github.com/gin-gonic/gin"
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
@@ -84,6 +85,68 @@ func TestHealthz(t *testing.T) {
})
}
func TestManagementUsageRequiresManagementAuthAndPopsArray(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "test-management-key")
prevQueueEnabled := redisqueue.Enabled()
redisqueue.SetEnabled(false)
t.Cleanup(func() {
redisqueue.SetEnabled(false)
redisqueue.SetEnabled(prevQueueEnabled)
})
server := newTestServer(t)
redisqueue.Enqueue([]byte(`{"id":1}`))
redisqueue.Enqueue([]byte(`{"id":2}`))
missingKeyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
missingKeyRR := httptest.NewRecorder()
server.engine.ServeHTTP(missingKeyRR, missingKeyReq)
if missingKeyRR.Code != http.StatusUnauthorized {
t.Fatalf("missing key status = %d, want %d body=%s", missingKeyRR.Code, http.StatusUnauthorized, missingKeyRR.Body.String())
}
legacyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage?count=2", nil)
legacyReq.Header.Set("Authorization", "Bearer test-management-key")
legacyRR := httptest.NewRecorder()
server.engine.ServeHTTP(legacyRR, legacyReq)
if legacyRR.Code != http.StatusNotFound {
t.Fatalf("legacy usage status = %d, want %d body=%s", legacyRR.Code, http.StatusNotFound, legacyRR.Body.String())
}
authReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
authReq.Header.Set("Authorization", "Bearer test-management-key")
authRR := httptest.NewRecorder()
server.engine.ServeHTTP(authRR, authReq)
if authRR.Code != http.StatusOK {
t.Fatalf("authenticated status = %d, want %d body=%s", authRR.Code, http.StatusOK, authRR.Body.String())
}
var payload []json.RawMessage
if errUnmarshal := json.Unmarshal(authRR.Body.Bytes(), &payload); errUnmarshal != nil {
t.Fatalf("unmarshal response: %v body=%s", errUnmarshal, authRR.Body.String())
}
if len(payload) != 2 {
t.Fatalf("response records = %d, want 2", len(payload))
}
for i, raw := range payload {
var record struct {
ID int `json:"id"`
}
if errUnmarshal := json.Unmarshal(raw, &record); errUnmarshal != nil {
t.Fatalf("unmarshal record %d: %v", i, errUnmarshal)
}
if record.ID != i+1 {
t.Fatalf("record %d id = %d, want %d", i, record.ID, i+1)
}
}
if remaining := redisqueue.PopOldest(1); len(remaining) != 0 {
t.Fatalf("remaining queue = %q, want empty", remaining)
}
}
func TestAmpProviderModelRoutes(t *testing.T) {
testCases := []struct {
name string
+134 -1
View File
@@ -6,15 +6,18 @@ package claude
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)
// OAuth configuration constants for Claude/Anthropic
@@ -23,8 +26,94 @@ const (
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
claudeRefreshMinBackoff = 5 * time.Second
claudeRefreshMaxBackoff = 5 * time.Minute
)
var (
claudeRefreshGroup singleflight.Group
claudeRefreshMu sync.Mutex
claudeRefreshBlock = make(map[string]time.Time)
)
type refreshHTTPError struct {
status int
message string
retryable bool
}
func (e *refreshHTTPError) Error() string {
return fmt.Sprintf("token refresh failed with status %d: %s", e.status, e.message)
}
func (e *refreshHTTPError) Retryable() bool {
return e != nil && e.retryable
}
func resetClaudeRefreshState() {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
claudeRefreshBlock = make(map[string]time.Time)
claudeRefreshGroup = singleflight.Group{}
}
func claudeRefreshBlockedUntil(refreshToken string) time.Time {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
return claudeRefreshBlock[refreshToken]
}
func setClaudeRefreshBlockedUntil(refreshToken string, until time.Time) {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
claudeRefreshBlock[refreshToken] = until
}
func clearClaudeRefreshBlockedUntil(refreshToken string) {
claudeRefreshMu.Lock()
defer claudeRefreshMu.Unlock()
delete(claudeRefreshBlock, refreshToken)
}
func clampClaudeRefreshBackoff(d time.Duration) time.Duration {
if d < claudeRefreshMinBackoff {
return claudeRefreshMinBackoff
}
if d > claudeRefreshMaxBackoff {
return claudeRefreshMaxBackoff
}
return d
}
func parseClaudeRetryAfter(resp *http.Response) time.Duration {
if resp == nil {
return claudeRefreshMinBackoff
}
if raw := strings.TrimSpace(resp.Header.Get("Retry-After")); raw != "" {
if seconds, err := time.ParseDuration(raw + "s"); err == nil {
return clampClaudeRefreshBackoff(seconds)
}
if when, err := http.ParseTime(raw); err == nil {
return clampClaudeRefreshBackoff(time.Until(when))
}
}
if raw := strings.TrimSpace(resp.Header.Get("Retry-After-Ms")); raw != "" {
if ms, err := time.ParseDuration(raw + "ms"); err == nil {
return clampClaudeRefreshBackoff(ms)
}
}
return claudeRefreshMinBackoff
}
func isClaudeRefreshRetryable(err error) bool {
var httpErr *refreshHTTPError
if errors.As(err, &httpErr) {
return httpErr.Retryable()
}
return true
}
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
// It contains access token, refresh token, and associated user/organization information.
type tokenResponse struct {
@@ -242,6 +331,35 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
if refreshToken == "" {
return nil, fmt.Errorf("refresh token is required")
}
if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
return nil, &refreshHTTPError{
status: http.StatusTooManyRequests,
message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
retryable: false,
}
}
result, err, _ := claudeRefreshGroup.Do(refreshToken, func() (interface{}, error) {
return o.refreshTokensSingleFlight(context.WithoutCancel(ctx), refreshToken)
})
if err != nil {
return nil, err
}
tokenData, ok := result.(*ClaudeTokenData)
if !ok || tokenData == nil {
return nil, fmt.Errorf("token refresh failed: invalid single-flight result")
}
return tokenData, nil
}
func (o *ClaudeAuth) refreshTokensSingleFlight(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
return nil, &refreshHTTPError{
status: http.StatusTooManyRequests,
message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
retryable: false,
}
}
reqBody := map[string]interface{}{
"client_id": ClientID,
@@ -276,7 +394,17 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
message := string(body)
if resp.StatusCode == http.StatusTooManyRequests {
retryAfter := parseClaudeRetryAfter(resp)
setClaudeRefreshBlockedUntil(refreshToken, time.Now().Add(retryAfter))
return nil, &refreshHTTPError{status: resp.StatusCode, message: message, retryable: false}
}
return nil, &refreshHTTPError{
status: resp.StatusCode,
message: message,
retryable: resp.StatusCode >= http.StatusInternalServerError,
}
}
// log.Debugf("Token response: %s", string(body))
@@ -287,6 +415,8 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
}
// Create token data
clearClaudeRefreshBlockedUntil(refreshToken)
return &ClaudeTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -348,6 +478,9 @@ func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken st
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
if !isClaudeRefreshRetryable(err) {
break
}
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
+123
View File
@@ -0,0 +1,123 @@
package claude
import (
"context"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_429BlocksImmediateReplay(t *testing.T) {
resetClaudeRefreshState()
defer resetClaudeRefreshState()
var calls int32
auth := &ClaudeAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Body: io.NopCloser(strings.NewReader(`{"error":"rate_limited"}`)),
Header: http.Header{"Retry-After": []string{"60"}},
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected 429 refresh error")
}
if !strings.Contains(err.Error(), "status 429") {
t.Fatalf("expected status 429 in error, got %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt after 429, got %d", got)
}
_, err = auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected immediate blocked refresh error")
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected blocked retry to avoid a second refresh call, got %d attempts", got)
}
if blockedUntil := claudeRefreshBlockedUntil("dummy_refresh_token"); !blockedUntil.After(time.Now()) {
t.Fatalf("expected blocked-until timestamp to be set, got %v", blockedUntil)
}
}
func TestRefreshTokens_DeduplicatesConcurrentRefresh(t *testing.T) {
resetClaudeRefreshState()
defer resetClaudeRefreshState()
var calls int32
started := make(chan struct{})
release := make(chan struct{})
var once sync.Once
auth := &ClaudeAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
once.Do(func() { close(started) })
<-release
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(`{
"access_token":"new-access",
"refresh_token":"new-refresh",
"token_type":"Bearer",
"expires_in":3600,
"account":{"email_address":"shared@example.com"}
}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
results := make(chan *ClaudeTokenData, 2)
errs := make(chan error, 2)
runRefresh := func() {
td, err := auth.RefreshTokens(context.Background(), "shared-refresh-token")
results <- td
errs <- err
}
go runRefresh()
go runRefresh()
<-started
time.Sleep(20 * time.Millisecond)
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
}
close(release)
for i := 0; i < 2; i++ {
if err := <-errs; err != nil {
t.Fatalf("expected refresh to succeed, got %v", err)
}
td := <-results
if td == nil || td.AccessToken != "new-access" {
t.Fatalf("expected refreshed access token, got %#v", td)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected exactly 1 upstream refresh call, got %d", got)
}
}
+3 -35
View File
@@ -333,42 +333,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// Interactive prompt for free users
fmt.Printf("\nGoogle returned a different project ID:\n")
fmt.Printf(" Requested (frontend): %s\n", projectID)
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
fmt.Printf(" This is normal for free tier users.\n\n")
fmt.Printf("Which project ID would you like to use?\n")
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
fmt.Printf("Enter choice [1]: ")
reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(choice)
if choice == "2" {
log.Infof("Using frontend project ID: %s", projectID)
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
finalProjectID = projectID
} else {
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
finalProjectID = responseProjectID
}
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s", responseProjectID)
}
finalProjectID = responseProjectID
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
+13
View File
@@ -65,6 +65,11 @@ type Config struct {
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
// are retained in memory for the Redis RESP interface (LPOP/RPOP).
// Default: 60. Max: 3600.
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
// DisableCooling disables quota cooldown scheduling when true.
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
@@ -609,6 +614,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.LogsMaxTotalSizeMB = 0
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
cfg.RedisUsageQueueRetentionSeconds = 60
cfg.DisableCooling = false
cfg.DisableImageGeneration = DisableImageGenerationOff
cfg.Pprof.Enable = false
@@ -671,6 +677,13 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
}
if cfg.RedisUsageQueueRetentionSeconds <= 0 {
cfg.RedisUsageQueueRetentionSeconds = 60
} else if cfg.RedisUsageQueueRetentionSeconds > 3600 {
log.WithField("value", cfg.RedisUsageQueueRetentionSeconds).Warn("redis-usage-queue-retention-seconds too large; clamping to 3600")
cfg.RedisUsageQueueRetentionSeconds = 3600
}
if cfg.MaxRetryCredentials < 0 {
cfg.MaxRetryCredentials = 0
}
+62
View File
@@ -0,0 +1,62 @@
package logging
import (
"context"
"sync/atomic"
)
type endpointKey struct{}
type responseStatusKey struct{}
type responseStatusHolder struct {
status atomic.Int32
}
func WithEndpoint(ctx context.Context, endpoint string) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, endpointKey{}, endpoint)
}
func GetEndpoint(ctx context.Context) string {
if ctx == nil {
return ""
}
if endpoint, ok := ctx.Value(endpointKey{}).(string); ok {
return endpoint
}
return ""
}
func WithResponseStatusHolder(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
if holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder); ok && holder != nil {
return ctx
}
return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{})
}
func SetResponseStatus(ctx context.Context, status int) {
if ctx == nil || status <= 0 {
return
}
holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
if !ok || holder == nil {
return
}
holder.status.Store(int32(status))
}
func GetResponseStatus(ctx context.Context) int {
if ctx == nil {
return 0
}
holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
if !ok || holder == nil {
return 0
}
return int(holder.status.Load())
}
+2 -2
View File
@@ -12,7 +12,7 @@ import (
const (
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
GeminiCLIVersion = "0.31.0"
GeminiCLIVersion = "0.34.0"
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
@@ -46,7 +46,7 @@ func GeminiCLIUserAgent(model string) string {
if model == "" {
model = "unknown"
}
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
}
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
+33 -43
View File
@@ -3,13 +3,10 @@ package redisqueue
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
@@ -23,7 +20,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if p == nil {
return
}
if !Enabled() || !internalusage.StatisticsEnabled() {
if !Enabled() || !UsageStatisticsEnabled() {
return
}
@@ -36,6 +33,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if modelName == "" {
modelName = "unknown"
}
aliasName := strings.TrimSpace(record.Alias)
if aliasName == "" {
aliasName = modelName
}
provider := strings.TrimSpace(record.Provider)
if provider == "" {
provider = "unknown"
@@ -46,13 +47,8 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
apiKey := strings.TrimSpace(record.APIKey)
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
if requestID == "" {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx))
}
}
tokens := internalusage.TokenStats{
tokens := tokenStats{
InputTokens: record.Detail.InputTokens,
OutputTokens: record.Detail.OutputTokens,
ReasoningTokens: record.Detail.ReasoningTokens,
@@ -71,7 +67,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
failed = !resolveSuccess(ctx)
}
detail := internalusage.RequestDetail{
detail := requestDetail{
Timestamp: timestamp,
LatencyMs: record.Latency.Milliseconds(),
Source: record.Source,
@@ -81,9 +77,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
payload, err := json.Marshal(queuedUsageDetail{
RequestDetail: detail,
requestDetail: detail,
Provider: provider,
Model: modelName,
Alias: aliasName,
Endpoint: resolveEndpoint(ctx),
AuthType: authType,
APIKey: apiKey,
@@ -96,50 +93,43 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
}
type queuedUsageDetail struct {
internalusage.RequestDetail
requestDetail
Provider string `json:"provider"`
Model string `json:"model"`
Alias string `json:"alias"`
Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"`
APIKey string `json:"api_key"`
RequestID string `json:"request_id"`
}
type requestDetail struct {
Timestamp time.Time `json:"timestamp"`
LatencyMs int64 `json:"latency_ms"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens tokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
type tokenStats struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
ReasoningTokens int64 `json:"reasoning_tokens"`
CachedTokens int64 `json:"cached_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
func resolveSuccess(ctx context.Context) bool {
if ctx == nil {
return true
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return true
}
status := ginCtx.Writer.Status()
status := internallogging.GetResponseStatus(ctx)
if status == 0 {
return true
}
return status < http.StatusBadRequest
return status < httpStatusBadRequest
}
func resolveEndpoint(ctx context.Context) string {
if ctx == nil {
return ""
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil || ginCtx.Request == nil {
return ""
}
path := strings.TrimSpace(ginCtx.FullPath())
if path == "" && ginCtx.Request.URL != nil {
path = strings.TrimSpace(ginCtx.Request.URL.Path)
}
if path == "" {
return ""
}
method := strings.TrimSpace(ginCtx.Request.Method)
if method == "" {
return path
}
return method + " " + path
return strings.TrimSpace(internallogging.GetEndpoint(ctx))
}
const httpStatusBadRequest = 400
+87 -10
View File
@@ -10,20 +10,21 @@ import (
"github.com/gin-gonic/gin"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored")
ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx)
ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id")
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusOK)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
@@ -40,6 +41,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "ctx-request-id")
@@ -49,14 +51,16 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError)
internallogging.SetGinRequestID(ginCtx, "gin-request-id")
ctx := context.WithValue(context.Background(), "gin", ginCtx)
ctx := internallogging.WithRequestID(context.Background(), "gin-request-id")
ctx = internallogging.WithEndpoint(ctx, "GET /v1/responses")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
plugin := &usageQueuePlugin{}
plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4-mini",
Alias: "client-mini",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
@@ -73,6 +77,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4-mini")
requireStringField(t, payload, "alias", "client-mini")
requireStringField(t, payload, "endpoint", "GET /v1/responses")
requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "gin-request-id")
@@ -80,20 +85,63 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
})
}
func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) {
withEnabledQueue(t, func() {
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
ctx := context.WithValue(context.Background(), "gin", ginCtx)
ctx = internallogging.WithRequestID(ctx, "ctx-request-id")
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
ctx = internallogging.WithResponseStatusHolder(ctx)
internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
mgr := coreusage.NewManager(16)
defer mgr.Stop()
mgr.Register(pluginFunc(func(_ context.Context, _ coreusage.Record) {
ginCtx.Request = httptest.NewRequest(http.MethodGet, "http://example.com/v1/responses", nil)
ginCtx.Status(http.StatusOK)
}))
mgr.Register(&usageQueuePlugin{})
mgr.Publish(ctx, coreusage.Record{
Provider: "openai",
Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key",
AuthIndex: "0",
AuthType: "apikey",
Source: "user@example.com",
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
Latency: 1500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
})
payload := waitForSinglePayload(t, 2*time.Second)
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "request_id", "ctx-request-id")
requireBoolField(t, payload, "failed", true)
})
}
func withEnabledQueue(t *testing.T, fn func()) {
t.Helper()
prevQueueEnabled := Enabled()
prevStatsEnabled := internalusage.StatisticsEnabled()
prevUsageEnabled := UsageStatisticsEnabled()
SetEnabled(false)
SetEnabled(true)
internalusage.SetStatisticsEnabled(true)
SetUsageStatisticsEnabled(true)
defer func() {
SetEnabled(false)
SetEnabled(prevQueueEnabled)
internalusage.SetStatisticsEnabled(prevStatsEnabled)
SetUsageStatisticsEnabled(prevUsageEnabled)
}()
fn()
@@ -127,6 +175,29 @@ func popSinglePayload(t *testing.T) map[string]json.RawMessage {
return payload
}
func waitForSinglePayload(t *testing.T, timeout time.Duration) map[string]json.RawMessage {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
items := PopOldest(10)
if len(items) == 0 {
time.Sleep(10 * time.Millisecond)
continue
}
if len(items) != 1 {
t.Fatalf("PopOldest() items = %d, want 1", len(items))
}
var payload map[string]json.RawMessage
if err := json.Unmarshal(items[0], &payload); err != nil {
t.Fatalf("unmarshal payload: %v", err)
}
return payload
}
t.Fatalf("timeout waiting for queued payload")
return nil
}
func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) {
t.Helper()
@@ -143,6 +214,12 @@ func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, w
}
}
type pluginFunc func(context.Context, coreusage.Record)
func (fn pluginFunc) HandleUsage(ctx context.Context, record coreusage.Record) {
fn(ctx, record)
}
func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) {
t.Helper()
+26 -4
View File
@@ -6,7 +6,10 @@ import (
"time"
)
const retentionWindow = time.Minute
const (
defaultRetentionSeconds int64 = 60
maxRetentionSeconds int64 = 3600
)
type queueItem struct {
enqueuedAt time.Time
@@ -20,10 +23,15 @@ type queue struct {
}
var (
enabled atomic.Bool
global queue
enabled atomic.Bool
retentionSeconds atomic.Int64
global queue
)
func init() {
retentionSeconds.Store(defaultRetentionSeconds)
}
func SetEnabled(value bool) {
enabled.Store(value)
if !value {
@@ -35,6 +43,16 @@ func Enabled() bool {
return enabled.Load()
}
func SetRetentionSeconds(value int) {
normalized := int64(value)
if normalized <= 0 {
normalized = defaultRetentionSeconds
} else if normalized > maxRetentionSeconds {
normalized = maxRetentionSeconds
}
retentionSeconds.Store(normalized)
}
func Enqueue(payload []byte) {
if !Enabled() {
return
@@ -110,7 +128,11 @@ func (q *queue) pruneLocked(now time.Time) {
return
}
cutoff := now.Add(-retentionWindow)
windowSeconds := retentionSeconds.Load()
if windowSeconds <= 0 {
windowSeconds = defaultRetentionSeconds
}
cutoff := now.Add(-time.Duration(windowSeconds) * time.Second)
for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) {
q.head++
}
+16
View File
@@ -0,0 +1,16 @@
package redisqueue
import "sync/atomic"
var usageStatisticsEnabled atomic.Bool
func init() {
usageStatisticsEnabled.Store(true)
}
// SetUsageStatisticsEnabled toggles whether usage records are enqueued into the redisqueue payload buffer.
// This is controlled by the config field `usage-statistics-enabled` and the corresponding management API.
func SetUsageStatisticsEnabled(enabled bool) { usageStatisticsEnabled.Store(enabled) }
// UsageStatisticsEnabled reports whether the usage queue plugin should publish records.
func UsageStatisticsEnabled() bool { return usageStatisticsEnabled.Load() }
+18 -4
View File
@@ -285,7 +285,10 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
if event.Err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
select {
case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
case <-ctx.Done():
}
return false
}
switch event.Type {
@@ -303,7 +306,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
case <-ctx.Done():
return false
}
}
break
}
@@ -319,14 +326,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
case <-ctx.Done():
return false
}
}
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
return false
case wsrelay.MessageTypeError:
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
select {
case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
case <-ctx.Done():
}
return false
}
return true
@@ -1357,17 +1357,28 @@ attemptLoop:
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), &param)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
} else {
reporter.EnsurePublished(ctx)
}
+168 -71
View File
@@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{
"notebookedit": "NotebookEdit",
}
// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding.
var oauthToolRenameReverseMap = func() map[string]string {
m := make(map[string]string, len(oauthToolRenameMap))
for k, v := range oauthToolRenameMap {
m[v] = k
}
return m
}()
// The reverse map is now computed per-request in remapOAuthToolNames so that
// only names the client actually caused us to rewrite are restored on the
// response. A global reverse map — as used previously — corrupted responses
// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase
// alongside `glob` lowercase; the request flagged renames via `glob→Glob`,
// then the global reverse map incorrectly rewrote every `Bash` in the
// response to `bash`, causing Amp to reject the tool_use as unknown).
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
// even after remapping. Currently empty — all tools are mapped instead of removed.
@@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap third-party tool names to Claude Code equivalents and remove
// tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns.
var oauthToolNamesReverseMap map[string]string
if oauthToken {
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
@@ -285,6 +278,10 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if stream {
if errValidate := validateClaudeStreamingResponse(data); errValidate != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errValidate)
return resp, errValidate
}
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
@@ -294,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
} else {
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
// Reverse the OAuth tool name remap so the downstream client sees original names.
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
data = reverseRemapOAuthToolNames(data)
}
data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
var param any
out := sdktranslator.TranslateNonStream(
ctx,
@@ -375,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
bodyForTranslation := body
bodyForUpstream := body
oauthToken := isClaudeOAuthToken(apiKey)
oauthToolNamesRemapped := false
if oauthToken && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap third-party tool names to Claude Code equivalents and remove
// tools without official counterparts. This prevents Anthropic from
// fingerprinting the request as third-party via tool naming patterns.
var oauthToolNamesReverseMap map[string]string
if oauthToken {
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
}
// Enable cch signing by default for OAuth tokens (not just experimental flag).
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
@@ -474,22 +459,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line)
}
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
cloned[len(line)] = '\n'
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: cloned}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
return
}
@@ -504,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
line = reverseRemapOAuthToolNamesFromStreamLine(line)
}
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
chunks := sdktranslator.TranslateStream(
ctx,
to,
@@ -521,18 +503,83 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
&param,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func validateClaudeStreamingResponse(data []byte) error {
scanner := bufio.NewScanner(bytes.NewReader(data))
scanner.Buffer(nil, 52_428_800)
hasData := false
hasMessageStart := false
hasMessageDelta := false
for scanner.Scan() {
line := bytes.TrimSpace(scanner.Bytes())
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
continue
}
payload := bytes.TrimSpace(line[len("data:"):])
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
continue
}
hasData = true
if !gjson.ValidBytes(payload) {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned malformed stream data"}
}
root := gjson.ParseBytes(payload)
switch root.Get("type").String() {
case "error":
message := strings.TrimSpace(root.Get("error.message").String())
if message == "" {
message = strings.TrimSpace(root.Get("error.type").String())
}
if message == "" {
message = "unknown upstream error"
}
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned error event: " + message}
case "message_start":
message := root.Get("message")
if strings.TrimSpace(message.Get("id").String()) == "" || strings.TrimSpace(message.Get("model").String()) == "" {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream message_start is missing id or model"}
}
hasMessageStart = true
case "message_delta":
hasMessageDelta = true
}
}
if errScan := scanner.Err(); errScan != nil {
return errScan
}
if !hasData {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned empty stream response"}
}
if !hasMessageStart {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response is missing message_start"}
}
if !hasMessageDelta {
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response ended before message completion"}
}
return nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
@@ -559,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
// Remap tool names for OAuth token requests to avoid third-party fingerprinting.
if isClaudeOAuthToken(apiKey) {
body, _ = remapOAuthToolNames(body)
body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled())
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
@@ -661,7 +704,7 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
return auth, nil
}
svc := claudeauth.NewClaudeAuthWithProxyURL(e.cfg, auth.ProxyURL)
td, err := svc.RefreshTokens(ctx, refreshToken)
td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3)
if err != nil {
return nil, err
}
@@ -1004,6 +1047,36 @@ func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name
// transforms in the same order across request paths. Remap runs before prefixing
// so any future non-empty prefix still composes correctly with the per-request
// reverse map.
func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) {
body, reverseMap := remapOAuthToolNames(body)
if !prefixDisabled {
body = applyClaudeToolPrefix(body, prefix)
}
return body, reverseMap
}
// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name
// transforms for non-stream responses in reverse order.
func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
if !prefixDisabled {
body = stripClaudeToolPrefixFromResponse(body, prefix)
}
return reverseRemapOAuthToolNames(body, reverseMap)
}
// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name
// transforms for SSE lines in reverse order.
func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
if !prefixDisabled {
line = stripClaudeToolPrefixFromStreamLine(line, prefix)
}
return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap)
}
// remapOAuthToolNames renames third-party tool names to Claude Code equivalents
// and removes tools without an official counterpart. This prevents Anthropic from
// fingerprinting the request as a third-party client via tool naming patterns.
@@ -1011,8 +1084,25 @@ func isClaudeOAuthToken(apiKey string) bool {
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
// references in messages. Removed tools' corresponding tool_result blocks are preserved
// (they just become orphaned, which is safe for Claude).
func remapOAuthToolNames(body []byte) ([]byte, bool) {
renamed := false
//
// The returned map is keyed on the upstream (TitleCase) name and maps to the
// client-supplied original name. Callers MUST pass this map to the reverse
// functions so only names the client actually caused us to rewrite are restored
// on the response. A global reverse map (the previous implementation) incorrectly
// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`)
// when any OTHER tool in the same request triggered a forward rename (e.g.
// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash`
// regardless of what the client originally sent.
func remapOAuthToolNames(body []byte) ([]byte, map[string]string) {
reverseMap := make(map[string]string, len(oauthToolRenameMap))
recordRename := func(original, renamed string) {
// Preserve the first-seen original name if the same upstream name is
// produced from multiple call sites; they all map back identically.
if _, exists := reverseMap[renamed]; !exists {
reverseMap[renamed] = original
}
}
// 1. Rewrite tools array in a single pass (if present).
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
@@ -1045,7 +1135,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
updatedTool, err := sjson.Set(toolJSON, "name", newName)
if err == nil {
toolJSON = updatedTool
renamed = true
recordRename(name, newName)
}
}
@@ -1070,7 +1160,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
body, _ = sjson.DeleteBytes(body, "tool_choice")
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
renamed = true
recordRename(tcName, newName)
}
}
@@ -1090,14 +1180,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
renamed = true
recordRename(name, newName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, newName)
renamed = true
recordRename(toolName, newName)
}
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
@@ -1111,7 +1201,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, newName)
renamed = true
recordRename(nestedToolName, newName)
}
}
return true
@@ -1124,13 +1214,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
})
}
return body, renamed
return body, reverseMap
}
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
// It maps Claude Code TitleCase names back to the original lowercase names so the
// downstream client receives tool names it recognizes.
func reverseRemapOAuthToolNames(body []byte) []byte {
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses
// using the per-request map produced by remapOAuthToolNames. Names the client sent
// that were NOT forward-renamed are passed through unchanged.
func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte {
if len(reverseMap) == 0 {
return body
}
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
@@ -1140,13 +1233,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
switch partType {
case "tool_use":
name := part.Get("name").String()
if origName, ok := oauthToolRenameReverseMap[name]; ok {
if origName, ok := reverseMap[name]; ok {
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
case "tool_reference":
toolName := part.Get("tool_name").String()
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
if origName, ok := reverseMap[toolName]; ok {
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, origName)
}
@@ -1156,8 +1249,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
return body
}
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines.
func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE
// stream lines, using the per-request reverseMap produced by remapOAuthToolNames.
func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte {
if len(reverseMap) == 0 {
return line
}
payload := helps.JSONPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
@@ -1175,7 +1272,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if origName, ok := oauthToolRenameReverseMap[name]; ok {
if origName, ok := reverseMap[name]; ok {
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
if err != nil {
return line
@@ -1185,7 +1282,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
if origName, ok := reverseMap[toolName]; ok {
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
if err != nil {
return line
+248 -14
View File
@@ -936,6 +936,113 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsEmptyClaudeStream(t *testing.T) {
_, err := executeOpenAIChatCompletionThroughClaude(t, "")
if err == nil {
t.Fatal("Execute error = nil, want empty stream error")
}
assertStatusErr(t, err, http.StatusBadGateway)
if !strings.Contains(err.Error(), "empty stream response") {
t.Fatalf("Execute error = %q, want empty stream response", err.Error())
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsClaudeErrorEvent(t *testing.T) {
body := `data: {"type":"error","error":{"type":"overloaded_error","message":"upstream overloaded"}}` + "\n"
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
if err == nil {
t.Fatal("Execute error = nil, want upstream error event")
}
assertStatusErr(t, err, http.StatusBadGateway)
if !strings.Contains(err.Error(), "upstream overloaded") {
t.Fatalf("Execute error = %q, want upstream overloaded", err.Error())
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsIncompleteClaudeStream(t *testing.T) {
body := strings.Join([]string{
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
`data: {"type":"message_stop"}`,
``,
}, "\n")
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
if err == nil {
t.Fatal("Execute error = nil, want incomplete stream error")
}
assertStatusErr(t, err, http.StatusBadGateway)
if !strings.Contains(err.Error(), "ended before message completion") {
t.Fatalf("Execute error = %q, want incomplete stream error", err.Error())
}
}
func TestClaudeExecutor_ExecuteOpenAINonStreamConvertsValidClaudeStream(t *testing.T) {
body := strings.Join([]string{
`event: message_start`,
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
`event: content_block_delta`,
`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ok"}}`,
`event: message_delta`,
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":2,"output_tokens":1}}`,
`event: message_stop`,
`data: {"type":"message_stop"}`,
``,
}, "\n")
resp, err := executeOpenAIChatCompletionThroughClaude(t, body)
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if got := gjson.GetBytes(resp.Payload, "id").String(); got != "msg_123" {
t.Fatalf("response id = %q, want msg_123; payload=%s", got, string(resp.Payload))
}
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-3-5-sonnet-20241022" {
t.Fatalf("response model = %q, want claude-3-5-sonnet-20241022", got)
}
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "ok" {
t.Fatalf("response content = %q, want ok", got)
}
if got := gjson.GetBytes(resp.Payload, "usage.total_tokens").Int(); got != 3 {
t.Fatalf("usage.total_tokens = %d, want 3", got)
}
}
func executeOpenAIChatCompletionThroughClaude(t *testing.T, upstreamBody string) (cliproxyexecutor.Response, error) {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte(upstreamBody))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hi"}]}`)
return executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
})
}
func assertStatusErr(t *testing.T, err error, want int) {
t.Helper()
status, ok := err.(interface{ StatusCode() int })
if !ok {
t.Fatalf("error %T does not expose StatusCode", err)
}
if got := status.StatusCode(); got != want {
t.Fatalf("StatusCode() = %d, want %d", got, want)
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
@@ -1989,19 +2096,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if renamed {
t.Fatalf("renamed = true, want false")
out, reverseMap := remapOAuthToolNames(body)
if len(reverseMap) != 0 {
t.Fatalf("reverseMap = %v, want empty", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
@@ -2010,20 +2114,150 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
out, renamed := remapOAuthToolNames(body)
if !renamed {
t.Fatalf("renamed = false, want true")
out, reverseMap := remapOAuthToolNames(body)
if reverseMap["Bash"] != "bash" {
t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap)
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
}
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := resp
if renamed {
reversed = reverseRemapOAuthToolNames(resp)
}
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
t.Fatalf("content.0.name = %q, want %q", got, "bash")
}
}
// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression
// test for a case where a single request contains both a TitleCase tool (which
// must pass through unchanged) and a lowercase tool that we forward-rename.
// Before the fix, triggering ANY forward rename caused the reverse pass to
// lowercase every TitleCase tool in the response using a global reverse map,
// corrupting tool names the client originally sent in TitleCase (notably Amp
// CLI's `Bash`, which its registry lookup cannot find as `bash`).
func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) {
body := []byte(`{"tools":[` +
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
`]}`)
out, reverseMap := remapOAuthToolNames(body)
// Forward: TitleCase `Bash` is not a forward-map key, must pass through.
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash")
}
// Forward: `glob` is a forward-map key, upstream sees `Glob`.
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" {
t.Fatalf("tools.1.name = %q, want %q", got, "Glob")
}
// Reverse map records ONLY the rename that happened.
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
}
// Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`,
// reverseRemap MUST leave it alone.
bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
reversed := reverseRemapOAuthToolNames(bashResp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash")
}
// Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`,
// reverseRemap MUST restore the original `glob`.
globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`)
reversed = reverseRemapOAuthToolNames(globResp, reverseMap)
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" {
t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob")
}
}
// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the
// SSE streaming code path against the same mixed-case bug.
func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
// Bash block was never renamed, must pass through as-is.
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`)
out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap)
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
t.Fatalf("Bash should be preserved, got: %s", string(out))
}
if bytes.Contains(out, []byte(`"name":"bash"`)) {
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
}
// Glob block IS in the reverseMap, must be restored to `glob`.
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`)
out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap)
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
}
}
func TestPrepareClaudeOAuthToolNamesForUpstream_MixedCaseWithPrefix(t *testing.T) {
body := []byte(`{"tools":[` +
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
`],"messages":[{"role":"assistant","content":[` +
`{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}},` +
`{"type":"tool_use","id":"toolu_02","name":"glob","input":{}}` +
`]}]}`)
out, reverseMap := prepareClaudeOAuthToolNamesForUpstream(body, "proxy_", false)
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Bash" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Bash")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Glob" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Glob")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Bash" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Bash")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Glob" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Glob")
}
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
}
}
func TestRestoreClaudeOAuthToolNamesFromResponse_MixedCaseWithPrefix(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
resp := []byte(`{"content":[` +
`{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}},` +
`{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}` +
`]}`)
out := restoreClaudeOAuthToolNamesFromResponse(resp, "proxy_", false, reverseMap)
if got := gjson.GetBytes(out, "content.0.name").String(); got != "Bash" {
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
}
if got := gjson.GetBytes(out, "content.1.name").String(); got != "glob" {
t.Fatalf("content.1.name = %q, want %q", got, "glob")
}
}
func TestRestoreClaudeOAuthToolNamesFromStreamLine_MixedCaseWithPrefix(t *testing.T) {
reverseMap := map[string]string{"Glob": "glob"}
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}}}`)
out := restoreClaudeOAuthToolNamesFromStreamLine(bashLine, "proxy_", false, reverseMap)
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
t.Fatalf("Bash should be preserved, got: %s", string(out))
}
if bytes.Contains(out, []byte(`"name":"bash"`)) {
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
}
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}}`)
out = restoreClaudeOAuthToolNamesFromStreamLine(globLine, "proxy_", false, reverseMap)
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
}
}
+11 -4
View File
@@ -30,8 +30,8 @@ import (
)
const (
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
codexOriginator = "codex-tui"
codexUserAgent = "codex_cli_rs/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9"
codexOriginator = "codex_cli_rs"
codexDefaultImageToolModel = "gpt-image-2"
)
@@ -515,13 +515,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, translatedLine, &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -188,7 +188,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body = normalizeCodexInstructions(body)
@@ -776,6 +775,11 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
default:
return "", fmt.Errorf("codex websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme)
}
if strings.TrimSpace(parsed.Host) == "" {
return "", fmt.Errorf("codex websockets executor: responses websocket URL host is empty")
}
return parsed.String(), nil
}
@@ -809,6 +813,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
setHeaderCasePreserved(headers, "session_id", cache.ID)
headers.Set("Conversation_id", cache.ID)
}
@@ -828,13 +833,19 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
ginHeaders = ginCtx.Request.Header.Clone()
}
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
isAPIKey := codexAuthUsesAPIKey(auth)
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", "")
if isAPIKey {
ensureHeaderWithPriority(headers, ginHeaders, "User-Agent", "", "")
} else {
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
}
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
@@ -845,16 +856,9 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
}
headers.Set("OpenAI-Beta", betaHeader)
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
}
headers.Del("User-Agent")
isAPIKey := false
if auth != nil && auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
isAPIKey = true
}
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", uuid.NewString())
}
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", "")
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
headers.Set("Originator", originator)
} else if !isAPIKey {
@@ -864,7 +868,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
headers.Set("Chatgpt-Account-Id", trimmed)
setHeaderCasePreserved(headers, "ChatGPT-Account-ID", trimmed)
}
}
}
@@ -879,6 +883,77 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
return headers
}
func codexAuthUsesAPIKey(auth *cliproxyauth.Auth) bool {
if auth == nil || auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["api_key"]) != ""
}
func ensureHeaderCasePreserved(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(headerValueCaseInsensitive(target, key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(headerValueCaseInsensitive(source, key)); val != "" {
setHeaderCasePreserved(target, key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
setHeaderCasePreserved(target, key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
setHeaderCasePreserved(target, key, val)
}
}
func setHeaderCasePreserved(headers http.Header, key string, value string) {
if headers == nil {
return
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
return
}
deleteHeaderCaseInsensitive(headers, key)
headers[key] = []string{value}
}
func headerValueCaseInsensitive(headers http.Header, key string) string {
key = strings.TrimSpace(key)
if headers == nil || key == "" {
return ""
}
if val := strings.TrimSpace(headers.Get(key)); val != "" {
return val
}
for existingKey, values := range headers {
if !strings.EqualFold(existingKey, key) {
continue
}
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
}
return ""
}
func deleteHeaderCaseInsensitive(headers http.Header, key string) {
for existingKey := range headers {
if strings.EqualFold(existingKey, key) {
delete(headers, existingKey)
}
}
}
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
@@ -962,25 +1037,55 @@ func parseCodexWebsocketError(payload []byte) (error, bool) {
return nil, false
}
out := []byte(`{}`)
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
raw := errNode.Raw
if errNode.Type == gjson.String {
raw = errNode.Raw
}
out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
} else {
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
}
out := buildCodexWebsocketErrorPayload(payload, status)
headers := parseCodexWebsocketErrorHeaders(payload)
statusError := statusErr{code: status, msg: string(out)}
if retryAfter := parseCodexRetryAfter(status, out, time.Now()); retryAfter != nil {
statusError.retryAfter = retryAfter
} else if isCodexWebsocketConnectionLimitError(payload) {
retryAfter := time.Duration(0)
statusError.retryAfter = &retryAfter
}
return statusErrWithHeaders{
statusErr: statusErr{code: status, msg: string(out)},
statusErr: statusError,
headers: headers,
}, true
}
func buildCodexWebsocketErrorPayload(payload []byte, status int) []byte {
out := []byte(`{}`)
out, _ = sjson.SetBytes(out, "status", status)
if bodyNode := gjson.GetBytes(payload, "body"); bodyNode.Exists() {
out, _ = sjson.SetRawBytes(out, "body", []byte(bodyNode.Raw))
if bodyErrorNode := bodyNode.Get("error"); bodyErrorNode.Exists() {
out, _ = sjson.SetRawBytes(out, "error", []byte(bodyErrorNode.Raw))
return out
}
}
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw))
return out
}
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
return out
}
func isCodexWebsocketConnectionLimitError(payload []byte) bool {
if len(payload) == 0 {
return false
}
for _, path := range []string{"error.code", "error.type", "body.error.code", "body.error.type", "code", "error"} {
if strings.TrimSpace(gjson.GetBytes(payload, path).String()) == "websocket_connection_limit_reached" {
return true
}
}
return false
}
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
headersNode := gjson.GetBytes(payload, "headers")
if !headersNode.Exists() || !headersNode.IsObject() {
@@ -1,15 +1,21 @@
package executor
import (
"bytes"
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
@@ -32,14 +38,80 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
}
}
func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T) {
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
capturedPayload := make(chan []byte, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
t.Fatalf("request path = %s, want /responses", r.URL.Path)
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Fatalf("upgrade websocket: %v", err)
}
defer func() { _ = conn.Close() }()
msgType, payload, err := conn.ReadMessage()
if err != nil {
t.Fatalf("read upstream websocket message: %v", err)
}
if msgType != websocket.TextMessage {
t.Fatalf("message type = %d, want text", msgType)
}
capturedPayload <- bytes.Clone(payload)
completed := []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
t.Fatalf("write completed websocket message: %v", errWrite)
}
}))
defer server.Close()
exec := NewCodexWebsocketsExecutor(&config.Config{SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll}})
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "sk-test", "base_url": server.URL}}
req := cliproxyexecutor.Request{
Model: "gpt-5-codex",
Payload: []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`),
}
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("codex")}
if _, err := exec.Execute(context.Background(), auth, req, opts); err != nil {
t.Fatalf("Execute() error = %v", err)
}
select {
case payload := <-capturedPayload:
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
t.Fatalf("upstream type = %s, want response.create; payload=%s", got, payload)
}
if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-1" {
t.Fatalf("upstream previous_response_id = %s, want resp-1; payload=%s", got, payload)
}
case <-time.After(5 * time.Second):
t.Fatal("timed out waiting for upstream websocket payload")
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if !strings.HasPrefix(codexUserAgent, codexOriginator+"/") {
t.Fatalf("default Codex User-Agent = %s, want prefix %s/", codexUserAgent, codexOriginator)
}
if strings.HasPrefix(codexUserAgent, "codex-tui/") {
t.Fatalf("default Codex User-Agent = %s, must not use stale codex-tui prefix", codexUserAgent)
}
if strings.Contains(codexUserAgent, "(codex-tui;") {
t.Fatalf("default Codex User-Agent = %s, must not include stale codex-tui suffix", codexUserAgent)
}
if got := headers.Get("Originator"); got != codexOriginator {
t.Fatalf("Originator = %s, want %s", got, codexOriginator)
}
if got := headers.Get("Version"); got != "" {
t.Fatalf("Version = %q, want empty", got)
@@ -62,9 +134,11 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
}
ctx := contextWithGinHeaders(map[string]string{
"Originator": "Codex Desktop",
"User-Agent": "codex_cli_rs/0.1.0",
"Version": "0.115.0-alpha.27",
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
"session_id": "sess-client",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
@@ -72,6 +146,9 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
if got := headers.Get("Originator"); got != "Codex Desktop" {
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
}
if got := headers.Get("User-Agent"); got != "codex_cli_rs/0.1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "codex_cli_rs/0.1.0")
}
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
}
@@ -81,6 +158,12 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
}
if got := headerValueCaseInsensitive(headers, "session_id"); got != "sess-client" {
t.Fatalf("session_id = %s, want sess-client", got)
}
if _, ok := headers["session_id"]; !ok {
t.Fatalf("expected lowercase session_id header key, got %#v", headers)
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
@@ -97,8 +180,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
@@ -129,8 +212,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "" {
t.Fatalf("User-Agent = %s, want empty", gotVal)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
@@ -155,8 +238,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "" {
t.Fatalf("User-Agent = %s, want empty", got)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
@@ -183,6 +266,131 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
if got := headers.Get("Originator"); got != "" {
t.Fatalf("Originator = %s, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersPreservesExplicitAPIKeyUserAgent(t *testing.T) {
auth := &cliproxyauth.Auth{Provider: "codex", Attributes: map[string]string{"api_key": "sk-test"}}
ctx := contextWithGinHeaders(map[string]string{"User-Agent": "api-key-client/1.0", "Originator": "explicit-origin"})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "sk-test", nil)
if got := headers.Get("User-Agent"); got != "api-key-client/1.0" {
t.Fatalf("User-Agent = %s, want api-key-client/1.0", got)
}
if got := headers.Get("Originator"); got != "explicit-origin" {
t.Fatalf("Originator = %s, want explicit-origin", got)
}
}
func TestApplyCodexPromptCacheHeadersSetsLowercaseSessionAndLegacyConversation(t *testing.T) {
req := cliproxyexecutor.Request{Model: "gpt-5-codex", Payload: []byte(`{"prompt_cache_key":"cache-1"}`)}
_, headers := applyCodexPromptCacheHeaders("openai-response", req, []byte(`{"model":"gpt-5-codex"}`))
if got := headerValueCaseInsensitive(headers, "session_id"); got != "cache-1" {
t.Fatalf("session_id = %s, want cache-1", got)
}
if _, ok := headers["session_id"]; !ok {
t.Fatalf("expected lowercase session_id key, got %#v", headers)
}
if got := headers.Get("Conversation_id"); got != "cache-1" {
t.Fatalf("Conversation_id = %s, want cache-1", got)
}
}
func TestApplyCodexWebsocketHeadersUsesCanonicalAccountHeader(t *testing.T) {
auth := &cliproxyauth.Auth{Provider: "codex", Metadata: map[string]any{"account_id": "acct-1"}}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", nil)
if got := headerValueCaseInsensitive(headers, "ChatGPT-Account-ID"); got != "acct-1" {
t.Fatalf("ChatGPT-Account-ID = %s, want acct-1", got)
}
values, ok := headers["ChatGPT-Account-ID"]
if !ok {
t.Fatalf("expected exact ChatGPT-Account-ID key, got %#v", headers)
}
if len(values) != 1 || values[0] != "acct-1" {
t.Fatalf("ChatGPT-Account-ID values = %#v, want [acct-1]", values)
}
}
func TestBuildCodexResponsesWebsocketURLRequiresHTTPURL(t *testing.T) {
if got, err := buildCodexResponsesWebsocketURL("https://example.com/backend/responses"); err != nil || got != "wss://example.com/backend/responses" {
t.Fatalf("https URL = %q, %v; want wss URL", got, err)
}
if _, err := buildCodexResponsesWebsocketURL("ftp://example.com/responses"); err == nil {
t.Fatalf("expected unsupported scheme error")
}
if _, err := buildCodexResponsesWebsocketURL("https:///responses"); err == nil {
t.Fatalf("expected empty host error")
}
}
func TestParseCodexWebsocketErrorMarksConnectionLimitRetryable(t *testing.T) {
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"code":"websocket_connection_limit_reached","message":"too many websockets"},"headers":{"retry-after":"1"}}`))
if !ok {
t.Fatalf("expected websocket error")
}
status, ok := err.(interface{ StatusCode() int })
if !ok || status.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("status = %#v, want 429", err)
}
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
if !ok || retryable.RetryAfter() == nil {
t.Fatalf("expected retryable websocket connection limit error")
}
if got := *retryable.RetryAfter(); got != 0 {
t.Fatalf("retryAfter = %v, want connection-limit fallback 0", got)
}
withHeaders, ok := err.(interface{ Headers() http.Header })
if !ok || withHeaders.Headers().Get("retry-after") != "1" {
t.Fatalf("headers = %#v, want retry-after", err)
}
}
func TestParseCodexWebsocketErrorUsesUsageLimitRetryMetadata(t *testing.T) {
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"type":"usage_limit_reached","message":"usage limit reached","resets_in_seconds":7}}}`))
if !ok {
t.Fatalf("expected websocket error")
}
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
if !ok || retryable.RetryAfter() == nil {
t.Fatalf("expected retryable usage limit websocket error")
}
if got := *retryable.RetryAfter(); got != 7*time.Second {
t.Fatalf("retryAfter = %v, want 7s", got)
}
}
func TestParseCodexWebsocketErrorPreservesWrappedBodyAndHeaders(t *testing.T) {
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"code":"websocket_connection_limit_reached","type":"server_error","message":"too many websocket connections"}},"headers":{"x-request-id":"req-1"}}`))
if !ok {
t.Fatalf("expected websocket error")
}
parsed := gjson.Parse(err.Error())
if got := parsed.Get("status").Int(); got != http.StatusTooManyRequests {
t.Fatalf("wrapped status = %d, want 429; payload=%s", got, err.Error())
}
if got := parsed.Get("body.error.code").String(); got != "websocket_connection_limit_reached" {
t.Fatalf("wrapped body error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
}
if got := parsed.Get("error.code").String(); got != "websocket_connection_limit_reached" {
t.Fatalf("surface error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
}
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
if !ok || retryable.RetryAfter() == nil {
t.Fatalf("expected body.error.code websocket connection limit to be retryable")
}
withHeaders, ok := err.(interface{ Headers() http.Header })
if !ok || withHeaders.Headers().Get("x-request-id") != "req-1" {
t.Fatalf("headers = %#v, want x-request-id", err)
}
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
@@ -411,19 +411,30 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
return
}
reporter.EnsurePublished(ctx)
@@ -434,7 +445,10 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errRead}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
case <-ctx.Done():
}
return
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
@@ -442,12 +456,20 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), &param)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
case <-ctx.Done():
return
}
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
+14 -3
View File
@@ -324,17 +324,28 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -338,6 +338,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
}
action := getVertexAction(baseModel, false)
@@ -459,6 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
@@ -570,6 +572,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, true)
baseURL := vertexBaseURL(location)
@@ -656,17 +659,28 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -700,6 +714,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
@@ -786,17 +801,28 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -818,6 +844,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -907,6 +934,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
@@ -18,6 +18,7 @@ import (
type UsageReporter struct {
provider string
model string
alias string
authID string
authIndex string
authType string
@@ -29,9 +30,14 @@ type UsageReporter struct {
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := APIKeyFromContext(ctx)
alias := usage.RequestedModelAliasFromContext(ctx)
if alias == "" {
alias = model
}
reporter := &UsageReporter{
provider: provider,
model: model,
alias: strings.TrimSpace(alias),
requestedAt: time.Now(),
apiKey: apiKey,
source: resolveUsageSource(auth, apiKey),
@@ -139,6 +145,7 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
return usage.Record{
Provider: r.provider,
Model: model,
Alias: r.alias,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
@@ -1,6 +1,7 @@
package helps
import (
"context"
"testing"
"time"
@@ -107,6 +108,19 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
}
}
func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
ctx := usage.WithRequestedModelAlias(context.Background(), "client-gpt")
reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
if record.Model != "gpt-5.4" {
t.Fatalf("model = %q, want %q", record.Model, "gpt-5.4")
}
if record.Alias != "client-gpt" {
t.Fatalf("alias = %q, want %q", record.Alias, "client-gpt")
}
}
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
reporter := &UsageReporter{
provider: "codex",
@@ -0,0 +1,43 @@
package helps
import (
"fmt"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// StripVertexOpenAIResponsesToolCallIDs removes OpenAI Responses call IDs that
// Vertex rejects in Gemini functionCall/functionResponse payloads.
func StripVertexOpenAIResponsesToolCallIDs(payload []byte, sourceFormat string) []byte {
if !strings.EqualFold(strings.TrimSpace(sourceFormat), "openai-response") {
return payload
}
contents := gjson.GetBytes(payload, "contents")
if !contents.IsArray() {
return payload
}
out := payload
for contentIndex, content := range contents.Array() {
parts := content.Get("parts")
if !parts.IsArray() {
continue
}
for partIndex, part := range parts.Array() {
if part.Get("functionCall.id").Exists() {
if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionCall.id", contentIndex, partIndex)); errDelete == nil {
out = updated
}
}
if part.Get("functionResponse.id").Exists() {
if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionResponse.id", contentIndex, partIndex)); errDelete == nil {
out = updated
}
}
}
}
return out
}
+115 -5
View File
@@ -290,17 +290,28 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), &param)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}:
case <-ctx.Done():
return
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
@@ -322,7 +333,17 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
return body, nil
}
out := body
msgs := messages.Array()
out, dropped, err := filterKimiEmptyAssistantMessages(body, msgs)
if err != nil {
return body, err
}
if dropped > 0 {
log.WithField("dropped_assistant_messages", dropped).Debug("kimi executor: dropped empty assistant messages")
}
messages = gjson.GetBytes(out, "messages")
msgs = messages.Array()
pending := make([]string, 0)
patched := 0
patchedReasoning := 0
@@ -340,7 +361,6 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
}
}
msgs := messages.Array()
for msgIdx := range msgs {
msg := msgs[msgIdx]
role := strings.TrimSpace(msg.Get("role").String())
@@ -428,6 +448,96 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
return out, nil
}
func filterKimiEmptyAssistantMessages(body []byte, msgs []gjson.Result) ([]byte, int, error) {
kept := make([]string, 0, len(msgs))
dropped := 0
for _, msg := range msgs {
if shouldDropKimiAssistantMessage(msg) {
dropped++
continue
}
kept = append(kept, msg.Raw)
}
if dropped == 0 {
return body, 0, nil
}
rawMessages := []byte("[" + strings.Join(kept, ",") + "]")
out, err := sjson.SetRawBytes(body, "messages", rawMessages)
if err != nil {
return body, 0, fmt.Errorf("kimi executor: failed to drop empty assistant messages: %w", err)
}
return out, dropped, nil
}
func shouldDropKimiAssistantMessage(msg gjson.Result) bool {
if strings.TrimSpace(msg.Get("role").String()) != "assistant" {
return false
}
if hasKimiToolCalls(msg) || hasKimiLegacyFunctionCall(msg) || hasKimiAssistantReasoning(msg) {
return false
}
return isKimiAssistantContentEmpty(msg.Get("content"))
}
func hasKimiToolCalls(msg gjson.Result) bool {
toolCalls := msg.Get("tool_calls")
return toolCalls.Exists() && toolCalls.IsArray() && len(toolCalls.Array()) > 0
}
func hasKimiLegacyFunctionCall(msg gjson.Result) bool {
functionCall := msg.Get("function_call")
if !functionCall.Exists() || functionCall.Type == gjson.Null {
return false
}
if functionCall.IsObject() && strings.TrimSpace(functionCall.Raw) == "{}" {
return false
}
return strings.TrimSpace(functionCall.Raw) != ""
}
func hasKimiAssistantReasoning(msg gjson.Result) bool {
reasoning := msg.Get("reasoning_content")
return reasoning.Exists() && strings.TrimSpace(reasoning.String()) != ""
}
func isKimiAssistantContentEmpty(content gjson.Result) bool {
if !content.Exists() || content.Type == gjson.Null {
return true
}
if content.Type == gjson.String {
return strings.TrimSpace(content.String()) == ""
}
if !content.IsArray() {
return false
}
for _, part := range content.Array() {
if !isKimiAssistantContentPartEmpty(part) {
return false
}
}
return true
}
func isKimiAssistantContentPartEmpty(part gjson.Result) bool {
if !part.Exists() || part.Type == gjson.Null {
return true
}
if part.Type == gjson.String {
return strings.TrimSpace(part.String()) == ""
}
if !part.IsObject() {
return false
}
if text := part.Get("text"); text.Exists() {
return strings.TrimSpace(text.String()) == ""
}
if strings.TrimSpace(part.Get("type").String()) == "text" {
return true
}
return strings.TrimSpace(part.Raw) == "{}"
}
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
if hasLatest && strings.TrimSpace(latest) != "" {
return latest
@@ -203,3 +203,70 @@ func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
}
}
func TestNormalizeKimiToolMessageLinks_DropsEmptyAssistantWithoutToolLink(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"user","content":"start"},
{"role":"assistant","content":""},
{"role":"assistant","content":" "},
{"role":"assistant","content":"","tool_calls":null},
{"role":"assistant","content":[{"type":"text","text":" "}]},
{"role":"assistant"},
{"role":"assistant","content":"keep"},
{"role":"user","content":"next"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
messages := gjson.GetBytes(out, "messages").Array()
if len(messages) != 3 {
t.Fatalf("messages length = %d, want 3, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
}
if got := messages[0].Get("content").String(); got != "start" {
t.Fatalf("messages.0.content = %q, want %q", got, "start")
}
if got := messages[1].Get("content").String(); got != "keep" {
t.Fatalf("messages.1.content = %q, want %q", got, "keep")
}
if got := messages[2].Get("content").String(); got != "next" {
t.Fatalf("messages.2.content = %q, want %q", got, "next")
}
}
func TestNormalizeKimiToolMessageLinks_PreservesAssistantWithToolLinkOrReasoning(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
{"role":"assistant","content":"","function_call":{"name":"legacy_call","arguments":"{}"}},
{"role":"assistant","content":"","reasoning_content":"thought"},
{"role":"assistant","content":[{"type":"text","text":" visible "}]}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
messages := gjson.GetBytes(out, "messages").Array()
if len(messages) != 4 {
t.Fatalf("messages length = %d, want 4, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
}
if !messages[0].Get("tool_calls").Exists() {
t.Fatalf("messages.0.tool_calls should exist")
}
if !messages[1].Get("function_call").Exists() {
t.Fatalf("messages.1.function_call should exist")
}
if got := messages[2].Get("reasoning_content").String(); got != "thought" {
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "thought")
}
if got := messages[3].Get("content.0.text").String(); got != " visible " {
t.Fatalf("messages.3.content.0.text = %q, want %q", got, " visible ")
}
}
@@ -96,6 +96,12 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
@@ -105,11 +111,6 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
}
}
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := strings.TrimSuffix(baseURL, "/") + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
@@ -199,15 +200,16 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
// Request usage data in the final streaming chunk so that token statistics
// are captured even when the upstream is an OpenAI-compatible provider.
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
@@ -281,32 +283,57 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
reporter.Publish(ctx, detail)
}
if len(line) == 0 {
trimmedLine := bytes.TrimSpace(line)
if len(trimmedLine) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
if !bytes.HasPrefix(trimmedLine, []byte("data:")) {
if bytes.HasPrefix(trimmedLine, []byte(":")) || bytes.HasPrefix(trimmedLine, []byte("event:")) ||
bytes.HasPrefix(trimmedLine, []byte("id:")) || bytes.HasPrefix(trimmedLine, []byte("retry:")) {
continue
}
if bytes.HasPrefix(trimmedLine, []byte("{")) || bytes.HasPrefix(trimmedLine, []byte("[")) {
streamErr := statusErr{code: http.StatusBadGateway, msg: string(trimmedLine)}
helps.RecordAPIResponseError(ctx, e.cfg, streamErr)
reporter.PublishFailure(ctx)
select {
case out <- cliproxyexecutor.StreamChunk{Err: streamErr}:
case <-ctx.Done():
}
return
}
continue
}
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), &param)
// OpenAI-compatible streams must use SSE data lines.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(trimmedLine), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
} else {
// In case the upstream close the stream without a terminal [DONE] marker.
// Feed a synthetic done marker through the translator so pending
// response.completed events are still emitted exactly once.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), &param)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
// Ensure we record the request if no usage chunk was ever seen
@@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -56,3 +57,125 @@ func TestOpenAICompatExecutorCompactPassthrough(t *testing.T) {
t.Fatalf("payload = %s", string(resp.Payload))
}
}
func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"chatcmpl_1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{
Payload: config.PayloadConfig{
Override: []config.PayloadRule{
{
Models: []config.PayloadModelRule{
{Name: "custom-openai", Protocol: "openai"},
},
Params: map[string]any{
"reasoning_effort": "low",
},
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
payload := []byte(`{"model":"custom-openai(high)","messages":[{"role":"user","content":"hi"}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "custom-openai(high)",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if got := gjson.GetBytes(gotBody, "reasoning_effort").String(); got != "low" {
t.Fatalf("reasoning_effort = %q, want %q; body=%s", got, "low", string(gotBody))
}
}
func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: error\n"))
_, _ = w.Write([]byte(`{"error":{"message":"upstream failed","type":"server_error"}}` + "\n"))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "openrouter-model",
Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var gotErr error
for chunk := range result.Chunks {
if chunk.Err != nil {
gotErr = chunk.Err
break
}
}
if gotErr == nil {
t.Fatalf("expected plain JSON stream error")
}
if status, ok := gotErr.(interface{ StatusCode() int }); !ok || status.StatusCode() != http.StatusBadGateway {
t.Fatalf("stream error status = %v, want %d", gotErr, http.StatusBadGateway)
}
if !strings.Contains(gotErr.Error(), "upstream failed") {
t.Fatalf("stream error = %v", gotErr)
}
}
func TestOpenAICompatExecutorStreamSkipsKeepAliveUntilDataLine(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: ping\nid: 1\nretry: 1000\n"))
_, _ = w.Write([]byte(`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hello"},"finish_reason":null}]}` + "\n"))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "openrouter-model",
Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai"),
Stream: true,
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var got strings.Builder
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
got.Write(chunk.Payload)
}
if gjson.Get(got.String(), "choices.0.delta.content").String() != "hello" {
t.Fatalf("stream payload = %s", got.String())
}
}
@@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct {
CreatedAt int64
ResponseID string
FinishReason string
Usage claudeUsageTokens
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
}
type claudeUsageTokens struct {
InputTokens int64
OutputTokens int64
CacheCreationInputTokens int64
CacheReadInputTokens int64
HasUsage bool
}
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
@@ -36,15 +45,30 @@ type ToolCallAccumulator struct {
Arguments strings.Builder
}
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
inputTokens := usage.Get("input_tokens").Int()
completionTokens = usage.Get("output_tokens").Int()
cachedTokens = usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
func (u *claudeUsageTokens) Merge(usage gjson.Result) {
if !usage.Exists() {
return
}
u.HasUsage = true
if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() {
u.InputTokens = inputTokens.Int()
}
if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() {
u.OutputTokens = outputTokens.Int()
}
if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() {
u.CacheCreationInputTokens = cacheCreationInputTokens.Int()
}
if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() {
u.CacheReadInputTokens = cacheReadInputTokens.Int()
}
}
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
cachedTokens = u.CacheReadInputTokens
promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens
completionTokens = u.OutputTokens
totalTokens = promptTokens + completionTokens
return promptTokens, completionTokens, totalTokens, cachedTokens
}
@@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage"))
}
return [][]byte{template}
@@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage)
promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage()
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
@@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var stopReason string
var contentParts []string
var reasoningParts []string
usageTokens := claudeUsageTokens{}
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
for _, chunk := range chunks {
@@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
usageTokens.Merge(message.Get("usage"))
}
case "content_block_start":
@@ -371,15 +399,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
usageTokens.Merge(usage)
}
}
}
if usageTokens.HasUsage {
promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage()
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
}
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.SetBytes(out, "id", messageID)
out, _ = sjson.SetBytes(out, "created", createdAt)
@@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin
}
}
func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) {
ctx := context.Background()
var param any
ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`),
&param,
)
out := ConvertClaudeResponseToOpenAI(
ctx,
"claude-opus-4-6",
nil,
nil,
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`),
&param,
)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
@@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) {
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" +
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n")
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
}
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
}
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
}
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
}
}
@@ -339,25 +339,21 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
})
}
includedToolNames := map[string]struct{}{}
toolNameMap := map[string]string{}
// tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := []byte("[]")
tools.ForEach(func(_, tool gjson.Result) bool {
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
if n := tool.Get("name"); n.Exists() {
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
convertedTools := convertResponsesToolToClaudeTools(tool, toolNameMap)
for _, tJSON := range convertedTools {
toolName := gjson.GetBytes(tJSON, "name").String()
if toolName != "" {
includedToolNames[toolName] = struct{}{}
}
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
}
if d := tool.Get("description"); d.Exists() {
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
}
if params := tool.Get("parameters"); params.Exists() {
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
}
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
return true
})
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
@@ -375,14 +371,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
case "none":
// Leave unset; implies no tools
case "required":
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
if len(includedToolNames) > 0 {
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
}
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
if fn == "" {
fn = toolChoice.Get("name").String()
}
if mappedName := toolNameMap[fn]; mappedName != "" {
fn = mappedName
}
if _, ok := includedToolNames[fn]; ok {
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
}
}
default:
@@ -391,3 +397,167 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
return out
}
func convertResponsesToolToClaudeTools(tool gjson.Result, toolNameMap map[string]string) [][]byte {
toolType := strings.TrimSpace(tool.Get("type").String())
switch toolType {
case "", "function":
if tJSON, ok := convertResponsesFunctionToolToClaude(tool, ""); ok {
return [][]byte{tJSON}
}
case "namespace":
return convertResponsesNamespaceToolToClaude(tool, toolNameMap)
case "web_search":
if tJSON, ok := convertResponsesWebSearchToolToClaude(tool); ok {
if name := gjson.GetBytes(tJSON, "name").String(); name != "" {
toolNameMap[name] = name
}
return [][]byte{tJSON}
}
default:
if isUnsupportedOpenAIBuiltinToolType(toolType) {
return nil
}
if tool.Get("name").String() != "" {
return [][]byte{[]byte(tool.Raw)}
}
}
return nil
}
func convertResponsesNamespaceToolToClaude(tool gjson.Result, toolNameMap map[string]string) [][]byte {
namespaceName := strings.TrimSpace(tool.Get("name").String())
children := tool.Get("tools")
if !children.Exists() || !children.IsArray() {
return nil
}
var out [][]byte
children.ForEach(func(_, child gjson.Result) bool {
childName := responsesToolName(child)
qualifiedName := qualifyResponsesNamespaceToolName(namespaceName, childName)
if tJSON, ok := convertResponsesFunctionToolToClaude(child, qualifiedName); ok {
out = append(out, tJSON)
toolNameMap[qualifiedName] = qualifiedName
if childName != "" {
toolNameMap[childName] = qualifiedName
}
}
return true
})
return out
}
func convertResponsesFunctionToolToClaude(tool gjson.Result, overrideName string) ([]byte, bool) {
name := strings.TrimSpace(overrideName)
if name == "" {
name = responsesToolName(tool)
}
if name == "" {
return nil, false
}
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
tJSON, _ = sjson.SetBytes(tJSON, "name", name)
if d := responsesToolDescription(tool); d != "" {
tJSON, _ = sjson.SetBytes(tJSON, "description", d)
}
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", normalizeClaudeToolInputSchema(responsesToolParameters(tool)))
return tJSON, true
}
func convertResponsesWebSearchToolToClaude(tool gjson.Result) ([]byte, bool) {
if externalWebAccess := tool.Get("external_web_access"); externalWebAccess.Exists() && !externalWebAccess.Bool() {
return nil, false
}
name := strings.TrimSpace(tool.Get("name").String())
if name == "" {
name = "web_search"
}
tJSON := []byte(`{"type":"web_search_20250305","name":""}`)
tJSON, _ = sjson.SetBytes(tJSON, "name", name)
if maxUses := tool.Get("max_uses"); maxUses.Exists() {
tJSON, _ = sjson.SetBytes(tJSON, "max_uses", maxUses.Int())
}
if allowedDomains := tool.Get("filters.allowed_domains"); allowedDomains.Exists() && allowedDomains.IsArray() {
tJSON, _ = sjson.SetRawBytes(tJSON, "allowed_domains", []byte(allowedDomains.Raw))
}
if userLocation := tool.Get("user_location"); userLocation.Exists() && userLocation.IsObject() {
tJSON, _ = sjson.SetRawBytes(tJSON, "user_location", []byte(userLocation.Raw))
}
return tJSON, true
}
func responsesToolName(tool gjson.Result) string {
if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
return name
}
return strings.TrimSpace(tool.Get("function.name").String())
}
func responsesToolDescription(tool gjson.Result) string {
if description := tool.Get("description").String(); description != "" {
return description
}
return tool.Get("function.description").String()
}
func responsesToolParameters(tool gjson.Result) gjson.Result {
for _, path := range []string{
"parameters",
"parametersJsonSchema",
"input_schema",
"function.parameters",
"function.parametersJsonSchema",
} {
if parameters := tool.Get(path); parameters.Exists() {
return parameters
}
}
return gjson.Result{}
}
func normalizeClaudeToolInputSchema(parameters gjson.Result) []byte {
raw := strings.TrimSpace(parameters.Raw)
if raw == "" || raw == "null" || !gjson.Valid(raw) {
return []byte(`{"type":"object","properties":{}}`)
}
result := gjson.Parse(raw)
if !result.IsObject() {
return []byte(`{"type":"object","properties":{}}`)
}
schema := []byte(raw)
schemaType := result.Get("type").String()
if schemaType == "" {
schema, _ = sjson.SetBytes(schema, "type", "object")
schemaType = "object"
}
if schemaType == "object" && !result.Get("properties").Exists() {
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
}
return schema
}
func qualifyResponsesNamespaceToolName(namespaceName, childName string) string {
childName = strings.TrimSpace(childName)
if childName == "" || namespaceName == "" || strings.HasPrefix(childName, "mcp__") {
return childName
}
if strings.HasPrefix(childName, namespaceName) {
return childName
}
if strings.HasSuffix(namespaceName, "__") {
return namespaceName + childName
}
return namespaceName + "__" + childName
}
func isUnsupportedOpenAIBuiltinToolType(toolType string) bool {
switch toolType {
case "image_generation", "file_search", "code_interpreter", "computer_use_preview":
return true
default:
return false
}
}
@@ -26,7 +26,8 @@ type claudeToResponsesState struct {
FuncNames map[int]string // index -> function name
FuncCallIDs map[int]string // index -> call id
// message text aggregation
TextBuf strings.Builder
TextBuf strings.Builder
CurrentTextBuf strings.Builder
// reasoning state
ReasoningActive bool
ReasoningItemID string
@@ -80,6 +81,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
st.CreatedAt = time.Now().Unix()
// Reset per-message aggregation state
st.TextBuf.Reset()
st.CurrentTextBuf.Reset()
st.ReasoningBuf.Reset()
st.ReasoningActive = false
st.InTextBlock = false
@@ -128,6 +130,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if typ == "text" {
// open message item + content part
st.InTextBlock = true
st.CurrentTextBuf.Reset()
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
@@ -189,6 +192,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output
st.TextBuf.WriteString(t.String())
st.CurrentTextBuf.WriteString(t.String())
}
} else if dt == "input_json_delta" {
idx := int(root.Get("index").Int())
@@ -220,17 +224,21 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
case "content_block_stop":
idx := int(root.Get("index").Int())
if st.InTextBlock {
fullText := st.CurrentTextBuf.String()
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
done, _ = sjson.SetBytes(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
final, _ = sjson.SetBytes(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false
} else if st.InFuncBlock {
@@ -0,0 +1,78 @@
package geminiCLI
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertGeminiCLIRequestToCodex_PreservesSchemaPropertyNamedType(t *testing.T) {
input := []byte(`{
"request": {
"tools": [
{
"functionDeclarations": [
{
"name": "ask_user",
"description": "Ask the user one or more questions.",
"parametersJsonSchema": {
"type": "object",
"properties": {
"questions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"header": {
"type": "string"
},
"type": {
"default": "choice",
"description": "Question type.",
"enum": [
"choice",
"text",
"yesno"
],
"type": "string"
}
},
"required": [
"question",
"header",
"type"
]
}
}
},
"required": [
"questions"
]
}
}
]
}
]
}
}`)
out := ConvertGeminiCLIRequestToCodex("gpt-5.2", input, true)
tool := gjson.GetBytes(out, "tools.0")
if got := tool.Get("type").String(); got != "function" {
t.Fatalf("expected tool type %q, got %q; output=%s", "function", got, string(out))
}
typeProperty := tool.Get("parameters.properties.questions.items.properties.type")
if !typeProperty.IsObject() {
t.Fatalf("expected schema property named type to stay an object; output=%s", string(out))
}
if got := typeProperty.Get("type").String(); got != "string" {
t.Fatalf("expected schema property type %q, got %q; output=%s", "string", got, string(out))
}
if got := typeProperty.Get("default").String(); got != "choice" {
t.Fatalf("expected default %q, got %q; output=%s", "choice", got, string(out))
}
if got := typeProperty.Get("enum.2").String(); got != "yesno" {
t.Fatalf("expected enum value %q, got %q; output=%s", "yesno", got, string(out))
}
}
@@ -284,7 +284,11 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
typeValue := gjson.GetBytes(out, fullPath)
if typeValue.Type != gjson.String {
continue
}
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(typeValue.String()))
}
return out
@@ -121,13 +121,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
case "tool":
// Handle tool response messages as top-level function_call_output objects
toolCallID := m.Get("tool_call_id").String()
content := m.Get("content").String()
content := m.Get("content")
// Create function_call_output object
funcOutput := []byte(`{}`)
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
funcOutput = setToolCallOutputContent(funcOutput, content)
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
default:
@@ -359,6 +359,91 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
return out
}
func setToolCallOutputContent(funcOutput []byte, content gjson.Result) []byte {
switch {
case content.Type == gjson.String:
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content.String())
case content.IsArray():
output := []byte(`[]`)
for _, item := range content.Array() {
output = appendToolOutputContentPart(output, item)
}
funcOutput, _ = sjson.SetRawBytes(funcOutput, "output", output)
default:
fallbackOutput := content.Raw
if fallbackOutput == "" {
fallbackOutput = content.String()
}
funcOutput, _ = sjson.SetBytes(funcOutput, "output", fallbackOutput)
}
return funcOutput
}
func appendToolOutputContentPart(output []byte, item gjson.Result) []byte {
switch item.Get("type").String() {
case "text":
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.SetBytes(part, "text", item.Get("text").String())
output, _ = sjson.SetRawBytes(output, "-1", part)
case "image_url":
imageURL := item.Get("image_url.url").String()
fileID := item.Get("image_url.file_id").String()
if imageURL == "" && fileID == "" {
return appendToolOutputFallbackPart(output, item)
}
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_image")
if imageURL != "" {
part, _ = sjson.SetBytes(part, "image_url", imageURL)
}
if fileID != "" {
part, _ = sjson.SetBytes(part, "file_id", fileID)
}
if detail := item.Get("image_url.detail").String(); detail != "" {
part, _ = sjson.SetBytes(part, "detail", detail)
}
output, _ = sjson.SetRawBytes(output, "-1", part)
case "file":
fileID := item.Get("file.file_id").String()
fileData := item.Get("file.file_data").String()
fileURL := item.Get("file.file_url").String()
if fileID == "" && fileData == "" && fileURL == "" {
return appendToolOutputFallbackPart(output, item)
}
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_file")
if fileID != "" {
part, _ = sjson.SetBytes(part, "file_id", fileID)
}
if fileData != "" {
part, _ = sjson.SetBytes(part, "file_data", fileData)
}
if fileURL != "" {
part, _ = sjson.SetBytes(part, "file_url", fileURL)
}
if filename := item.Get("file.filename").String(); filename != "" {
part, _ = sjson.SetBytes(part, "filename", filename)
}
output, _ = sjson.SetRawBytes(output, "-1", part)
default:
output = appendToolOutputFallbackPart(output, item)
}
return output
}
func appendToolOutputFallbackPart(output []byte, item gjson.Result) []byte {
text := item.Raw
if text == "" {
text = item.String()
}
part := []byte(`{}`)
part, _ = sjson.SetBytes(part, "type", "input_text")
part, _ = sjson.SetBytes(part, "text", text)
output, _ = sjson.SetRawBytes(output, "-1", part)
return output
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.
// If the name length exceeds 64, it will try to preserve the "mcp__" prefix and last segment.
// Otherwise it truncates to 64 characters.
@@ -176,6 +176,182 @@ func TestToolCallWithContent(t *testing.T) {
}
}
func TestToolCallOutputWithMultimodalContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Show me the generated result."},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_output_1",
"type": "function",
"function": {"name": "render_output", "arguments": "{}"}
}
]
},
{
"role": "tool",
"tool_call_id": "call_output_1",
"content": [
{"type":"text","text":"Rendered result attached."},
{"type":"image_url","image_url":{"url":"https://example.com/generated.png","detail":"high"}},
{"type":"image_url","image_url":{"file_id":"file-img-123"}},
{"type":"file","file":{"file_id":"file-doc-123","filename":"doc.pdf"}},
{"type":"file","file":{"file_data":"SGVsbG8=","filename":"inline.txt"}},
{"type":"file","file":{"file_url":"https://example.com/report.pdf","filename":"report.pdf"}}
]
}
],
"tools": [
{
"type": "function",
"function": {"name": "render_output", "description": "Render output", "parameters": {"type": "object", "properties": {}}}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
output := gjson.Get(result, "input.2.output")
if !output.IsArray() {
t.Fatalf("expected tool output to be an array, got: %s", output.Raw)
}
parts := output.Array()
if len(parts) != 6 {
t.Fatalf("expected 6 output parts, got %d: %s", len(parts), output.Raw)
}
if parts[0].Get("type").String() != "input_text" || parts[0].Get("text").String() != "Rendered result attached." {
t.Fatalf("part 0: expected input_text with rendered text, got %s", parts[0].Raw)
}
if parts[1].Get("type").String() != "input_image" {
t.Fatalf("part 1: expected input_image, got %s", parts[1].Raw)
}
if parts[1].Get("image_url").String() != "https://example.com/generated.png" {
t.Errorf("part 1: unexpected image_url %s", parts[1].Get("image_url").String())
}
if parts[1].Get("detail").String() != "high" {
t.Errorf("part 1: unexpected detail %s", parts[1].Get("detail").String())
}
if parts[2].Get("type").String() != "input_image" || parts[2].Get("file_id").String() != "file-img-123" {
t.Fatalf("part 2: expected file_id-backed input_image, got %s", parts[2].Raw)
}
if parts[3].Get("type").String() != "input_file" || parts[3].Get("file_id").String() != "file-doc-123" {
t.Fatalf("part 3: expected file_id-backed input_file, got %s", parts[3].Raw)
}
if parts[3].Get("filename").String() != "doc.pdf" {
t.Errorf("part 3: unexpected filename %s", parts[3].Get("filename").String())
}
if parts[4].Get("type").String() != "input_file" || parts[4].Get("file_data").String() != "SGVsbG8=" {
t.Fatalf("part 4: expected file_data-backed input_file, got %s", parts[4].Raw)
}
if parts[5].Get("type").String() != "input_file" || parts[5].Get("file_url").String() != "https://example.com/report.pdf" {
t.Fatalf("part 5: expected file_url-backed input_file, got %s", parts[5].Raw)
}
}
func TestToolCallOutputFallsBackForInvalidStructuredParts(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Check tool output."},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "call_invalid_parts", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
]
},
{
"role": "tool",
"tool_call_id": "call_invalid_parts",
"content": [
{"type":"image_url","image_url":{"detail":"low"}},
{"type":"file","file":{"filename":"orphan.txt"}},
{"type":"unknown_type","foo":"bar","nested":{"a":1}}
]
}
],
"tools": [
{"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
parts := gjson.Get(result, "input.2.output").Array()
if len(parts) != 3 {
t.Fatalf("expected 3 output parts, got %d: %s", len(parts), gjson.Get(result, "input.2.output").Raw)
}
expectedFallbacks := []string{
`{"type":"image_url","image_url":{"detail":"low"}}`,
`{"type":"file","file":{"filename":"orphan.txt"}}`,
`{"type":"unknown_type","foo":"bar","nested":{"a":1}}`,
}
for i, expectedFallback := range expectedFallbacks {
if parts[i].Get("type").String() != "input_text" {
t.Fatalf("part %d: expected input_text fallback, got %s", i, parts[i].Raw)
}
if parts[i].Get("text").String() != expectedFallback {
t.Fatalf("part %d: expected fallback %s, got %s", i, expectedFallback, parts[i].Get("text").String())
}
}
}
func TestToolCallOutputWithNonStringJSONContent(t *testing.T) {
tests := []struct {
name string
content string
expectedOutput string
}{
{name: "null", content: `null`, expectedOutput: `null`},
{name: "object", content: `{"status":"ok","count":2}`, expectedOutput: `{"status":"ok","count":2}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Check tool output."},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "call_json", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
]
},
{
"role": "tool",
"tool_call_id": "call_json",
"content": ` + tt.content + `
}
],
"tools": [
{"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
output := gjson.Get(result, "input.2.output")
if !output.Exists() {
t.Fatalf("expected output field to exist: %s", gjson.Get(result, "input.2").Raw)
}
if output.String() != tt.expectedOutput {
t.Fatalf("expected output %s, got %s", tt.expectedOutput, output.String())
}
})
}
}
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
@@ -236,7 +236,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
// Handle function name
if function := toolCall.Get("function"); function.Exists() {
if name := function.Get("name"); name.Exists() {
if name := function.Get("name"); name.Exists() && name.String() != "" {
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
stopThinkingContentBlock(param, &results)
@@ -0,0 +1,41 @@
package claude
import (
"bytes"
"context"
"testing"
)
func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing.T) {
originalRequest := []byte(`{"stream":true}`)
var param any
firstChunks := ConvertOpenAIResponseToClaude(
context.Background(),
"test-model",
originalRequest,
nil,
[]byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"read_file","arguments":""}}]},"finish_reason":null}]}`),
&param,
)
firstOutput := bytes.Join(firstChunks, nil)
if !bytes.Contains(firstOutput, []byte(`"name":"read_file"`)) {
t.Fatalf("expected first chunk to start read_file tool block, got %s", string(firstOutput))
}
secondChunks := ConvertOpenAIResponseToClaude(
context.Background(),
"test-model",
originalRequest,
nil,
[]byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"{\"path\":\"/tmp/a\"}"}}]},"finish_reason":null}]}`),
&param,
)
secondOutput := bytes.Join(secondChunks, nil)
if bytes.Contains(secondOutput, []byte(`content_block_start`)) {
t.Fatalf("did not expect null tool name delta to start a new content block, got %s", string(secondOutput))
}
if bytes.Contains(secondOutput, []byte(`"name":""`)) {
t.Fatalf("did not expect null tool name delta to emit an empty tool name, got %s", string(secondOutput))
}
}
+4 -18
View File
@@ -18,7 +18,6 @@ const (
tabAuthFiles
tabAPIKeys
tabOAuth
tabUsage
tabLogs
)
@@ -40,7 +39,6 @@ type App struct {
auth authTabModel
keys keysTabModel
oauth oauthTabModel
usage usageTabModel
logs logsTabModel
client *Client
@@ -50,7 +48,7 @@ type App struct {
ready bool
// Track which tabs have been initialized (fetched data)
initialized [7]bool
initialized [6]bool
}
type authConnectMsg struct {
@@ -81,10 +79,9 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
initialized: [7]bool{
initialized: [6]bool{
tabDashboard: true,
tabLogs: true,
},
@@ -92,7 +89,7 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
app.refreshTabs()
if authRequired {
app.initialized = [7]bool{}
app.initialized = [6]bool{}
}
app.setAuthInputPrompt()
return app
@@ -128,7 +125,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
@@ -142,7 +138,7 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
a.initialized = [7]bool{}
a.initialized = [6]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
@@ -258,8 +254,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
case tabUsage:
a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
@@ -322,8 +316,6 @@ func (a *App) initTabIfNeeded(_ int) tea.Cmd {
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
case tabUsage:
return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
@@ -360,8 +352,6 @@ func (a App) View() string {
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
case tabUsage:
sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
@@ -529,10 +519,6 @@ func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
if cmd != nil {
cmds = append(cmds, cmd)
}
a.usage, cmd = a.usage.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
-5
View File
@@ -140,11 +140,6 @@ func (c *Client) PutConfigYAML(yamlContent string) error {
return err
}
// GetUsage fetches usage statistics.
func (c *Client) GetUsage() (map[string]any, error) {
return c.getJSON("/v0/management/usage")
}
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
+7 -70
View File
@@ -22,14 +22,12 @@ type dashboardModel struct {
// Cached data for re-rendering on locale change
lastConfig map[string]any
lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
@@ -47,25 +45,24 @@ func (m dashboardModel) Init() tea.Cmd {
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
for _, e := range []error{cfgErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
return dashboardDataMsg{config: cfg, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
m.content = m.renderDashboard(m.lastConfig, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
@@ -78,11 +75,10 @@ func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
m.content = m.renderDashboard(msg.config, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
@@ -121,7 +117,7 @@ func (m dashboardModel) View() string {
return m.viewport.View()
}
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
func (m dashboardModel) renderDashboard(cfg map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
@@ -138,7 +134,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
cardWidth = (m.width - 6) / 4
cardWidth = (m.width - 2) / 2
if cardWidth < 18 {
cardWidth = 18
}
@@ -173,34 +169,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
// Card 3: Total Requests
totalReqs := int64(0)
successReqs := int64(0)
failedReqs := int64(0)
totalTokens := int64(0)
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
totalReqs = int64(getFloat(usageMap, "total_requests"))
successReqs = int64(getFloat(usageMap, "success_count"))
failedReqs = int64(getFloat(usageMap, "failure_count"))
totalTokens = int64(getFloat(usageMap, "total_tokens"))
}
}
card3 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
))
// Card 4: Total Tokens
tokenStr := formatLargeNumber(totalTokens)
card4 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
@@ -258,38 +227,6 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
sb.WriteString("\n")
// ━━━ Per-Model Usage ━━━
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for _, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
reqs := int64(getFloat(stats, "total_requests"))
toks := int64(getFloat(stats, "total_tokens"))
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
sb.WriteString(tableCellStyle.Render(row))
sb.WriteString("\n")
}
}
}
}
}
}
}
}
return sb.String()
}
+2 -2
View File
@@ -50,8 +50,8 @@ var locales = map[string]map[string]string{
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
-418
View File
@@ -1,418 +0,0 @@
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// usageTabModel displays usage statistics with charts and breakdowns.
type usageTabModel struct {
client *Client
viewport viewport.Model
usage map[string]any
err error
width int
height int
ready bool
}
type usageDataMsg struct {
usage map[string]any
err error
}
func newUsageTabModel(client *Client) usageTabModel {
return usageTabModel{
client: client,
}
}
func (m usageTabModel) Init() tea.Cmd {
return m.fetchData
}
func (m usageTabModel) fetchData() tea.Msg {
usage, err := m.client.GetUsage()
return usageDataMsg{usage: usage, err: err}
}
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case usageDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.usage = msg.usage
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *usageTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m usageTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m usageTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("usage_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("usage_help")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if m.usage == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
usageMap, _ := m.usage["usage"].(map[string]any)
if usageMap == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
totalReqs := int64(getFloat(usageMap, "total_requests"))
successCnt := int64(getFloat(usageMap, "success_count"))
failureCnt := int64(getFloat(usageMap, "failure_count"))
totalTokens := int64(getFloat(usageMap, "total_tokens"))
// ━━━ Overview Cards ━━━
cardWidth := 20
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 16 {
cardWidth = 16
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(3)
// Total Requests
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
))
// Total Tokens
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
))
// RPM
rpm := float64(0)
if totalReqs > 0 {
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
}
}
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
))
// TPM
tpm := float64(0)
if totalTokens > 0 {
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
}
}
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Requests by Hour (ASCII bar chart) ━━━
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
sb.WriteString("\n")
}
// ━━━ Tokens by Hour ━━━
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
sb.WriteString("\n")
}
// ━━━ Requests by Day ━━━
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
sb.WriteString("\n")
}
// ━━━ API Detail Stats ━━━
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for apiName, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
apiReqs := int64(getFloat(apiMap, "total_requests"))
apiToks := int64(getFloat(apiMap, "total_tokens"))
row := fmt.Sprintf(" %-30s %10d %12s",
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
sb.WriteString("\n")
// Per-model breakdown
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
mReqs := int64(getFloat(stats, "total_requests"))
mToks := int64(getFloat(stats, "total_tokens"))
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
truncate(model, 28), mReqs, formatLargeNumber(mToks))
sb.WriteString(tableCellStyle.Render(mRow))
sb.WriteString("\n")
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
// Latency breakdown from details
sb.WriteString(m.renderLatencyBreakdown(stats))
}
}
}
}
}
}
sb.WriteString("\n")
return sb.String()
}
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
tokens, ok := dm["tokens"].(map[string]any)
if !ok {
continue
}
inputTotal += int64(getFloat(tokens, "input_tokens"))
outputTotal += int64(getFloat(tokens, "output_tokens"))
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
}
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
return ""
}
parts := []string{}
if inputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
}
if outputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
}
if cachedTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
}
if reasoningTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
}
return fmt.Sprintf(" │ %s\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var totalLatency int64
var count int
var minLatency, maxLatency int64
first := true
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
latencyMs := int64(getFloat(dm, "latency_ms"))
if latencyMs <= 0 {
continue
}
totalLatency += latencyMs
count++
if first {
minLatency = latencyMs
maxLatency = latencyMs
first = false
} else {
if latencyMs < minLatency {
minLatency = latencyMs
}
if latencyMs > maxLatency {
maxLatency = latencyMs
}
}
}
if count == 0 {
return ""
}
avgLatency := totalLatency / int64(count)
return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
avgLatency, minLatency, maxLatency)
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {
maxBarWidth = 10
}
// Sort keys
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
// Find max value
maxVal := float64(0)
for _, k := range keys {
v := getFloat(data, k)
if v > maxVal {
maxVal = v
}
}
if maxVal == 0 {
return ""
}
barStyle := lipgloss.NewStyle().Foreground(barColor)
var sb strings.Builder
labelWidth := 12
barAvail := maxBarWidth - labelWidth - 12
if barAvail < 5 {
barAvail = 5
}
for _, k := range keys {
v := getFloat(data, k)
barLen := int(v / maxVal * float64(barAvail))
if barLen < 1 && v > 0 {
barLen = 1
}
bar := strings.Repeat("█", barLen)
label := k
if len(label) > labelWidth {
label = label[:labelWidth]
}
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
labelWidth, label,
barStyle.Render(bar),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
))
}
return sb.String()
}
-134
View File
@@ -1,134 +0,0 @@
package tui
import (
"strings"
"testing"
)
func TestRenderLatencyBreakdown(t *testing.T) {
tests := []struct {
name string
modelStats map[string]any
wantEmpty bool
wantContains string
}{
{
name: "no details",
modelStats: map[string]any{},
wantEmpty: true,
},
{
name: "empty details",
modelStats: map[string]any{
"details": []any{},
},
wantEmpty: true,
},
{
name: "details with zero latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(0),
},
},
},
wantEmpty: true,
},
{
name: "single request with latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1500ms min 1500ms max 1500ms",
},
{
name: "multiple requests with varying latency",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(100),
},
map[string]any{
"latency_ms": float64(200),
},
map[string]any{
"latency_ms": float64(300),
},
},
},
wantEmpty: false,
wantContains: "avg 200ms min 100ms max 300ms",
},
{
name: "mixed valid and invalid latency values",
modelStats: map[string]any{
"details": []any{
map[string]any{
"latency_ms": float64(500),
},
map[string]any{
"latency_ms": float64(0),
},
map[string]any{
"latency_ms": float64(1500),
},
},
},
wantEmpty: false,
wantContains: "avg 1000ms min 500ms max 1500ms",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := usageTabModel{}
result := m.renderLatencyBreakdown(tt.modelStats)
if tt.wantEmpty {
if result != "" {
t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
}
return
}
if result == "" {
t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
return
}
if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
}
})
}
}
func TestUsageTimeTranslations(t *testing.T) {
prevLocale := CurrentLocale()
t.Cleanup(func() {
SetLocale(prevLocale)
})
tests := []struct {
locale string
want string
}{
{locale: "en", want: "Time"},
{locale: "zh", want: "时间"},
}
for _, tt := range tests {
t.Run(tt.locale, func(t *testing.T) {
SetLocale(tt.locale)
if got := T("usage_time"); got != tt.want {
t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
}
})
}
}
-484
View File
@@ -1,484 +0,0 @@
// Package usage provides usage tracking and logging functionality for the CLI Proxy API server.
// It includes plugins for monitoring API usage, token consumption, and other metrics
// to help with observability and billing purposes.
package usage
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
var statisticsEnabled atomic.Bool
func init() {
statisticsEnabled.Store(true)
coreusage.RegisterPlugin(NewLoggerPlugin())
}
// LoggerPlugin collects in-memory request statistics for usage analysis.
// It implements coreusage.Plugin to receive usage records emitted by the runtime.
type LoggerPlugin struct {
stats *RequestStatistics
}
// NewLoggerPlugin constructs a new logger plugin instance.
//
// Returns:
// - *LoggerPlugin: A new logger plugin instance wired to the shared statistics store.
func NewLoggerPlugin() *LoggerPlugin { return &LoggerPlugin{stats: defaultRequestStatistics} }
// HandleUsage implements coreusage.Plugin.
// It updates the in-memory statistics store whenever a usage record is received.
//
// Parameters:
// - ctx: The context for the usage record
// - record: The usage record to aggregate
func (p *LoggerPlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
if !statisticsEnabled.Load() {
return
}
if p == nil || p.stats == nil {
return
}
p.stats.Record(ctx, record)
}
// SetStatisticsEnabled toggles whether in-memory statistics are recorded.
func SetStatisticsEnabled(enabled bool) { statisticsEnabled.Store(enabled) }
// StatisticsEnabled reports the current recording state.
func StatisticsEnabled() bool { return statisticsEnabled.Load() }
// RequestStatistics maintains aggregated request metrics in memory.
type RequestStatistics struct {
mu sync.RWMutex
totalRequests int64
successCount int64
failureCount int64
totalTokens int64
apis map[string]*apiStats
requestsByDay map[string]int64
requestsByHour map[int]int64
tokensByDay map[string]int64
tokensByHour map[int]int64
}
// apiStats holds aggregated metrics for a single API key.
type apiStats struct {
TotalRequests int64
TotalTokens int64
Models map[string]*modelStats
}
// modelStats holds aggregated metrics for a specific model within an API.
type modelStats struct {
TotalRequests int64
TotalTokens int64
Details []RequestDetail
}
// RequestDetail stores the timestamp, latency, and token usage for a single request.
type RequestDetail struct {
Timestamp time.Time `json:"timestamp"`
LatencyMs int64 `json:"latency_ms"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
// TokenStats captures the token usage breakdown for a request.
type TokenStats struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
ReasoningTokens int64 `json:"reasoning_tokens"`
CachedTokens int64 `json:"cached_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
// StatisticsSnapshot represents an immutable view of the aggregated metrics.
type StatisticsSnapshot struct {
TotalRequests int64 `json:"total_requests"`
SuccessCount int64 `json:"success_count"`
FailureCount int64 `json:"failure_count"`
TotalTokens int64 `json:"total_tokens"`
APIs map[string]APISnapshot `json:"apis"`
RequestsByDay map[string]int64 `json:"requests_by_day"`
RequestsByHour map[string]int64 `json:"requests_by_hour"`
TokensByDay map[string]int64 `json:"tokens_by_day"`
TokensByHour map[string]int64 `json:"tokens_by_hour"`
}
// APISnapshot summarises metrics for a single API key.
type APISnapshot struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
Models map[string]ModelSnapshot `json:"models"`
}
// ModelSnapshot summarises metrics for a specific model.
type ModelSnapshot struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
Details []RequestDetail `json:"details"`
}
var defaultRequestStatistics = NewRequestStatistics()
// GetRequestStatistics returns the shared statistics store.
func GetRequestStatistics() *RequestStatistics { return defaultRequestStatistics }
// NewRequestStatistics constructs an empty statistics store.
func NewRequestStatistics() *RequestStatistics {
return &RequestStatistics{
apis: make(map[string]*apiStats),
requestsByDay: make(map[string]int64),
requestsByHour: make(map[int]int64),
tokensByDay: make(map[string]int64),
tokensByHour: make(map[int]int64),
}
}
// Record ingests a new usage record and updates the aggregates.
func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record) {
if s == nil {
return
}
if !statisticsEnabled.Load() {
return
}
timestamp := record.RequestedAt
if timestamp.IsZero() {
timestamp = time.Now()
}
detail := normaliseDetail(record.Detail)
totalTokens := detail.TotalTokens
statsKey := record.APIKey
if statsKey == "" {
statsKey = resolveAPIIdentifier(ctx, record)
}
failed := record.Failed
if !failed {
failed = !resolveSuccess(ctx)
}
success := !failed
modelName := record.Model
if modelName == "" {
modelName = "unknown"
}
dayKey := timestamp.Format("2006-01-02")
hourKey := timestamp.Hour()
s.mu.Lock()
defer s.mu.Unlock()
s.totalRequests++
if success {
s.successCount++
} else {
s.failureCount++
}
s.totalTokens += totalTokens
stats, ok := s.apis[statsKey]
if !ok {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[statsKey] = stats
}
s.updateAPIStats(stats, modelName, RequestDetail{
Timestamp: timestamp,
LatencyMs: normaliseLatency(record.Latency),
Source: record.Source,
AuthIndex: record.AuthIndex,
Tokens: detail,
Failed: failed,
})
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail RequestDetail) {
stats.TotalRequests++
stats.TotalTokens += detail.Tokens.TotalTokens
modelStatsValue, ok := stats.Models[model]
if !ok {
modelStatsValue = &modelStats{}
stats.Models[model] = modelStatsValue
}
modelStatsValue.TotalRequests++
modelStatsValue.TotalTokens += detail.Tokens.TotalTokens
modelStatsValue.Details = append(modelStatsValue.Details, detail)
}
// Snapshot returns a copy of the aggregated metrics for external consumption.
func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
result := StatisticsSnapshot{}
if s == nil {
return result
}
s.mu.RLock()
defer s.mu.RUnlock()
result.TotalRequests = s.totalRequests
result.SuccessCount = s.successCount
result.FailureCount = s.failureCount
result.TotalTokens = s.totalTokens
result.APIs = make(map[string]APISnapshot, len(s.apis))
for apiName, stats := range s.apis {
apiSnapshot := APISnapshot{
TotalRequests: stats.TotalRequests,
TotalTokens: stats.TotalTokens,
Models: make(map[string]ModelSnapshot, len(stats.Models)),
}
for modelName, modelStatsValue := range stats.Models {
requestDetails := make([]RequestDetail, len(modelStatsValue.Details))
copy(requestDetails, modelStatsValue.Details)
apiSnapshot.Models[modelName] = ModelSnapshot{
TotalRequests: modelStatsValue.TotalRequests,
TotalTokens: modelStatsValue.TotalTokens,
Details: requestDetails,
}
}
result.APIs[apiName] = apiSnapshot
}
result.RequestsByDay = make(map[string]int64, len(s.requestsByDay))
for k, v := range s.requestsByDay {
result.RequestsByDay[k] = v
}
result.RequestsByHour = make(map[string]int64, len(s.requestsByHour))
for hour, v := range s.requestsByHour {
key := formatHour(hour)
result.RequestsByHour[key] = v
}
result.TokensByDay = make(map[string]int64, len(s.tokensByDay))
for k, v := range s.tokensByDay {
result.TokensByDay[k] = v
}
result.TokensByHour = make(map[string]int64, len(s.tokensByHour))
for hour, v := range s.tokensByHour {
key := formatHour(hour)
result.TokensByHour[key] = v
}
return result
}
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.LatencyMs < 0 {
detail.LatencyMs = 0
}
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
path := ginCtx.FullPath()
if path == "" && ginCtx.Request != nil {
path = ginCtx.Request.URL.Path
}
method := ""
if ginCtx.Request != nil {
method = ginCtx.Request.Method
}
if path != "" {
if method != "" {
return method + " " + path
}
return path
}
}
}
if record.Provider != "" {
return record.Provider
}
return "unknown"
}
func resolveSuccess(ctx context.Context) bool {
if ctx == nil {
return true
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return true
}
status := ginCtx.Writer.Status()
if status == 0 {
return true
}
return status < httpStatusBadRequest
}
const httpStatusBadRequest = 400
func normaliseDetail(detail coreusage.Detail) TokenStats {
tokens := TokenStats{
InputTokens: detail.InputTokens,
OutputTokens: detail.OutputTokens,
ReasoningTokens: detail.ReasoningTokens,
CachedTokens: detail.CachedTokens,
TotalTokens: detail.TotalTokens,
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens + detail.CachedTokens
}
return tokens
}
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func normaliseLatency(latency time.Duration) int64 {
if latency <= 0 {
return 0
}
return latency.Milliseconds()
}
func formatHour(hour int) string {
if hour < 0 {
hour = 0
}
hour = hour % 24
return fmt.Sprintf("%02d", hour)
}
-96
View File
@@ -1,96 +0,0 @@
package usage
import (
"context"
"testing"
"time"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
func TestRequestStatisticsRecordIncludesLatency(t *testing.T) {
stats := NewRequestStatistics()
stats.Record(context.Background(), coreusage.Record{
APIKey: "test-key",
Model: "gpt-5.4",
RequestedAt: time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC),
Latency: 1500 * time.Millisecond,
Detail: coreusage.Detail{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
})
snapshot := stats.Snapshot()
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
if len(details) != 1 {
t.Fatalf("details len = %d, want 1", len(details))
}
if details[0].LatencyMs != 1500 {
t.Fatalf("latency_ms = %d, want 1500", details[0].LatencyMs)
}
}
func TestRequestStatisticsMergeSnapshotDedupIgnoresLatency(t *testing.T) {
stats := NewRequestStatistics()
timestamp := time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC)
first := StatisticsSnapshot{
APIs: map[string]APISnapshot{
"test-key": {
Models: map[string]ModelSnapshot{
"gpt-5.4": {
Details: []RequestDetail{{
Timestamp: timestamp,
LatencyMs: 0,
Source: "user@example.com",
AuthIndex: "0",
Tokens: TokenStats{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
}},
},
},
},
},
}
second := StatisticsSnapshot{
APIs: map[string]APISnapshot{
"test-key": {
Models: map[string]ModelSnapshot{
"gpt-5.4": {
Details: []RequestDetail{{
Timestamp: timestamp,
LatencyMs: 2500,
Source: "user@example.com",
AuthIndex: "0",
Tokens: TokenStats{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
},
}},
},
},
},
},
}
result := stats.MergeSnapshot(first)
if result.Added != 1 || result.Skipped != 0 {
t.Fatalf("first merge = %+v, want added=1 skipped=0", result)
}
result = stats.MergeSnapshot(second)
if result.Added != 0 || result.Skipped != 1 {
t.Fatalf("second merge = %+v, want added=0 skipped=1", result)
}
snapshot := stats.Snapshot()
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
if len(details) != 1 {
t.Fatalf("details len = %d, want 1", len(details))
}
}
+3
View File
@@ -39,6 +39,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
}
if oldCfg.RedisUsageQueueRetentionSeconds != newCfg.RedisUsageQueueRetentionSeconds {
changes = append(changes, fmt.Sprintf("redis-usage-queue-retention-seconds: %d -> %d", oldCfg.RedisUsageQueueRetentionSeconds, newCfg.RedisUsageQueueRetentionSeconds))
}
if oldCfg.DisableCooling != newCfg.DisableCooling {
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
}