Merge branch 'router-for-me:main' into my-fix
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
type apiKeyUsageEntry struct {
|
||||
Success int64 `json:"success"`
|
||||
Failed int64 `json:"failed"`
|
||||
RecentRequests []coreauth.RecentRequestBucket `json:"recent_requests"`
|
||||
}
|
||||
|
||||
func mergeRecentRequestBuckets(dst, src []coreauth.RecentRequestBucket) []coreauth.RecentRequestBucket {
|
||||
if len(dst) == 0 {
|
||||
return src
|
||||
}
|
||||
if len(src) == 0 {
|
||||
return dst
|
||||
}
|
||||
if len(dst) != len(src) {
|
||||
n := len(dst)
|
||||
if len(src) < n {
|
||||
n = len(src)
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
dst[i].Success += src[i].Success
|
||||
dst[i].Failed += src[i].Failed
|
||||
}
|
||||
return dst
|
||||
}
|
||||
for i := range dst {
|
||||
dst[i].Success += src[i].Success
|
||||
dst[i].Failed += src[i].Failed
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// GetAPIKeyUsage returns recent request buckets for all in-memory api_key auths,
|
||||
// grouped by provider and keyed by "base_url|api_key".
|
||||
func (h *Handler) GetAPIKeyUsage(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
manager := h.authManager
|
||||
h.mu.Unlock()
|
||||
if manager == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
out := make(map[string]map[string]apiKeyUsageEntry)
|
||||
for _, auth := range manager.List() {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
kind, apiKey := auth.AccountInfo()
|
||||
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
|
||||
continue
|
||||
}
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
continue
|
||||
}
|
||||
baseURL := ""
|
||||
if auth.Attributes != nil {
|
||||
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
if baseURL == "" {
|
||||
baseURL = strings.TrimSpace(auth.Attributes["base-url"])
|
||||
}
|
||||
}
|
||||
compositeKey := baseURL + "|" + apiKey
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
}
|
||||
|
||||
recent := auth.RecentRequestsSnapshot(now)
|
||||
providerBucket, ok := out[provider]
|
||||
if !ok {
|
||||
providerBucket = make(map[string]apiKeyUsageEntry)
|
||||
out[provider] = providerBucket
|
||||
}
|
||||
if existing, exists := providerBucket[compositeKey]; exists {
|
||||
existing.Success += auth.Success
|
||||
existing.Failed += auth.Failed
|
||||
existing.RecentRequests = mergeRecentRequestBuckets(existing.RecentRequests, recent)
|
||||
providerBucket[compositeKey] = existing
|
||||
continue
|
||||
}
|
||||
providerBucket[compositeKey] = apiKeyUsageEntry{
|
||||
Success: auth.Success,
|
||||
Failed: auth.Failed,
|
||||
RecentRequests: recent,
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, out)
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func sumRecentRequestBuckets(buckets []coreauth.RecentRequestBucket) (int64, int64) {
|
||||
var success int64
|
||||
var failed int64
|
||||
for _, bucket := range buckets {
|
||||
success += bucket.Success
|
||||
failed += bucket.Failed
|
||||
}
|
||||
return success, failed
|
||||
}
|
||||
|
||||
func TestGetAPIKeyUsage_GroupsByProviderAndAPIKey(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
if _, err := manager.Register(context.Background(), &coreauth.Auth{
|
||||
ID: "codex-auth",
|
||||
Provider: "codex",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "codex-key",
|
||||
"base_url": "https://codex.example.com",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register codex auth: %v", err)
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), &coreauth.Auth{
|
||||
ID: "claude-auth",
|
||||
Provider: "claude",
|
||||
Attributes: map[string]string{
|
||||
"api_key": "claude-key",
|
||||
"base_url": "https://claude.example.com",
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatalf("register claude auth: %v", err)
|
||||
}
|
||||
|
||||
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: true})
|
||||
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "codex-auth", Provider: "codex", Model: "gpt-5", Success: false})
|
||||
manager.MarkResult(context.Background(), coreauth.Result{AuthID: "claude-auth", Provider: "claude", Model: "claude-4", Success: true})
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/api-key-usage", nil)
|
||||
ginCtx.Request = req
|
||||
h.GetAPIKeyUsage(ginCtx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]map[string]apiKeyUsageEntry
|
||||
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
|
||||
t.Fatalf("decode payload: %v", err)
|
||||
}
|
||||
|
||||
codexEntry := payload["codex"]["https://codex.example.com|codex-key"]
|
||||
if codexEntry.Success != 1 || codexEntry.Failed != 1 {
|
||||
t.Fatalf("codex totals = %d/%d, want 1/1", codexEntry.Success, codexEntry.Failed)
|
||||
}
|
||||
if len(codexEntry.RecentRequests) != 20 {
|
||||
t.Fatalf("codex buckets len = %d, want 20", len(codexEntry.RecentRequests))
|
||||
}
|
||||
codexSuccess, codexFailed := sumRecentRequestBuckets(codexEntry.RecentRequests)
|
||||
if codexSuccess != 1 || codexFailed != 1 {
|
||||
t.Fatalf("codex totals = %d/%d, want 1/1", codexSuccess, codexFailed)
|
||||
}
|
||||
|
||||
claudeEntry := payload["claude"]["https://claude.example.com|claude-key"]
|
||||
if claudeEntry.Success != 1 || claudeEntry.Failed != 0 {
|
||||
t.Fatalf("claude totals = %d/%d, want 1/0", claudeEntry.Success, claudeEntry.Failed)
|
||||
}
|
||||
if len(claudeEntry.RecentRequests) != 20 {
|
||||
t.Fatalf("claude buckets len = %d, want 20", len(claudeEntry.RecentRequests))
|
||||
}
|
||||
claudeSuccess, claudeFailed := sumRecentRequestBuckets(claudeEntry.RecentRequests)
|
||||
if claudeSuccess != 1 || claudeFailed != 0 {
|
||||
t.Fatalf("claude totals = %d/%d, want 1/0", claudeSuccess, claudeFailed)
|
||||
}
|
||||
}
|
||||
@@ -388,6 +388,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
"source": "memory",
|
||||
"size": int64(0),
|
||||
}
|
||||
entry["success"] = auth.Success
|
||||
entry["failed"] = auth.Failed
|
||||
entry["recent_requests"] = auth.RecentRequestsSnapshot(time.Now())
|
||||
if email := authEmail(auth); email != "" {
|
||||
entry["email"] = email
|
||||
}
|
||||
@@ -2395,23 +2398,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
finalProjectID := projectID
|
||||
if responseProjectID != "" {
|
||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||
strings.EqualFold(tierID, "FREE") ||
|
||||
strings.EqualFold(tierID, "LEGACY")
|
||||
|
||||
if isFreeUser {
|
||||
// For free users, use backend project ID for preview model access
|
||||
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
|
||||
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
|
||||
finalProjectID = responseProjectID
|
||||
} else {
|
||||
// Pro users: keep requested project ID (original behavior)
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
}
|
||||
} else {
|
||||
finalProjectID = responseProjectID
|
||||
log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
|
||||
log.Infof("Using backend project ID: %s", responseProjectID)
|
||||
}
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
|
||||
storage.ProjectID = strings.TrimSpace(finalProjectID)
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestListAuthFiles_IncludesRecentRequestsBuckets(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
record := &coreauth.Auth{
|
||||
ID: "runtime-only-auth-1",
|
||||
Provider: "codex",
|
||||
Attributes: map[string]string{
|
||||
"runtime_only": "true",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"type": "codex",
|
||||
},
|
||||
}
|
||||
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
|
||||
t.Fatalf("failed to register auth record: %v", errRegister)
|
||||
}
|
||||
|
||||
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: t.TempDir()}, manager)
|
||||
h.tokenStore = &memoryAuthStore{}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
req := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
|
||||
ginCtx.Request = req
|
||||
|
||||
h.ListAuthFiles(ginCtx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload map[string]any
|
||||
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
|
||||
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
|
||||
}
|
||||
filesRaw, ok := payload["files"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected files array, payload: %#v", payload)
|
||||
}
|
||||
if len(filesRaw) != 1 {
|
||||
t.Fatalf("expected 1 auth entry, got %d", len(filesRaw))
|
||||
}
|
||||
|
||||
fileEntry, ok := filesRaw[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected file entry object, got %#v", filesRaw[0])
|
||||
}
|
||||
|
||||
if _, ok := fileEntry["success"].(float64); !ok {
|
||||
t.Fatalf("expected success number, got %#v", fileEntry["success"])
|
||||
}
|
||||
if _, ok := fileEntry["failed"].(float64); !ok {
|
||||
t.Fatalf("expected failed number, got %#v", fileEntry["failed"])
|
||||
}
|
||||
|
||||
recentRaw, ok := fileEntry["recent_requests"].([]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected recent_requests array, got %#v", fileEntry["recent_requests"])
|
||||
}
|
||||
if len(recentRaw) != 20 {
|
||||
t.Fatalf("expected 20 recent_requests buckets, got %d", len(recentRaw))
|
||||
}
|
||||
for idx, item := range recentRaw {
|
||||
bucket, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected bucket object at %d, got %#v", idx, item)
|
||||
}
|
||||
if _, ok := bucket["time"].(string); !ok {
|
||||
t.Fatalf("expected bucket time string at %d, got %#v", idx, bucket["time"])
|
||||
}
|
||||
if _, ok := bucket["success"].(float64); !ok {
|
||||
t.Fatalf("expected bucket success number at %d, got %#v", idx, bucket["success"])
|
||||
}
|
||||
if _, ok := bucket["failed"].(float64); !ok {
|
||||
t.Fatalf("expected bucket failed number at %d, got %#v", idx, bucket["failed"])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@@ -41,7 +40,6 @@ type Handler struct {
|
||||
attemptsMu sync.Mutex
|
||||
failedAttempts map[string]*attemptInfo // keyed by client IP
|
||||
authManager *coreauth.Manager
|
||||
usageStats *usage.RequestStatistics
|
||||
tokenStore coreauth.Store
|
||||
localPassword string
|
||||
allowRemoteOverride bool
|
||||
@@ -60,7 +58,6 @@ func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Man
|
||||
configFilePath: configFilePath,
|
||||
failedAttempts: make(map[string]*attemptInfo),
|
||||
authManager: manager,
|
||||
usageStats: usage.GetRequestStatistics(),
|
||||
tokenStore: sdkAuth.GetTokenStore(),
|
||||
allowRemoteOverride: envSecret != "",
|
||||
envSecret: envSecret,
|
||||
@@ -124,9 +121,6 @@ func (h *Handler) SetAuthManager(manager *coreauth.Manager) {
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetUsageStatistics allows replacing the usage statistics reference.
|
||||
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
|
||||
|
||||
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
|
||||
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
|
||||
|
||||
|
||||
@@ -2,78 +2,54 @@ package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
)
|
||||
|
||||
type usageExportPayload struct {
|
||||
Version int `json:"version"`
|
||||
ExportedAt time.Time `json:"exported_at"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
type usageQueueRecord []byte
|
||||
|
||||
type usageImportPayload struct {
|
||||
Version int `json:"version"`
|
||||
Usage usage.StatisticsSnapshot `json:"usage"`
|
||||
}
|
||||
|
||||
// GetUsageStatistics returns the in-memory request statistics snapshot.
|
||||
func (h *Handler) GetUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
func (r usageQueueRecord) MarshalJSON() ([]byte, error) {
|
||||
if json.Valid(r) {
|
||||
return append([]byte(nil), r...), nil
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"usage": snapshot,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
return json.Marshal(string(r))
|
||||
}
|
||||
|
||||
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
|
||||
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
|
||||
var snapshot usage.StatisticsSnapshot
|
||||
if h != nil && h.usageStats != nil {
|
||||
snapshot = h.usageStats.Snapshot()
|
||||
}
|
||||
c.JSON(http.StatusOK, usageExportPayload{
|
||||
Version: 1,
|
||||
ExportedAt: time.Now().UTC(),
|
||||
Usage: snapshot,
|
||||
})
|
||||
}
|
||||
|
||||
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
|
||||
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
|
||||
if h == nil || h.usageStats == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
|
||||
// GetUsageQueue pops queued usage records from the usage queue.
|
||||
func (h *Handler) GetUsageQueue(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||
count, errCount := parseUsageQueueCount(c.Query("count"))
|
||||
if errCount != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errCount.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var payload usageImportPayload
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
|
||||
return
|
||||
}
|
||||
if payload.Version != 0 && payload.Version != 1 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
|
||||
return
|
||||
items := redisqueue.PopOldest(count)
|
||||
records := make([]usageQueueRecord, 0, len(items))
|
||||
for _, item := range items {
|
||||
records = append(records, usageQueueRecord(append([]byte(nil), item...)))
|
||||
}
|
||||
|
||||
result := h.usageStats.MergeSnapshot(payload.Usage)
|
||||
snapshot := h.usageStats.Snapshot()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"added": result.Added,
|
||||
"skipped": result.Skipped,
|
||||
"total_requests": snapshot.TotalRequests,
|
||||
"failed_requests": snapshot.FailureCount,
|
||||
})
|
||||
c.JSON(http.StatusOK, records)
|
||||
}
|
||||
|
||||
func parseUsageQueueCount(value string) (int, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return 1, nil
|
||||
}
|
||||
count, errCount := strconv.Atoi(value)
|
||||
if errCount != nil || count <= 0 {
|
||||
return 0, errors.New("count must be a positive integer")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
)
|
||||
|
||||
func TestGetUsageQueuePopsRequestedRecords(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
withManagementUsageQueue(t, func() {
|
||||
redisqueue.Enqueue([]byte(`{"id":1}`))
|
||||
redisqueue.Enqueue([]byte(`{"id":2}`))
|
||||
redisqueue.Enqueue([]byte(`{"id":3}`))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
|
||||
|
||||
h := &Handler{}
|
||||
h.GetUsageQueue(ginCtx)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusOK, rec.Body.String())
|
||||
}
|
||||
|
||||
var payload []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal(rec.Body.Bytes(), &payload); errUnmarshal != nil {
|
||||
t.Fatalf("unmarshal response: %v", errUnmarshal)
|
||||
}
|
||||
if len(payload) != 2 {
|
||||
t.Fatalf("response records = %d, want 2", len(payload))
|
||||
}
|
||||
requireRecordID(t, payload[0], 1)
|
||||
requireRecordID(t, payload[1], 2)
|
||||
|
||||
remaining := redisqueue.PopOldest(10)
|
||||
if len(remaining) != 1 || string(remaining[0]) != `{"id":3}` {
|
||||
t.Fatalf("remaining queue = %q, want third item only", remaining)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetUsageQueueInvalidCountDoesNotPop(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
withManagementUsageQueue(t, func() {
|
||||
redisqueue.Enqueue([]byte(`{"id":1}`))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
ginCtx, _ := gin.CreateTestContext(rec)
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=0", nil)
|
||||
|
||||
h := &Handler{}
|
||||
h.GetUsageQueue(ginCtx)
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d body=%s", rec.Code, http.StatusBadRequest, rec.Body.String())
|
||||
}
|
||||
|
||||
remaining := redisqueue.PopOldest(10)
|
||||
if len(remaining) != 1 || string(remaining[0]) != `{"id":1}` {
|
||||
t.Fatalf("remaining queue = %q, want original item", remaining)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func withManagementUsageQueue(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
|
||||
prevQueueEnabled := redisqueue.Enabled()
|
||||
redisqueue.SetEnabled(false)
|
||||
redisqueue.SetEnabled(true)
|
||||
|
||||
defer func() {
|
||||
redisqueue.SetEnabled(false)
|
||||
redisqueue.SetEnabled(prevQueueEnabled)
|
||||
}()
|
||||
|
||||
fn()
|
||||
}
|
||||
|
||||
func requireRecordID(t *testing.T, raw json.RawMessage, want int) {
|
||||
t.Helper()
|
||||
|
||||
var payload struct {
|
||||
ID int `json:"id"`
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(raw, &payload); errUnmarshal != nil {
|
||||
t.Fatalf("unmarshal record: %v", errUnmarshal)
|
||||
}
|
||||
if payload.ID != want {
|
||||
t.Fatalf("record id = %d, want %d", payload.ID, want)
|
||||
}
|
||||
}
|
||||
@@ -123,6 +123,52 @@ func (rw *ResponseRewriter) Flush() {
|
||||
|
||||
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
|
||||
|
||||
// ampCanonicalToolNames maps tool names to the exact casing expected by the
|
||||
// Amp mode tool whitelist (case-sensitive match).
|
||||
var ampCanonicalToolNames = map[string]string{
|
||||
"bash": "Bash",
|
||||
"read": "Read",
|
||||
"grep": "Grep",
|
||||
"glob": "glob",
|
||||
"task": "Task",
|
||||
"check": "Check",
|
||||
}
|
||||
|
||||
// normalizeAmpToolNames fixes tool_use block names to match Amp's canonical casing.
|
||||
// Some upstream models return lowercase tool names (e.g. "bash" instead of "Bash")
|
||||
// which causes Amp's case-sensitive mode whitelist to reject them.
|
||||
func normalizeAmpToolNames(data []byte) []byte {
|
||||
// Non-streaming: content[].name in tool_use blocks
|
||||
for index, block := range gjson.GetBytes(data, "content").Array() {
|
||||
if block.Get("type").String() != "tool_use" {
|
||||
continue
|
||||
}
|
||||
name := block.Get("name").String()
|
||||
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||
path := fmt.Sprintf("content.%d.name", index)
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, path, canonical)
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to normalize tool name %q to %q: %v", name, canonical, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming: content_block.name in content_block_start events
|
||||
if gjson.GetBytes(data, "content_block.type").String() == "tool_use" {
|
||||
name := gjson.GetBytes(data, "content_block.name").String()
|
||||
if canonical, ok := ampCanonicalToolNames[strings.ToLower(name)]; ok && name != canonical {
|
||||
var err error
|
||||
data, err = sjson.SetBytes(data, "content_block.name", canonical)
|
||||
if err != nil {
|
||||
log.Warnf("Amp ResponseRewriter: failed to normalize streaming tool name %q to %q: %v", name, canonical, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks
|
||||
// in API responses so that the Amp TUI does not crash on P.signature.length.
|
||||
func ensureAmpSignature(data []byte) []byte {
|
||||
@@ -179,6 +225,7 @@ func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte {
|
||||
|
||||
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
data = ensureAmpSignature(data)
|
||||
data = normalizeAmpToolNames(data)
|
||||
data = rw.suppressAmpThinking(data)
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
@@ -278,6 +325,9 @@ func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte {
|
||||
// Inject empty signature where needed
|
||||
data = ensureAmpSignature(data)
|
||||
|
||||
// Normalize tool names to canonical casing
|
||||
data = normalizeAmpToolNames(data)
|
||||
|
||||
// Rewrite model name
|
||||
if rw.originalModel != "" {
|
||||
for _, path := range modelFieldPaths {
|
||||
|
||||
@@ -175,6 +175,57 @@ func TestSanitizeAmpRequestBody_MixedInvalidThinkingAndToolUseSignature(t *testi
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_NonStreaming(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"bash","input":{"cmd":"ls"}},{"type":"tool_use","id":"toolu_02","name":"read","input":{"path":"/tmp"}},{"type":"text","text":"hello"}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if !contains(result, []byte(`"name":"Bash"`)) {
|
||||
t.Errorf("expected bash->Bash, got %s", string(result))
|
||||
}
|
||||
if !contains(result, []byte(`"name":"Read"`)) {
|
||||
t.Errorf("expected read->Read, got %s", string(result))
|
||||
}
|
||||
if contains(result, []byte(`"name":"bash"`)) {
|
||||
t.Errorf("expected lowercase bash to be replaced, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_Streaming(t *testing.T) {
|
||||
input := []byte(`{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","name":"grep","id":"toolu_01","input":{}}}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if !contains(result, []byte(`"name":"Grep"`)) {
|
||||
t.Errorf("expected grep->Grep in streaming, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_AlreadyCorrect(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification for correctly-cased tool, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_GlobPreserved(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"glob","input":{"pattern":"*.go"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected glob to remain lowercase, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAmpToolNames_UnknownToolUntouched(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"edit_file","input":{"path":"/tmp/x"}}]}`)
|
||||
result := normalizeAmpToolNames(input)
|
||||
|
||||
if string(result) != string(input) {
|
||||
t.Errorf("expected no modification for unknown tool, got %s", string(result))
|
||||
}
|
||||
}
|
||||
|
||||
func contains(data, substr []byte) bool {
|
||||
for i := 0; i <= len(data)-len(substr); i++ {
|
||||
if string(data[i:i+len(substr)]) == string(substr) {
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
@@ -507,9 +506,6 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt := s.engine.Group("/v0/management")
|
||||
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
|
||||
{
|
||||
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
|
||||
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
|
||||
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
|
||||
mgmt.GET("/config", s.mgmt.GetConfig)
|
||||
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
|
||||
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
|
||||
@@ -554,6 +550,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
|
||||
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
|
||||
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
|
||||
mgmt.GET("/api-key-usage", s.mgmt.GetAPIKeyUsage)
|
||||
mgmt.GET("/usage-queue", s.mgmt.GetUsageQueue)
|
||||
|
||||
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
|
||||
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
|
||||
@@ -1000,7 +998,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
redisqueue.SetUsageStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.RedisUsageQueueRetentionSeconds != cfg.RedisUsageQueueRetentionSeconds {
|
||||
redisqueue.SetRetentionSeconds(cfg.RedisUsageQueueRetentionSeconds)
|
||||
}
|
||||
|
||||
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
gin "github.com/gin-gonic/gin"
|
||||
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/redisqueue"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
@@ -84,6 +85,68 @@ func TestHealthz(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestManagementUsageRequiresManagementAuthAndPopsArray(t *testing.T) {
|
||||
t.Setenv("MANAGEMENT_PASSWORD", "test-management-key")
|
||||
|
||||
prevQueueEnabled := redisqueue.Enabled()
|
||||
redisqueue.SetEnabled(false)
|
||||
t.Cleanup(func() {
|
||||
redisqueue.SetEnabled(false)
|
||||
redisqueue.SetEnabled(prevQueueEnabled)
|
||||
})
|
||||
|
||||
server := newTestServer(t)
|
||||
|
||||
redisqueue.Enqueue([]byte(`{"id":1}`))
|
||||
redisqueue.Enqueue([]byte(`{"id":2}`))
|
||||
|
||||
missingKeyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
|
||||
missingKeyRR := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(missingKeyRR, missingKeyReq)
|
||||
if missingKeyRR.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("missing key status = %d, want %d body=%s", missingKeyRR.Code, http.StatusUnauthorized, missingKeyRR.Body.String())
|
||||
}
|
||||
|
||||
legacyReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage?count=2", nil)
|
||||
legacyReq.Header.Set("Authorization", "Bearer test-management-key")
|
||||
legacyRR := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(legacyRR, legacyReq)
|
||||
if legacyRR.Code != http.StatusNotFound {
|
||||
t.Fatalf("legacy usage status = %d, want %d body=%s", legacyRR.Code, http.StatusNotFound, legacyRR.Body.String())
|
||||
}
|
||||
|
||||
authReq := httptest.NewRequest(http.MethodGet, "/v0/management/usage-queue?count=2", nil)
|
||||
authReq.Header.Set("Authorization", "Bearer test-management-key")
|
||||
authRR := httptest.NewRecorder()
|
||||
server.engine.ServeHTTP(authRR, authReq)
|
||||
if authRR.Code != http.StatusOK {
|
||||
t.Fatalf("authenticated status = %d, want %d body=%s", authRR.Code, http.StatusOK, authRR.Body.String())
|
||||
}
|
||||
|
||||
var payload []json.RawMessage
|
||||
if errUnmarshal := json.Unmarshal(authRR.Body.Bytes(), &payload); errUnmarshal != nil {
|
||||
t.Fatalf("unmarshal response: %v body=%s", errUnmarshal, authRR.Body.String())
|
||||
}
|
||||
if len(payload) != 2 {
|
||||
t.Fatalf("response records = %d, want 2", len(payload))
|
||||
}
|
||||
for i, raw := range payload {
|
||||
var record struct {
|
||||
ID int `json:"id"`
|
||||
}
|
||||
if errUnmarshal := json.Unmarshal(raw, &record); errUnmarshal != nil {
|
||||
t.Fatalf("unmarshal record %d: %v", i, errUnmarshal)
|
||||
}
|
||||
if record.ID != i+1 {
|
||||
t.Fatalf("record %d id = %d, want %d", i, record.ID, i+1)
|
||||
}
|
||||
}
|
||||
|
||||
if remaining := redisqueue.PopOldest(1); len(remaining) != 0 {
|
||||
t.Fatalf("remaining queue = %q, want empty", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmpProviderModelRoutes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
||||
@@ -6,15 +6,18 @@ package claude
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// OAuth configuration constants for Claude/Anthropic
|
||||
@@ -23,8 +26,94 @@ const (
|
||||
TokenURL = "https://api.anthropic.com/v1/oauth/token"
|
||||
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
RedirectURI = "http://localhost:54545/callback"
|
||||
|
||||
claudeRefreshMinBackoff = 5 * time.Second
|
||||
claudeRefreshMaxBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
claudeRefreshGroup singleflight.Group
|
||||
claudeRefreshMu sync.Mutex
|
||||
claudeRefreshBlock = make(map[string]time.Time)
|
||||
)
|
||||
|
||||
type refreshHTTPError struct {
|
||||
status int
|
||||
message string
|
||||
retryable bool
|
||||
}
|
||||
|
||||
func (e *refreshHTTPError) Error() string {
|
||||
return fmt.Sprintf("token refresh failed with status %d: %s", e.status, e.message)
|
||||
}
|
||||
|
||||
func (e *refreshHTTPError) Retryable() bool {
|
||||
return e != nil && e.retryable
|
||||
}
|
||||
|
||||
func resetClaudeRefreshState() {
|
||||
claudeRefreshMu.Lock()
|
||||
defer claudeRefreshMu.Unlock()
|
||||
claudeRefreshBlock = make(map[string]time.Time)
|
||||
claudeRefreshGroup = singleflight.Group{}
|
||||
}
|
||||
|
||||
func claudeRefreshBlockedUntil(refreshToken string) time.Time {
|
||||
claudeRefreshMu.Lock()
|
||||
defer claudeRefreshMu.Unlock()
|
||||
return claudeRefreshBlock[refreshToken]
|
||||
}
|
||||
|
||||
func setClaudeRefreshBlockedUntil(refreshToken string, until time.Time) {
|
||||
claudeRefreshMu.Lock()
|
||||
defer claudeRefreshMu.Unlock()
|
||||
claudeRefreshBlock[refreshToken] = until
|
||||
}
|
||||
|
||||
func clearClaudeRefreshBlockedUntil(refreshToken string) {
|
||||
claudeRefreshMu.Lock()
|
||||
defer claudeRefreshMu.Unlock()
|
||||
delete(claudeRefreshBlock, refreshToken)
|
||||
}
|
||||
|
||||
func clampClaudeRefreshBackoff(d time.Duration) time.Duration {
|
||||
if d < claudeRefreshMinBackoff {
|
||||
return claudeRefreshMinBackoff
|
||||
}
|
||||
if d > claudeRefreshMaxBackoff {
|
||||
return claudeRefreshMaxBackoff
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func parseClaudeRetryAfter(resp *http.Response) time.Duration {
|
||||
if resp == nil {
|
||||
return claudeRefreshMinBackoff
|
||||
}
|
||||
if raw := strings.TrimSpace(resp.Header.Get("Retry-After")); raw != "" {
|
||||
if seconds, err := time.ParseDuration(raw + "s"); err == nil {
|
||||
return clampClaudeRefreshBackoff(seconds)
|
||||
}
|
||||
if when, err := http.ParseTime(raw); err == nil {
|
||||
return clampClaudeRefreshBackoff(time.Until(when))
|
||||
}
|
||||
}
|
||||
if raw := strings.TrimSpace(resp.Header.Get("Retry-After-Ms")); raw != "" {
|
||||
if ms, err := time.ParseDuration(raw + "ms"); err == nil {
|
||||
return clampClaudeRefreshBackoff(ms)
|
||||
}
|
||||
}
|
||||
return claudeRefreshMinBackoff
|
||||
}
|
||||
|
||||
func isClaudeRefreshRetryable(err error) bool {
|
||||
var httpErr *refreshHTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return httpErr.Retryable()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
|
||||
// It contains access token, refresh token, and associated user/organization information.
|
||||
type tokenResponse struct {
|
||||
@@ -242,6 +331,35 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
if refreshToken == "" {
|
||||
return nil, fmt.Errorf("refresh token is required")
|
||||
}
|
||||
if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
|
||||
return nil, &refreshHTTPError{
|
||||
status: http.StatusTooManyRequests,
|
||||
message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
|
||||
retryable: false,
|
||||
}
|
||||
}
|
||||
|
||||
result, err, _ := claudeRefreshGroup.Do(refreshToken, func() (interface{}, error) {
|
||||
return o.refreshTokensSingleFlight(context.WithoutCancel(ctx), refreshToken)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenData, ok := result.(*ClaudeTokenData)
|
||||
if !ok || tokenData == nil {
|
||||
return nil, fmt.Errorf("token refresh failed: invalid single-flight result")
|
||||
}
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
func (o *ClaudeAuth) refreshTokensSingleFlight(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
|
||||
if blockedUntil := claudeRefreshBlockedUntil(refreshToken); blockedUntil.After(time.Now()) {
|
||||
return nil, &refreshHTTPError{
|
||||
status: http.StatusTooManyRequests,
|
||||
message: fmt.Sprintf("refresh temporarily blocked until %s", blockedUntil.Format(time.RFC3339)),
|
||||
retryable: false,
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"client_id": ClientID,
|
||||
@@ -276,7 +394,17 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
||||
message := string(body)
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
retryAfter := parseClaudeRetryAfter(resp)
|
||||
setClaudeRefreshBlockedUntil(refreshToken, time.Now().Add(retryAfter))
|
||||
return nil, &refreshHTTPError{status: resp.StatusCode, message: message, retryable: false}
|
||||
}
|
||||
return nil, &refreshHTTPError{
|
||||
status: resp.StatusCode,
|
||||
message: message,
|
||||
retryable: resp.StatusCode >= http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
// log.Debugf("Token response: %s", string(body))
|
||||
@@ -287,6 +415,8 @@ func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*C
|
||||
}
|
||||
|
||||
// Create token data
|
||||
clearClaudeRefreshBlockedUntil(refreshToken)
|
||||
|
||||
return &ClaudeTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
@@ -348,6 +478,9 @@ func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken st
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
if !isClaudeRefreshRetryable(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
|
||||
@@ -0,0 +1,123 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func TestRefreshTokensWithRetry_429BlocksImmediateReplay(t *testing.T) {
|
||||
resetClaudeRefreshState()
|
||||
defer resetClaudeRefreshState()
|
||||
|
||||
var calls int32
|
||||
auth := &ClaudeAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":"rate_limited"}`)),
|
||||
Header: http.Header{"Retry-After": []string{"60"}},
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected 429 refresh error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "status 429") {
|
||||
t.Fatalf("expected status 429 in error, got %v", err)
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected 1 refresh attempt after 429, got %d", got)
|
||||
}
|
||||
|
||||
_, err = auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
|
||||
if err == nil {
|
||||
t.Fatalf("expected immediate blocked refresh error")
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected blocked retry to avoid a second refresh call, got %d attempts", got)
|
||||
}
|
||||
if blockedUntil := claudeRefreshBlockedUntil("dummy_refresh_token"); !blockedUntil.After(time.Now()) {
|
||||
t.Fatalf("expected blocked-until timestamp to be set, got %v", blockedUntil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokens_DeduplicatesConcurrentRefresh(t *testing.T) {
|
||||
resetClaudeRefreshState()
|
||||
defer resetClaudeRefreshState()
|
||||
|
||||
var calls int32
|
||||
started := make(chan struct{})
|
||||
release := make(chan struct{})
|
||||
var once sync.Once
|
||||
|
||||
auth := &ClaudeAuth{
|
||||
httpClient: &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
once.Do(func() { close(started) })
|
||||
<-release
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"access_token":"new-access",
|
||||
"refresh_token":"new-refresh",
|
||||
"token_type":"Bearer",
|
||||
"expires_in":3600,
|
||||
"account":{"email_address":"shared@example.com"}
|
||||
}`)),
|
||||
Header: make(http.Header),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
results := make(chan *ClaudeTokenData, 2)
|
||||
errs := make(chan error, 2)
|
||||
runRefresh := func() {
|
||||
td, err := auth.RefreshTokens(context.Background(), "shared-refresh-token")
|
||||
results <- td
|
||||
errs <- err
|
||||
}
|
||||
|
||||
go runRefresh()
|
||||
go runRefresh()
|
||||
|
||||
<-started
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
|
||||
}
|
||||
close(release)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-errs; err != nil {
|
||||
t.Fatalf("expected refresh to succeed, got %v", err)
|
||||
}
|
||||
td := <-results
|
||||
if td == nil || td.AccessToken != "new-access" {
|
||||
t.Fatalf("expected refreshed access token, got %#v", td)
|
||||
}
|
||||
}
|
||||
if got := atomic.LoadInt32(&calls); got != 1 {
|
||||
t.Fatalf("expected exactly 1 upstream refresh call, got %d", got)
|
||||
}
|
||||
}
|
||||
+3
-35
@@ -333,42 +333,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage
|
||||
finalProjectID := projectID
|
||||
if responseProjectID != "" {
|
||||
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
|
||||
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
|
||||
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
|
||||
strings.EqualFold(tierID, "FREE") ||
|
||||
strings.EqualFold(tierID, "LEGACY")
|
||||
|
||||
if isFreeUser {
|
||||
// Interactive prompt for free users
|
||||
fmt.Printf("\nGoogle returned a different project ID:\n")
|
||||
fmt.Printf(" Requested (frontend): %s\n", projectID)
|
||||
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
|
||||
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
|
||||
fmt.Printf(" This is normal for free tier users.\n\n")
|
||||
fmt.Printf("Which project ID would you like to use?\n")
|
||||
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
|
||||
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
|
||||
fmt.Printf("Enter choice [1]: ")
|
||||
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
choice, _ := reader.ReadString('\n')
|
||||
choice = strings.TrimSpace(choice)
|
||||
|
||||
if choice == "2" {
|
||||
log.Infof("Using frontend project ID: %s", projectID)
|
||||
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
|
||||
finalProjectID = projectID
|
||||
} else {
|
||||
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
} else {
|
||||
// Pro users: keep requested project ID (original behavior)
|
||||
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
|
||||
}
|
||||
} else {
|
||||
finalProjectID = responseProjectID
|
||||
log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID)
|
||||
log.Infof("Using backend project ID: %s", responseProjectID)
|
||||
}
|
||||
finalProjectID = responseProjectID
|
||||
}
|
||||
|
||||
storage.ProjectID = strings.TrimSpace(finalProjectID)
|
||||
|
||||
@@ -65,6 +65,11 @@ type Config struct {
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
// RedisUsageQueueRetentionSeconds controls how long (in seconds) usage queue items
|
||||
// are retained in memory for the Redis RESP interface (LPOP/RPOP).
|
||||
// Default: 60. Max: 3600.
|
||||
RedisUsageQueueRetentionSeconds int `yaml:"redis-usage-queue-retention-seconds" json:"redis-usage-queue-retention-seconds"`
|
||||
|
||||
// DisableCooling disables quota cooldown scheduling when true.
|
||||
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
|
||||
|
||||
@@ -609,6 +614,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.RedisUsageQueueRetentionSeconds = 60
|
||||
cfg.DisableCooling = false
|
||||
cfg.DisableImageGeneration = DisableImageGenerationOff
|
||||
cfg.Pprof.Enable = false
|
||||
@@ -671,6 +677,13 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
}
|
||||
|
||||
if cfg.RedisUsageQueueRetentionSeconds <= 0 {
|
||||
cfg.RedisUsageQueueRetentionSeconds = 60
|
||||
} else if cfg.RedisUsageQueueRetentionSeconds > 3600 {
|
||||
log.WithField("value", cfg.RedisUsageQueueRetentionSeconds).Warn("redis-usage-queue-retention-seconds too large; clamping to 3600")
|
||||
cfg.RedisUsageQueueRetentionSeconds = 3600
|
||||
}
|
||||
|
||||
if cfg.MaxRetryCredentials < 0 {
|
||||
cfg.MaxRetryCredentials = 0
|
||||
}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type endpointKey struct{}
|
||||
type responseStatusKey struct{}
|
||||
|
||||
type responseStatusHolder struct {
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
func WithEndpoint(ctx context.Context, endpoint string) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, endpointKey{}, endpoint)
|
||||
}
|
||||
|
||||
func GetEndpoint(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if endpoint, ok := ctx.Value(endpointKey{}).(string); ok {
|
||||
return endpoint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithResponseStatusHolder(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder); ok && holder != nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, responseStatusKey{}, &responseStatusHolder{})
|
||||
}
|
||||
|
||||
func SetResponseStatus(ctx context.Context, status int) {
|
||||
if ctx == nil || status <= 0 {
|
||||
return
|
||||
}
|
||||
holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
|
||||
if !ok || holder == nil {
|
||||
return
|
||||
}
|
||||
holder.status.Store(int32(status))
|
||||
}
|
||||
|
||||
func GetResponseStatus(ctx context.Context) int {
|
||||
if ctx == nil {
|
||||
return 0
|
||||
}
|
||||
holder, ok := ctx.Value(responseStatusKey{}).(*responseStatusHolder)
|
||||
if !ok || holder == nil {
|
||||
return 0
|
||||
}
|
||||
return int(holder.status.Load())
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
const (
|
||||
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
|
||||
GeminiCLIVersion = "0.31.0"
|
||||
GeminiCLIVersion = "0.34.0"
|
||||
|
||||
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
|
||||
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
|
||||
@@ -46,7 +46,7 @@ func GeminiCLIUserAgent(model string) string {
|
||||
if model == "" {
|
||||
model = "unknown"
|
||||
}
|
||||
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
|
||||
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
|
||||
}
|
||||
|
||||
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
|
||||
|
||||
@@ -3,13 +3,10 @@ package redisqueue
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
@@ -23,7 +20,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
if !Enabled() || !internalusage.StatisticsEnabled() {
|
||||
if !Enabled() || !UsageStatisticsEnabled() {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -36,6 +33,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
if modelName == "" {
|
||||
modelName = "unknown"
|
||||
}
|
||||
aliasName := strings.TrimSpace(record.Alias)
|
||||
if aliasName == "" {
|
||||
aliasName = modelName
|
||||
}
|
||||
provider := strings.TrimSpace(record.Provider)
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
@@ -46,13 +47,8 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
}
|
||||
apiKey := strings.TrimSpace(record.APIKey)
|
||||
requestID := strings.TrimSpace(internallogging.GetRequestID(ctx))
|
||||
if requestID == "" {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
|
||||
requestID = strings.TrimSpace(internallogging.GetGinRequestID(ginCtx))
|
||||
}
|
||||
}
|
||||
|
||||
tokens := internalusage.TokenStats{
|
||||
tokens := tokenStats{
|
||||
InputTokens: record.Detail.InputTokens,
|
||||
OutputTokens: record.Detail.OutputTokens,
|
||||
ReasoningTokens: record.Detail.ReasoningTokens,
|
||||
@@ -71,7 +67,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
failed = !resolveSuccess(ctx)
|
||||
}
|
||||
|
||||
detail := internalusage.RequestDetail{
|
||||
detail := requestDetail{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: record.Latency.Milliseconds(),
|
||||
Source: record.Source,
|
||||
@@ -81,9 +77,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(queuedUsageDetail{
|
||||
RequestDetail: detail,
|
||||
requestDetail: detail,
|
||||
Provider: provider,
|
||||
Model: modelName,
|
||||
Alias: aliasName,
|
||||
Endpoint: resolveEndpoint(ctx),
|
||||
AuthType: authType,
|
||||
APIKey: apiKey,
|
||||
@@ -96,50 +93,43 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
|
||||
}
|
||||
|
||||
type queuedUsageDetail struct {
|
||||
internalusage.RequestDetail
|
||||
requestDetail
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Alias string `json:"alias"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
AuthType string `json:"auth_type"`
|
||||
APIKey string `json:"api_key"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type requestDetail struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens tokenStats `json:"tokens"`
|
||||
Failed bool `json:"failed"`
|
||||
}
|
||||
|
||||
type tokenStats struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
ReasoningTokens int64 `json:"reasoning_tokens"`
|
||||
CachedTokens int64 `json:"cached_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func resolveSuccess(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return true
|
||||
}
|
||||
ginCtx, ok := ctx.Value("gin").(*gin.Context)
|
||||
if !ok || ginCtx == nil {
|
||||
return true
|
||||
}
|
||||
status := ginCtx.Writer.Status()
|
||||
status := internallogging.GetResponseStatus(ctx)
|
||||
if status == 0 {
|
||||
return true
|
||||
}
|
||||
return status < http.StatusBadRequest
|
||||
return status < httpStatusBadRequest
|
||||
}
|
||||
|
||||
func resolveEndpoint(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
ginCtx, ok := ctx.Value("gin").(*gin.Context)
|
||||
if !ok || ginCtx == nil || ginCtx.Request == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
path := strings.TrimSpace(ginCtx.FullPath())
|
||||
if path == "" && ginCtx.Request.URL != nil {
|
||||
path = strings.TrimSpace(ginCtx.Request.URL.Path)
|
||||
}
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
method := strings.TrimSpace(ginCtx.Request.Method)
|
||||
if method == "" {
|
||||
return path
|
||||
}
|
||||
return method + " " + path
|
||||
return strings.TrimSpace(internallogging.GetEndpoint(ctx))
|
||||
}
|
||||
|
||||
const httpStatusBadRequest = 400
|
||||
|
||||
@@ -10,20 +10,21 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
internalusage "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
|
||||
withEnabledQueue(t, func() {
|
||||
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
|
||||
internallogging.SetGinRequestID(ginCtx, "gin-request-id-ignored")
|
||||
ctx := context.WithValue(internallogging.WithRequestID(context.Background(), "ctx-request-id"), "gin", ginCtx)
|
||||
ctx := internallogging.WithRequestID(context.Background(), "ctx-request-id")
|
||||
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
|
||||
ctx = internallogging.WithResponseStatusHolder(ctx)
|
||||
internallogging.SetResponseStatus(ctx, http.StatusOK)
|
||||
|
||||
plugin := &usageQueuePlugin{}
|
||||
plugin.HandleUsage(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4",
|
||||
Alias: "client-gpt",
|
||||
APIKey: "test-key",
|
||||
AuthIndex: "0",
|
||||
AuthType: "apikey",
|
||||
@@ -40,6 +41,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
|
||||
payload := popSinglePayload(t)
|
||||
requireStringField(t, payload, "provider", "openai")
|
||||
requireStringField(t, payload, "model", "gpt-5.4")
|
||||
requireStringField(t, payload, "alias", "client-gpt")
|
||||
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
|
||||
requireStringField(t, payload, "auth_type", "apikey")
|
||||
requireStringField(t, payload, "request_id", "ctx-request-id")
|
||||
@@ -49,14 +51,16 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
|
||||
|
||||
func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t *testing.T) {
|
||||
withEnabledQueue(t, func() {
|
||||
ginCtx := newTestGinContext(t, http.MethodGet, "/v1/responses", http.StatusInternalServerError)
|
||||
internallogging.SetGinRequestID(ginCtx, "gin-request-id")
|
||||
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||
ctx := internallogging.WithRequestID(context.Background(), "gin-request-id")
|
||||
ctx = internallogging.WithEndpoint(ctx, "GET /v1/responses")
|
||||
ctx = internallogging.WithResponseStatusHolder(ctx)
|
||||
internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
|
||||
|
||||
plugin := &usageQueuePlugin{}
|
||||
plugin.HandleUsage(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4-mini",
|
||||
Alias: "client-mini",
|
||||
APIKey: "test-key",
|
||||
AuthIndex: "0",
|
||||
AuthType: "apikey",
|
||||
@@ -73,6 +77,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
|
||||
payload := popSinglePayload(t)
|
||||
requireStringField(t, payload, "provider", "openai")
|
||||
requireStringField(t, payload, "model", "gpt-5.4-mini")
|
||||
requireStringField(t, payload, "alias", "client-mini")
|
||||
requireStringField(t, payload, "endpoint", "GET /v1/responses")
|
||||
requireStringField(t, payload, "auth_type", "apikey")
|
||||
requireStringField(t, payload, "request_id", "gin-request-id")
|
||||
@@ -80,20 +85,63 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) {
|
||||
withEnabledQueue(t, func() {
|
||||
ginCtx := newTestGinContext(t, http.MethodPost, "/v1/chat/completions", http.StatusOK)
|
||||
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||
ctx = internallogging.WithRequestID(ctx, "ctx-request-id")
|
||||
ctx = internallogging.WithEndpoint(ctx, "POST /v1/chat/completions")
|
||||
ctx = internallogging.WithResponseStatusHolder(ctx)
|
||||
internallogging.SetResponseStatus(ctx, http.StatusInternalServerError)
|
||||
|
||||
mgr := coreusage.NewManager(16)
|
||||
defer mgr.Stop()
|
||||
|
||||
mgr.Register(pluginFunc(func(_ context.Context, _ coreusage.Record) {
|
||||
ginCtx.Request = httptest.NewRequest(http.MethodGet, "http://example.com/v1/responses", nil)
|
||||
ginCtx.Status(http.StatusOK)
|
||||
}))
|
||||
mgr.Register(&usageQueuePlugin{})
|
||||
|
||||
mgr.Publish(ctx, coreusage.Record{
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.4",
|
||||
Alias: "client-gpt",
|
||||
APIKey: "test-key",
|
||||
AuthIndex: "0",
|
||||
AuthType: "apikey",
|
||||
Source: "user@example.com",
|
||||
RequestedAt: time.Date(2026, 4, 25, 0, 0, 0, 0, time.UTC),
|
||||
Latency: 1500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
})
|
||||
|
||||
payload := waitForSinglePayload(t, 2*time.Second)
|
||||
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
|
||||
requireStringField(t, payload, "alias", "client-gpt")
|
||||
requireStringField(t, payload, "request_id", "ctx-request-id")
|
||||
requireBoolField(t, payload, "failed", true)
|
||||
})
|
||||
}
|
||||
|
||||
func withEnabledQueue(t *testing.T, fn func()) {
|
||||
t.Helper()
|
||||
|
||||
prevQueueEnabled := Enabled()
|
||||
prevStatsEnabled := internalusage.StatisticsEnabled()
|
||||
prevUsageEnabled := UsageStatisticsEnabled()
|
||||
|
||||
SetEnabled(false)
|
||||
SetEnabled(true)
|
||||
internalusage.SetStatisticsEnabled(true)
|
||||
SetUsageStatisticsEnabled(true)
|
||||
|
||||
defer func() {
|
||||
SetEnabled(false)
|
||||
SetEnabled(prevQueueEnabled)
|
||||
internalusage.SetStatisticsEnabled(prevStatsEnabled)
|
||||
SetUsageStatisticsEnabled(prevUsageEnabled)
|
||||
}()
|
||||
|
||||
fn()
|
||||
@@ -127,6 +175,29 @@ func popSinglePayload(t *testing.T) map[string]json.RawMessage {
|
||||
return payload
|
||||
}
|
||||
|
||||
func waitForSinglePayload(t *testing.T, timeout time.Duration) map[string]json.RawMessage {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
items := PopOldest(10)
|
||||
if len(items) == 0 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("PopOldest() items = %d, want 1", len(items))
|
||||
}
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(items[0], &payload); err != nil {
|
||||
t.Fatalf("unmarshal payload: %v", err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
t.Fatalf("timeout waiting for queued payload")
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, want string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -143,6 +214,12 @@ func requireStringField(t *testing.T, payload map[string]json.RawMessage, key, w
|
||||
}
|
||||
}
|
||||
|
||||
type pluginFunc func(context.Context, coreusage.Record)
|
||||
|
||||
func (fn pluginFunc) HandleUsage(ctx context.Context, record coreusage.Record) {
|
||||
fn(ctx, record)
|
||||
}
|
||||
|
||||
func requireBoolField(t *testing.T, payload map[string]json.RawMessage, key string, want bool) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -6,7 +6,10 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const retentionWindow = time.Minute
|
||||
const (
|
||||
defaultRetentionSeconds int64 = 60
|
||||
maxRetentionSeconds int64 = 3600
|
||||
)
|
||||
|
||||
type queueItem struct {
|
||||
enqueuedAt time.Time
|
||||
@@ -20,10 +23,15 @@ type queue struct {
|
||||
}
|
||||
|
||||
var (
|
||||
enabled atomic.Bool
|
||||
global queue
|
||||
enabled atomic.Bool
|
||||
retentionSeconds atomic.Int64
|
||||
global queue
|
||||
)
|
||||
|
||||
func init() {
|
||||
retentionSeconds.Store(defaultRetentionSeconds)
|
||||
}
|
||||
|
||||
func SetEnabled(value bool) {
|
||||
enabled.Store(value)
|
||||
if !value {
|
||||
@@ -35,6 +43,16 @@ func Enabled() bool {
|
||||
return enabled.Load()
|
||||
}
|
||||
|
||||
func SetRetentionSeconds(value int) {
|
||||
normalized := int64(value)
|
||||
if normalized <= 0 {
|
||||
normalized = defaultRetentionSeconds
|
||||
} else if normalized > maxRetentionSeconds {
|
||||
normalized = maxRetentionSeconds
|
||||
}
|
||||
retentionSeconds.Store(normalized)
|
||||
}
|
||||
|
||||
func Enqueue(payload []byte) {
|
||||
if !Enabled() {
|
||||
return
|
||||
@@ -110,7 +128,11 @@ func (q *queue) pruneLocked(now time.Time) {
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := now.Add(-retentionWindow)
|
||||
windowSeconds := retentionSeconds.Load()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = defaultRetentionSeconds
|
||||
}
|
||||
cutoff := now.Add(-time.Duration(windowSeconds) * time.Second)
|
||||
for q.head < len(q.items) && q.items[q.head].enqueuedAt.Before(cutoff) {
|
||||
q.head++
|
||||
}
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package redisqueue
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
var usageStatisticsEnabled atomic.Bool
|
||||
|
||||
func init() {
|
||||
usageStatisticsEnabled.Store(true)
|
||||
}
|
||||
|
||||
// SetUsageStatisticsEnabled toggles whether usage records are enqueued into the redisqueue payload buffer.
|
||||
// This is controlled by the config field `usage-statistics-enabled` and the corresponding management API.
|
||||
func SetUsageStatisticsEnabled(enabled bool) { usageStatisticsEnabled.Store(enabled) }
|
||||
|
||||
// UsageStatisticsEnabled reports whether the usage queue plugin should publish records.
|
||||
func UsageStatisticsEnabled() bool { return usageStatisticsEnabled.Load() }
|
||||
@@ -285,7 +285,10 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
if event.Err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return false
|
||||
}
|
||||
switch event.Type {
|
||||
@@ -303,7 +306,11 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -319,14 +326,21 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON(lines[i])}:
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
reporter.Publish(ctx, helps.ParseGeminiUsage(event.Payload))
|
||||
return false
|
||||
case wsrelay.MessageTypeError:
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, event.Err)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -1357,17 +1357,28 @@ attemptLoop:
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: tail[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
} else {
|
||||
reporter.EnsurePublished(ctx)
|
||||
}
|
||||
|
||||
@@ -65,14 +65,13 @@ var oauthToolRenameMap = map[string]string{
|
||||
"notebookedit": "NotebookEdit",
|
||||
}
|
||||
|
||||
// oauthToolRenameReverseMap is the inverse of oauthToolRenameMap for response decoding.
|
||||
var oauthToolRenameReverseMap = func() map[string]string {
|
||||
m := make(map[string]string, len(oauthToolRenameMap))
|
||||
for k, v := range oauthToolRenameMap {
|
||||
m[v] = k
|
||||
}
|
||||
return m
|
||||
}()
|
||||
// The reverse map is now computed per-request in remapOAuthToolNames so that
|
||||
// only names the client actually caused us to rewrite are restored on the
|
||||
// response. A global reverse map — as used previously — corrupted responses
|
||||
// for clients that sent mixed casing (e.g. Amp CLI sends `Bash` TitleCase
|
||||
// alongside `glob` lowercase; the request flagged renames via `glob→Glob`,
|
||||
// then the global reverse map incorrectly rewrote every `Bash` in the
|
||||
// response to `bash`, causing Amp to reject the tool_use as unknown).
|
||||
|
||||
// oauthToolsToRemove lists tool names that must be stripped from OAuth requests
|
||||
// even after remapping. Currently empty — all tools are mapped instead of removed.
|
||||
@@ -192,15 +191,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
oauthToken := isClaudeOAuthToken(apiKey)
|
||||
oauthToolNamesRemapped := false
|
||||
if oauthToken && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
// Remap third-party tool names to Claude Code equivalents and remove
|
||||
// tools without official counterparts. This prevents Anthropic from
|
||||
// fingerprinting the request as third-party via tool naming patterns.
|
||||
var oauthToolNamesReverseMap map[string]string
|
||||
if oauthToken {
|
||||
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
|
||||
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
|
||||
}
|
||||
// Enable cch signing by default for OAuth tokens (not just experimental flag).
|
||||
// Claude Code always computes cch; missing or invalid cch is a detectable fingerprint.
|
||||
@@ -285,6 +278,10 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if stream {
|
||||
if errValidate := validateClaudeStreamingResponse(data); errValidate != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errValidate)
|
||||
return resp, errValidate
|
||||
}
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
@@ -294,13 +291,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
} else {
|
||||
reporter.Publish(ctx, helps.ParseClaudeUsage(data))
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
|
||||
}
|
||||
// Reverse the OAuth tool name remap so the downstream client sees original names.
|
||||
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
|
||||
data = reverseRemapOAuthToolNames(data)
|
||||
}
|
||||
data = restoreClaudeOAuthToolNamesFromResponse(data, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(
|
||||
ctx,
|
||||
@@ -375,15 +366,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
bodyForTranslation := body
|
||||
bodyForUpstream := body
|
||||
oauthToken := isClaudeOAuthToken(apiKey)
|
||||
oauthToolNamesRemapped := false
|
||||
if oauthToken && !auth.ToolPrefixDisabled() {
|
||||
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
// Remap third-party tool names to Claude Code equivalents and remove
|
||||
// tools without official counterparts. This prevents Anthropic from
|
||||
// fingerprinting the request as third-party via tool naming patterns.
|
||||
var oauthToolNamesReverseMap map[string]string
|
||||
if oauthToken {
|
||||
bodyForUpstream, oauthToolNamesRemapped = remapOAuthToolNames(bodyForUpstream)
|
||||
bodyForUpstream, oauthToolNamesReverseMap = prepareClaudeOAuthToolNamesForUpstream(bodyForUpstream, claudeToolPrefix, auth.ToolPrefixDisabled())
|
||||
}
|
||||
// Enable cch signing by default for OAuth tokens (not just experimental flag).
|
||||
if oauthToken || experimentalCCHSigningEnabled(e.cfg, auth) {
|
||||
@@ -474,22 +459,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
|
||||
line = reverseRemapOAuthToolNamesFromStreamLine(line)
|
||||
}
|
||||
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
|
||||
// Forward the line as-is to preserve SSE format
|
||||
cloned := make([]byte, len(line)+1)
|
||||
copy(cloned, line)
|
||||
cloned[len(line)] = '\n'
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: cloned}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -504,12 +491,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := helps.ParseClaudeStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
|
||||
}
|
||||
if isClaudeOAuthToken(apiKey) && oauthToolNamesRemapped {
|
||||
line = reverseRemapOAuthToolNamesFromStreamLine(line)
|
||||
}
|
||||
line = restoreClaudeOAuthToolNamesFromStreamLine(line, claudeToolPrefix, auth.ToolPrefixDisabled(), oauthToolNamesReverseMap)
|
||||
chunks := sdktranslator.TranslateStream(
|
||||
ctx,
|
||||
to,
|
||||
@@ -521,18 +503,83 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
¶m,
|
||||
)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func validateClaudeStreamingResponse(data []byte) error {
|
||||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
|
||||
hasData := false
|
||||
hasMessageStart := false
|
||||
hasMessageDelta := false
|
||||
|
||||
for scanner.Scan() {
|
||||
line := bytes.TrimSpace(scanner.Bytes())
|
||||
if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
payload := bytes.TrimSpace(line[len("data:"):])
|
||||
if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
hasData = true
|
||||
if !gjson.ValidBytes(payload) {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned malformed stream data"}
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(payload)
|
||||
switch root.Get("type").String() {
|
||||
case "error":
|
||||
message := strings.TrimSpace(root.Get("error.message").String())
|
||||
if message == "" {
|
||||
message = strings.TrimSpace(root.Get("error.type").String())
|
||||
}
|
||||
if message == "" {
|
||||
message = "unknown upstream error"
|
||||
}
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned error event: " + message}
|
||||
case "message_start":
|
||||
message := root.Get("message")
|
||||
if strings.TrimSpace(message.Get("id").String()) == "" || strings.TrimSpace(message.Get("model").String()) == "" {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream message_start is missing id or model"}
|
||||
}
|
||||
hasMessageStart = true
|
||||
case "message_delta":
|
||||
hasMessageDelta = true
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
return errScan
|
||||
}
|
||||
if !hasData {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream returned empty stream response"}
|
||||
}
|
||||
if !hasMessageStart {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response is missing message_start"}
|
||||
}
|
||||
if !hasMessageDelta {
|
||||
return statusErr{code: http.StatusBadGateway, msg: "claude executor: upstream stream response ended before message completion"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
@@ -559,12 +606,8 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
|
||||
body = applyClaudeToolPrefix(body, claudeToolPrefix)
|
||||
}
|
||||
// Remap tool names for OAuth token requests to avoid third-party fingerprinting.
|
||||
if isClaudeOAuthToken(apiKey) {
|
||||
body, _ = remapOAuthToolNames(body)
|
||||
body, _ = prepareClaudeOAuthToolNamesForUpstream(body, claudeToolPrefix, auth.ToolPrefixDisabled())
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
|
||||
@@ -661,7 +704,7 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
|
||||
return auth, nil
|
||||
}
|
||||
svc := claudeauth.NewClaudeAuthWithProxyURL(e.cfg, auth.ProxyURL)
|
||||
td, err := svc.RefreshTokens(ctx, refreshToken)
|
||||
td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1004,6 +1047,36 @@ func isClaudeOAuthToken(apiKey string) bool {
|
||||
return strings.Contains(apiKey, "sk-ant-oat")
|
||||
}
|
||||
|
||||
// prepareClaudeOAuthToolNamesForUpstream applies the Claude OAuth tool-name
|
||||
// transforms in the same order across request paths. Remap runs before prefixing
|
||||
// so any future non-empty prefix still composes correctly with the per-request
|
||||
// reverse map.
|
||||
func prepareClaudeOAuthToolNamesForUpstream(body []byte, prefix string, prefixDisabled bool) ([]byte, map[string]string) {
|
||||
body, reverseMap := remapOAuthToolNames(body)
|
||||
if !prefixDisabled {
|
||||
body = applyClaudeToolPrefix(body, prefix)
|
||||
}
|
||||
return body, reverseMap
|
||||
}
|
||||
|
||||
// restoreClaudeOAuthToolNamesFromResponse undoes the Claude OAuth tool-name
|
||||
// transforms for non-stream responses in reverse order.
|
||||
func restoreClaudeOAuthToolNamesFromResponse(body []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
|
||||
if !prefixDisabled {
|
||||
body = stripClaudeToolPrefixFromResponse(body, prefix)
|
||||
}
|
||||
return reverseRemapOAuthToolNames(body, reverseMap)
|
||||
}
|
||||
|
||||
// restoreClaudeOAuthToolNamesFromStreamLine undoes the Claude OAuth tool-name
|
||||
// transforms for SSE lines in reverse order.
|
||||
func restoreClaudeOAuthToolNamesFromStreamLine(line []byte, prefix string, prefixDisabled bool, reverseMap map[string]string) []byte {
|
||||
if !prefixDisabled {
|
||||
line = stripClaudeToolPrefixFromStreamLine(line, prefix)
|
||||
}
|
||||
return reverseRemapOAuthToolNamesFromStreamLine(line, reverseMap)
|
||||
}
|
||||
|
||||
// remapOAuthToolNames renames third-party tool names to Claude Code equivalents
|
||||
// and removes tools without an official counterpart. This prevents Anthropic from
|
||||
// fingerprinting the request as a third-party client via tool naming patterns.
|
||||
@@ -1011,8 +1084,25 @@ func isClaudeOAuthToken(apiKey string) bool {
|
||||
// It operates on: tools[].name, tool_choice.name, and all tool_use/tool_reference
|
||||
// references in messages. Removed tools' corresponding tool_result blocks are preserved
|
||||
// (they just become orphaned, which is safe for Claude).
|
||||
func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
||||
renamed := false
|
||||
//
|
||||
// The returned map is keyed on the upstream (TitleCase) name and maps to the
|
||||
// client-supplied original name. Callers MUST pass this map to the reverse
|
||||
// functions so only names the client actually caused us to rewrite are restored
|
||||
// on the response. A global reverse map (the previous implementation) incorrectly
|
||||
// rewrote names the client originally sent in TitleCase (e.g. Amp CLI's `Bash`)
|
||||
// when any OTHER tool in the same request triggered a forward rename (e.g.
|
||||
// Amp's `glob`→`Glob`), because the global reverse map contained `Bash`→`bash`
|
||||
// regardless of what the client originally sent.
|
||||
func remapOAuthToolNames(body []byte) ([]byte, map[string]string) {
|
||||
reverseMap := make(map[string]string, len(oauthToolRenameMap))
|
||||
recordRename := func(original, renamed string) {
|
||||
// Preserve the first-seen original name if the same upstream name is
|
||||
// produced from multiple call sites; they all map back identically.
|
||||
if _, exists := reverseMap[renamed]; !exists {
|
||||
reverseMap[renamed] = original
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Rewrite tools array in a single pass (if present).
|
||||
// IMPORTANT: do not mutate names first and then rebuild from an older gjson
|
||||
// snapshot. gjson results are snapshots of the original bytes; rebuilding from a
|
||||
@@ -1045,7 +1135,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
||||
updatedTool, err := sjson.Set(toolJSON, "name", newName)
|
||||
if err == nil {
|
||||
toolJSON = updatedTool
|
||||
renamed = true
|
||||
recordRename(name, newName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1070,7 +1160,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
||||
body, _ = sjson.DeleteBytes(body, "tool_choice")
|
||||
} else if newName, ok := oauthToolRenameMap[tcName]; ok && newName != tcName {
|
||||
body, _ = sjson.SetBytes(body, "tool_choice.name", newName)
|
||||
renamed = true
|
||||
recordRename(tcName, newName)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1090,14 +1180,14 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
||||
if newName, ok := oauthToolRenameMap[name]; ok && newName != name {
|
||||
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, newName)
|
||||
renamed = true
|
||||
recordRename(name, newName)
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if newName, ok := oauthToolRenameMap[toolName]; ok && newName != toolName {
|
||||
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, path, newName)
|
||||
renamed = true
|
||||
recordRename(toolName, newName)
|
||||
}
|
||||
case "tool_result":
|
||||
// Handle nested tool_reference blocks inside tool_result.content[]
|
||||
@@ -1111,7 +1201,7 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
||||
if newName, ok := oauthToolRenameMap[nestedToolName]; ok && newName != nestedToolName {
|
||||
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
|
||||
body, _ = sjson.SetBytes(body, nestedPath, newName)
|
||||
renamed = true
|
||||
recordRename(nestedToolName, newName)
|
||||
}
|
||||
}
|
||||
return true
|
||||
@@ -1124,13 +1214,16 @@ func remapOAuthToolNames(body []byte) ([]byte, bool) {
|
||||
})
|
||||
}
|
||||
|
||||
return body, renamed
|
||||
return body, reverseMap
|
||||
}
|
||||
|
||||
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses.
|
||||
// It maps Claude Code TitleCase names back to the original lowercase names so the
|
||||
// downstream client receives tool names it recognizes.
|
||||
func reverseRemapOAuthToolNames(body []byte) []byte {
|
||||
// reverseRemapOAuthToolNames reverses the tool name mapping for non-stream responses
|
||||
// using the per-request map produced by remapOAuthToolNames. Names the client sent
|
||||
// that were NOT forward-renamed are passed through unchanged.
|
||||
func reverseRemapOAuthToolNames(body []byte, reverseMap map[string]string) []byte {
|
||||
if len(reverseMap) == 0 {
|
||||
return body
|
||||
}
|
||||
content := gjson.GetBytes(body, "content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
return body
|
||||
@@ -1140,13 +1233,13 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
|
||||
switch partType {
|
||||
case "tool_use":
|
||||
name := part.Get("name").String()
|
||||
if origName, ok := oauthToolRenameReverseMap[name]; ok {
|
||||
if origName, ok := reverseMap[name]; ok {
|
||||
path := fmt.Sprintf("content.%d.name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, origName)
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := part.Get("tool_name").String()
|
||||
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
|
||||
if origName, ok := reverseMap[toolName]; ok {
|
||||
path := fmt.Sprintf("content.%d.tool_name", index.Int())
|
||||
body, _ = sjson.SetBytes(body, path, origName)
|
||||
}
|
||||
@@ -1156,8 +1249,12 @@ func reverseRemapOAuthToolNames(body []byte) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE stream lines.
|
||||
func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
|
||||
// reverseRemapOAuthToolNamesFromStreamLine reverses the tool name mapping for SSE
|
||||
// stream lines, using the per-request reverseMap produced by remapOAuthToolNames.
|
||||
func reverseRemapOAuthToolNamesFromStreamLine(line []byte, reverseMap map[string]string) []byte {
|
||||
if len(reverseMap) == 0 {
|
||||
return line
|
||||
}
|
||||
payload := helps.JSONPayload(line)
|
||||
if len(payload) == 0 || !gjson.ValidBytes(payload) {
|
||||
return line
|
||||
@@ -1175,7 +1272,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
|
||||
switch blockType {
|
||||
case "tool_use":
|
||||
name := contentBlock.Get("name").String()
|
||||
if origName, ok := oauthToolRenameReverseMap[name]; ok {
|
||||
if origName, ok := reverseMap[name]; ok {
|
||||
updated, err = sjson.SetBytes(payload, "content_block.name", origName)
|
||||
if err != nil {
|
||||
return line
|
||||
@@ -1185,7 +1282,7 @@ func reverseRemapOAuthToolNamesFromStreamLine(line []byte) []byte {
|
||||
}
|
||||
case "tool_reference":
|
||||
toolName := contentBlock.Get("tool_name").String()
|
||||
if origName, ok := oauthToolRenameReverseMap[toolName]; ok {
|
||||
if origName, ok := reverseMap[toolName]; ok {
|
||||
updated, err = sjson.SetBytes(payload, "content_block.tool_name", origName)
|
||||
if err != nil {
|
||||
return line
|
||||
|
||||
@@ -936,6 +936,113 @@ func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsEmptyClaudeStream(t *testing.T) {
|
||||
_, err := executeOpenAIChatCompletionThroughClaude(t, "")
|
||||
if err == nil {
|
||||
t.Fatal("Execute error = nil, want empty stream error")
|
||||
}
|
||||
assertStatusErr(t, err, http.StatusBadGateway)
|
||||
if !strings.Contains(err.Error(), "empty stream response") {
|
||||
t.Fatalf("Execute error = %q, want empty stream response", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsClaudeErrorEvent(t *testing.T) {
|
||||
body := `data: {"type":"error","error":{"type":"overloaded_error","message":"upstream overloaded"}}` + "\n"
|
||||
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
|
||||
if err == nil {
|
||||
t.Fatal("Execute error = nil, want upstream error event")
|
||||
}
|
||||
assertStatusErr(t, err, http.StatusBadGateway)
|
||||
if !strings.Contains(err.Error(), "upstream overloaded") {
|
||||
t.Fatalf("Execute error = %q, want upstream overloaded", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamRejectsIncompleteClaudeStream(t *testing.T) {
|
||||
body := strings.Join([]string{
|
||||
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
|
||||
`data: {"type":"message_stop"}`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
_, err := executeOpenAIChatCompletionThroughClaude(t, body)
|
||||
if err == nil {
|
||||
t.Fatal("Execute error = nil, want incomplete stream error")
|
||||
}
|
||||
assertStatusErr(t, err, http.StatusBadGateway)
|
||||
if !strings.Contains(err.Error(), "ended before message completion") {
|
||||
t.Fatalf("Execute error = %q, want incomplete stream error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeExecutor_ExecuteOpenAINonStreamConvertsValidClaudeStream(t *testing.T) {
|
||||
body := strings.Join([]string{
|
||||
`event: message_start`,
|
||||
`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-3-5-sonnet-20241022"}}`,
|
||||
`event: content_block_delta`,
|
||||
`data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ok"}}`,
|
||||
`event: message_delta`,
|
||||
`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"input_tokens":2,"output_tokens":1}}`,
|
||||
`event: message_stop`,
|
||||
`data: {"type":"message_stop"}`,
|
||||
``,
|
||||
}, "\n")
|
||||
|
||||
resp, err := executeOpenAIChatCompletionThroughClaude(t, body)
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "id").String(); got != "msg_123" {
|
||||
t.Fatalf("response id = %q, want msg_123; payload=%s", got, string(resp.Payload))
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "model").String(); got != "claude-3-5-sonnet-20241022" {
|
||||
t.Fatalf("response model = %q, want claude-3-5-sonnet-20241022", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "choices.0.message.content").String(); got != "ok" {
|
||||
t.Fatalf("response content = %q, want ok", got)
|
||||
}
|
||||
if got := gjson.GetBytes(resp.Payload, "usage.total_tokens").Int(); got != 3 {
|
||||
t.Fatalf("usage.total_tokens = %d, want 3", got)
|
||||
}
|
||||
}
|
||||
|
||||
func executeOpenAIChatCompletionThroughClaude(t *testing.T, upstreamBody string) (cliproxyexecutor.Response, error) {
|
||||
t.Helper()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte(upstreamBody))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewClaudeExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"api_key": "key-123",
|
||||
"base_url": server.URL,
|
||||
}}
|
||||
payload := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"hi"}]}`)
|
||||
|
||||
return executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "claude-3-5-sonnet-20241022",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
})
|
||||
}
|
||||
|
||||
func assertStatusErr(t *testing.T, err error, want int) {
|
||||
t.Helper()
|
||||
|
||||
status, ok := err.(interface{ StatusCode() int })
|
||||
if !ok {
|
||||
t.Fatalf("error %T does not expose StatusCode", err)
|
||||
}
|
||||
if got := status.StatusCode(); got != want {
|
||||
t.Fatalf("StatusCode() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
|
||||
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
|
||||
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
|
||||
@@ -1989,19 +2096,16 @@ func TestNormalizeClaudeTemperatureForThinking_AfterForcedToolChoiceKeepsOrigina
|
||||
func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"Bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
out, renamed := remapOAuthToolNames(body)
|
||||
if renamed {
|
||||
t.Fatalf("renamed = true, want false")
|
||||
out, reverseMap := remapOAuthToolNames(body)
|
||||
if len(reverseMap) != 0 {
|
||||
t.Fatalf("reverseMap = %v, want empty", reverseMap)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
|
||||
}
|
||||
|
||||
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||
reversed := resp
|
||||
if renamed {
|
||||
reversed = reverseRemapOAuthToolNames(resp)
|
||||
}
|
||||
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
|
||||
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
|
||||
}
|
||||
@@ -2010,20 +2114,150 @@ func TestRemapOAuthToolNames_TitleCase_NoReverseNeeded(t *testing.T) {
|
||||
func TestRemapOAuthToolNames_Lowercase_ReverseApplied(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","description":"Run shell commands","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
|
||||
|
||||
out, renamed := remapOAuthToolNames(body)
|
||||
if !renamed {
|
||||
t.Fatalf("renamed = false, want true")
|
||||
out, reverseMap := remapOAuthToolNames(body)
|
||||
if reverseMap["Bash"] != "bash" {
|
||||
t.Fatalf("reverseMap = %v, want entry Bash->bash", reverseMap)
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "Bash")
|
||||
}
|
||||
|
||||
resp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||
reversed := resp
|
||||
if renamed {
|
||||
reversed = reverseRemapOAuthToolNames(resp)
|
||||
}
|
||||
reversed := reverseRemapOAuthToolNames(resp, reverseMap)
|
||||
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "bash" {
|
||||
t.Fatalf("content.0.name = %q, want %q", got, "bash")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed is the regression
|
||||
// test for a case where a single request contains both a TitleCase tool (which
|
||||
// must pass through unchanged) and a lowercase tool that we forward-rename.
|
||||
// Before the fix, triggering ANY forward rename caused the reverse pass to
|
||||
// lowercase every TitleCase tool in the response using a global reverse map,
|
||||
// corrupting tool names the client originally sent in TitleCase (notably Amp
|
||||
// CLI's `Bash`, which its registry lookup cannot find as `bash`).
|
||||
func TestRemapOAuthToolNames_MixedCase_OnlyRenamedToolsReversed(t *testing.T) {
|
||||
body := []byte(`{"tools":[` +
|
||||
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
|
||||
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
|
||||
`]}`)
|
||||
|
||||
out, reverseMap := remapOAuthToolNames(body)
|
||||
|
||||
// Forward: TitleCase `Bash` is not a forward-map key, must pass through.
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("tools.0.name = %q, want %q (TitleCase tool must not be renamed)", got, "Bash")
|
||||
}
|
||||
// Forward: `glob` is a forward-map key, upstream sees `Glob`.
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "Glob" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "Glob")
|
||||
}
|
||||
|
||||
// Reverse map records ONLY the rename that happened.
|
||||
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
|
||||
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
|
||||
}
|
||||
|
||||
// Upstream responds with a `Bash` tool_use. Since we never renamed `Bash`,
|
||||
// reverseRemap MUST leave it alone.
|
||||
bashResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"cmd":"ls"}}]}`)
|
||||
reversed := reverseRemapOAuthToolNames(bashResp, reverseMap)
|
||||
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("content.0.name = %q, want %q (Bash must be preserved; was never forward-renamed)", got, "Bash")
|
||||
}
|
||||
|
||||
// Upstream responds with a `Glob` tool_use. Since we renamed `glob`→`Glob`,
|
||||
// reverseRemap MUST restore the original `glob`.
|
||||
globResp := []byte(`{"content":[{"type":"tool_use","id":"toolu_02","name":"Glob","input":{"filePattern":"**/*.go"}}]}`)
|
||||
reversed = reverseRemapOAuthToolNames(globResp, reverseMap)
|
||||
if got := gjson.GetBytes(reversed, "content.0.name").String(); got != "glob" {
|
||||
t.Fatalf("content.0.name = %q, want %q (Glob must be restored to client's original `glob`)", got, "glob")
|
||||
}
|
||||
}
|
||||
|
||||
// TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap guards the
|
||||
// SSE streaming code path against the same mixed-case bug.
|
||||
func TestReverseRemapOAuthToolNamesFromStreamLine_HonorsPerRequestMap(t *testing.T) {
|
||||
reverseMap := map[string]string{"Glob": "glob"}
|
||||
|
||||
// Bash block was never renamed, must pass through as-is.
|
||||
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}}}`)
|
||||
out := reverseRemapOAuthToolNamesFromStreamLine(bashLine, reverseMap)
|
||||
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
|
||||
t.Fatalf("Bash should be preserved, got: %s", string(out))
|
||||
}
|
||||
if bytes.Contains(out, []byte(`"name":"bash"`)) {
|
||||
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
|
||||
}
|
||||
|
||||
// Glob block IS in the reverseMap, must be restored to `glob`.
|
||||
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"Glob","input":{}}}`)
|
||||
out = reverseRemapOAuthToolNamesFromStreamLine(globLine, reverseMap)
|
||||
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
|
||||
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareClaudeOAuthToolNamesForUpstream_MixedCaseWithPrefix(t *testing.T) {
|
||||
body := []byte(`{"tools":[` +
|
||||
`{"name":"Bash","input_schema":{"type":"object","properties":{"cmd":{"type":"string"}}}},` +
|
||||
`{"name":"glob","input_schema":{"type":"object","properties":{"filePattern":{"type":"string"}}}}` +
|
||||
`],"messages":[{"role":"assistant","content":[` +
|
||||
`{"type":"tool_use","id":"toolu_01","name":"Bash","input":{}},` +
|
||||
`{"type":"tool_use","id":"toolu_02","name":"glob","input":{}}` +
|
||||
`]}]}`)
|
||||
|
||||
out, reverseMap := prepareClaudeOAuthToolNamesForUpstream(body, "proxy_", false)
|
||||
|
||||
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Bash" {
|
||||
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Bash")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Glob" {
|
||||
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Glob")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Bash" {
|
||||
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Bash")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Glob" {
|
||||
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Glob")
|
||||
}
|
||||
if len(reverseMap) != 1 || reverseMap["Glob"] != "glob" {
|
||||
t.Fatalf("reverseMap = %v, want {Glob:glob}", reverseMap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreClaudeOAuthToolNamesFromResponse_MixedCaseWithPrefix(t *testing.T) {
|
||||
reverseMap := map[string]string{"Glob": "glob"}
|
||||
resp := []byte(`{"content":[` +
|
||||
`{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}},` +
|
||||
`{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}` +
|
||||
`]}`)
|
||||
|
||||
out := restoreClaudeOAuthToolNamesFromResponse(resp, "proxy_", false, reverseMap)
|
||||
|
||||
if got := gjson.GetBytes(out, "content.0.name").String(); got != "Bash" {
|
||||
t.Fatalf("content.0.name = %q, want %q", got, "Bash")
|
||||
}
|
||||
if got := gjson.GetBytes(out, "content.1.name").String(); got != "glob" {
|
||||
t.Fatalf("content.1.name = %q, want %q", got, "glob")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestoreClaudeOAuthToolNamesFromStreamLine_MixedCaseWithPrefix(t *testing.T) {
|
||||
reverseMap := map[string]string{"Glob": "glob"}
|
||||
|
||||
bashLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"proxy_Bash","input":{}}}`)
|
||||
out := restoreClaudeOAuthToolNamesFromStreamLine(bashLine, "proxy_", false, reverseMap)
|
||||
if !bytes.Contains(out, []byte(`"name":"Bash"`)) {
|
||||
t.Fatalf("Bash should be preserved, got: %s", string(out))
|
||||
}
|
||||
if bytes.Contains(out, []byte(`"name":"bash"`)) {
|
||||
t.Fatalf("Bash must not be lowercased, got: %s", string(out))
|
||||
}
|
||||
|
||||
globLine := []byte(`data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_02","name":"proxy_Glob","input":{}}}`)
|
||||
out = restoreClaudeOAuthToolNamesFromStreamLine(globLine, "proxy_", false, reverseMap)
|
||||
if !bytes.Contains(out, []byte(`"name":"glob"`)) {
|
||||
t.Fatalf("Glob should be restored to glob, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,8 +30,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
codexUserAgent = "codex-tui/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9 (codex-tui; 0.118.0)"
|
||||
codexOriginator = "codex-tui"
|
||||
codexUserAgent = "codex_cli_rs/0.118.0 (Mac OS 26.3.1; arm64) iTerm.app/3.6.9"
|
||||
codexOriginator = "codex_cli_rs"
|
||||
codexDefaultImageToolModel = "gpt-image-2"
|
||||
)
|
||||
|
||||
@@ -515,13 +515,20 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, translatedLine, ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
|
||||
@@ -188,7 +188,6 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", true)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body = normalizeCodexInstructions(body)
|
||||
@@ -776,6 +775,11 @@ func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
|
||||
parsed.Scheme = "ws"
|
||||
case "https":
|
||||
parsed.Scheme = "wss"
|
||||
default:
|
||||
return "", fmt.Errorf("codex websockets executor: unsupported responses websocket URL scheme %q", parsed.Scheme)
|
||||
}
|
||||
if strings.TrimSpace(parsed.Host) == "" {
|
||||
return "", fmt.Errorf("codex websockets executor: responses websocket URL host is empty")
|
||||
}
|
||||
return parsed.String(), nil
|
||||
}
|
||||
@@ -809,6 +813,7 @@ func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecuto
|
||||
|
||||
if cache.ID != "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
|
||||
setHeaderCasePreserved(headers, "session_id", cache.ID)
|
||||
headers.Set("Conversation_id", cache.ID)
|
||||
}
|
||||
|
||||
@@ -828,13 +833,19 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
ginHeaders = ginCtx.Request.Header.Clone()
|
||||
}
|
||||
|
||||
_, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||
isAPIKey := codexAuthUsesAPIKey(auth)
|
||||
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
|
||||
misc.EnsureHeader(headers, ginHeaders, "Version", "")
|
||||
if isAPIKey {
|
||||
ensureHeaderWithPriority(headers, ginHeaders, "User-Agent", "", "")
|
||||
} else {
|
||||
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
|
||||
}
|
||||
|
||||
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
|
||||
if betaHeader == "" && ginHeaders != nil {
|
||||
@@ -845,16 +856,9 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
}
|
||||
headers.Set("OpenAI-Beta", betaHeader)
|
||||
if strings.Contains(headers.Get("User-Agent"), "Mac OS") {
|
||||
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
|
||||
}
|
||||
headers.Del("User-Agent")
|
||||
|
||||
isAPIKey := false
|
||||
if auth != nil && auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
|
||||
isAPIKey = true
|
||||
}
|
||||
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", uuid.NewString())
|
||||
}
|
||||
ensureHeaderCasePreserved(headers, ginHeaders, "session_id", "", "")
|
||||
if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" {
|
||||
headers.Set("Originator", originator)
|
||||
} else if !isAPIKey {
|
||||
@@ -864,7 +868,7 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
if auth != nil && auth.Metadata != nil {
|
||||
if accountID, ok := auth.Metadata["account_id"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
|
||||
headers.Set("Chatgpt-Account-Id", trimmed)
|
||||
setHeaderCasePreserved(headers, "ChatGPT-Account-ID", trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -879,6 +883,77 @@ func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *
|
||||
return headers
|
||||
}
|
||||
|
||||
func codexAuthUsesAPIKey(auth *cliproxyauth.Auth) bool {
|
||||
if auth == nil || auth.Attributes == nil {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(auth.Attributes["api_key"]) != ""
|
||||
}
|
||||
|
||||
func ensureHeaderCasePreserved(target http.Header, source http.Header, key, configValue, fallbackValue string) {
|
||||
if target == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(headerValueCaseInsensitive(target, key)) != "" {
|
||||
return
|
||||
}
|
||||
if source != nil {
|
||||
if val := strings.TrimSpace(headerValueCaseInsensitive(source, key)); val != "" {
|
||||
setHeaderCasePreserved(target, key, val)
|
||||
return
|
||||
}
|
||||
}
|
||||
if val := strings.TrimSpace(configValue); val != "" {
|
||||
setHeaderCasePreserved(target, key, val)
|
||||
return
|
||||
}
|
||||
if val := strings.TrimSpace(fallbackValue); val != "" {
|
||||
setHeaderCasePreserved(target, key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func setHeaderCasePreserved(headers http.Header, key string, value string) {
|
||||
if headers == nil {
|
||||
return
|
||||
}
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
return
|
||||
}
|
||||
deleteHeaderCaseInsensitive(headers, key)
|
||||
headers[key] = []string{value}
|
||||
}
|
||||
|
||||
func headerValueCaseInsensitive(headers http.Header, key string) string {
|
||||
key = strings.TrimSpace(key)
|
||||
if headers == nil || key == "" {
|
||||
return ""
|
||||
}
|
||||
if val := strings.TrimSpace(headers.Get(key)); val != "" {
|
||||
return val
|
||||
}
|
||||
for existingKey, values := range headers {
|
||||
if !strings.EqualFold(existingKey, key) {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func deleteHeaderCaseInsensitive(headers http.Header, key string) {
|
||||
for existingKey := range headers {
|
||||
if strings.EqualFold(existingKey, key) {
|
||||
delete(headers, existingKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
|
||||
if cfg == nil || auth == nil {
|
||||
return "", ""
|
||||
@@ -962,25 +1037,55 @@ func parseCodexWebsocketError(payload []byte) (error, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
out := []byte(`{}`)
|
||||
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
|
||||
raw := errNode.Raw
|
||||
if errNode.Type == gjson.String {
|
||||
raw = errNode.Raw
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
|
||||
} else {
|
||||
out, _ = sjson.SetBytes(out, "error.type", "server_error")
|
||||
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
|
||||
}
|
||||
|
||||
out := buildCodexWebsocketErrorPayload(payload, status)
|
||||
headers := parseCodexWebsocketErrorHeaders(payload)
|
||||
statusError := statusErr{code: status, msg: string(out)}
|
||||
if retryAfter := parseCodexRetryAfter(status, out, time.Now()); retryAfter != nil {
|
||||
statusError.retryAfter = retryAfter
|
||||
} else if isCodexWebsocketConnectionLimitError(payload) {
|
||||
retryAfter := time.Duration(0)
|
||||
statusError.retryAfter = &retryAfter
|
||||
}
|
||||
return statusErrWithHeaders{
|
||||
statusErr: statusErr{code: status, msg: string(out)},
|
||||
statusErr: statusError,
|
||||
headers: headers,
|
||||
}, true
|
||||
}
|
||||
|
||||
func buildCodexWebsocketErrorPayload(payload []byte, status int) []byte {
|
||||
out := []byte(`{}`)
|
||||
out, _ = sjson.SetBytes(out, "status", status)
|
||||
|
||||
if bodyNode := gjson.GetBytes(payload, "body"); bodyNode.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "body", []byte(bodyNode.Raw))
|
||||
if bodyErrorNode := bodyNode.Get("error"); bodyErrorNode.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "error", []byte(bodyErrorNode.Raw))
|
||||
return out
|
||||
}
|
||||
}
|
||||
|
||||
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
|
||||
out, _ = sjson.SetRawBytes(out, "error", []byte(errNode.Raw))
|
||||
return out
|
||||
}
|
||||
|
||||
out, _ = sjson.SetBytes(out, "error.type", "server_error")
|
||||
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
|
||||
return out
|
||||
}
|
||||
|
||||
func isCodexWebsocketConnectionLimitError(payload []byte) bool {
|
||||
if len(payload) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, path := range []string{"error.code", "error.type", "body.error.code", "body.error.type", "code", "error"} {
|
||||
if strings.TrimSpace(gjson.GetBytes(payload, path).String()) == "websocket_connection_limit_reached" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
|
||||
headersNode := gjson.GetBytes(payload, "headers")
|
||||
if !headersNode.Exists() || !headersNode.IsObject() {
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
@@ -32,14 +38,80 @@ func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexWebsocketsExecutePreservesPreviousResponseIDUpstream(t *testing.T) {
|
||||
upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
capturedPayload := make(chan []byte, 1)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
t.Fatalf("request path = %s, want /responses", r.URL.Path)
|
||||
}
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("upgrade websocket: %v", err)
|
||||
}
|
||||
defer func() { _ = conn.Close() }()
|
||||
|
||||
msgType, payload, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Fatalf("read upstream websocket message: %v", err)
|
||||
}
|
||||
if msgType != websocket.TextMessage {
|
||||
t.Fatalf("message type = %d, want text", msgType)
|
||||
}
|
||||
capturedPayload <- bytes.Clone(payload)
|
||||
|
||||
completed := []byte(`{"type":"response.completed","response":{"id":"resp-2","output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, completed); errWrite != nil {
|
||||
t.Fatalf("write completed websocket message: %v", errWrite)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewCodexWebsocketsExecutor(&config.Config{SDKConfig: config.SDKConfig{DisableImageGeneration: config.DisableImageGenerationAll}})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{"api_key": "sk-test", "base_url": server.URL}}
|
||||
req := cliproxyexecutor.Request{
|
||||
Model: "gpt-5-codex",
|
||||
Payload: []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`),
|
||||
}
|
||||
opts := cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("codex")}
|
||||
|
||||
if _, err := exec.Execute(context.Background(), auth, req, opts); err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case payload := <-capturedPayload:
|
||||
if got := gjson.GetBytes(payload, "type").String(); got != "response.create" {
|
||||
t.Fatalf("upstream type = %s, want response.create; payload=%s", got, payload)
|
||||
}
|
||||
if got := gjson.GetBytes(payload, "previous_response_id").String(); got != "resp-1" {
|
||||
t.Fatalf("upstream previous_response_id = %s, want resp-1; payload=%s", got, payload)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timed out waiting for upstream websocket payload")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
|
||||
|
||||
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
|
||||
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
|
||||
}
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
if got := headers.Get("User-Agent"); got != codexUserAgent {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
|
||||
}
|
||||
if !strings.HasPrefix(codexUserAgent, codexOriginator+"/") {
|
||||
t.Fatalf("default Codex User-Agent = %s, want prefix %s/", codexUserAgent, codexOriginator)
|
||||
}
|
||||
if strings.HasPrefix(codexUserAgent, "codex-tui/") {
|
||||
t.Fatalf("default Codex User-Agent = %s, must not use stale codex-tui prefix", codexUserAgent)
|
||||
}
|
||||
if strings.Contains(codexUserAgent, "(codex-tui;") {
|
||||
t.Fatalf("default Codex User-Agent = %s, must not include stale codex-tui suffix", codexUserAgent)
|
||||
}
|
||||
if got := headers.Get("Originator"); got != codexOriginator {
|
||||
t.Fatalf("Originator = %s, want %s", got, codexOriginator)
|
||||
}
|
||||
if got := headers.Get("Version"); got != "" {
|
||||
t.Fatalf("Version = %q, want empty", got)
|
||||
@@ -62,9 +134,11 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
|
||||
}
|
||||
ctx := contextWithGinHeaders(map[string]string{
|
||||
"Originator": "Codex Desktop",
|
||||
"User-Agent": "codex_cli_rs/0.1.0",
|
||||
"Version": "0.115.0-alpha.27",
|
||||
"X-Codex-Turn-Metadata": `{"turn_id":"turn-1"}`,
|
||||
"X-Client-Request-Id": "019d2233-e240-7162-992d-38df0a2a0e0d",
|
||||
"session_id": "sess-client",
|
||||
})
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", nil)
|
||||
@@ -72,6 +146,9 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
|
||||
if got := headers.Get("Originator"); got != "Codex Desktop" {
|
||||
t.Fatalf("Originator = %s, want %s", got, "Codex Desktop")
|
||||
}
|
||||
if got := headers.Get("User-Agent"); got != "codex_cli_rs/0.1.0" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "codex_cli_rs/0.1.0")
|
||||
}
|
||||
if got := headers.Get("Version"); got != "0.115.0-alpha.27" {
|
||||
t.Fatalf("Version = %s, want %s", got, "0.115.0-alpha.27")
|
||||
}
|
||||
@@ -81,6 +158,12 @@ func TestApplyCodexWebsocketHeadersPassesThroughClientIdentityHeaders(t *testing
|
||||
if got := headers.Get("X-Client-Request-Id"); got != "019d2233-e240-7162-992d-38df0a2a0e0d" {
|
||||
t.Fatalf("X-Client-Request-Id = %s, want %s", got, "019d2233-e240-7162-992d-38df0a2a0e0d")
|
||||
}
|
||||
if got := headerValueCaseInsensitive(headers, "session_id"); got != "sess-client" {
|
||||
t.Fatalf("session_id = %s, want sess-client", got)
|
||||
}
|
||||
if _, ok := headers["session_id"]; !ok {
|
||||
t.Fatalf("expected lowercase session_id header key, got %#v", headers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
@@ -97,8 +180,8 @@ func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
|
||||
@@ -129,8 +212,8 @@ func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *
|
||||
|
||||
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
|
||||
|
||||
if gotVal := got.Get("User-Agent"); gotVal != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", gotVal)
|
||||
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
|
||||
}
|
||||
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
|
||||
@@ -155,8 +238,8 @@ func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testi
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "" {
|
||||
t.Fatalf("User-Agent = %s, want empty", got)
|
||||
if got := headers.Get("User-Agent"); got != "config-ua" {
|
||||
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
|
||||
}
|
||||
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
|
||||
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
|
||||
@@ -183,6 +266,131 @@ func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
|
||||
if got := headers.Get("x-codex-beta-features"); got != "" {
|
||||
t.Fatalf("x-codex-beta-features = %q, want empty", got)
|
||||
}
|
||||
if got := headers.Get("Originator"); got != "" {
|
||||
t.Fatalf("Originator = %s, want empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersPreservesExplicitAPIKeyUserAgent(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{Provider: "codex", Attributes: map[string]string{"api_key": "sk-test"}}
|
||||
ctx := contextWithGinHeaders(map[string]string{"User-Agent": "api-key-client/1.0", "Originator": "explicit-origin"})
|
||||
|
||||
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "sk-test", nil)
|
||||
|
||||
if got := headers.Get("User-Agent"); got != "api-key-client/1.0" {
|
||||
t.Fatalf("User-Agent = %s, want api-key-client/1.0", got)
|
||||
}
|
||||
if got := headers.Get("Originator"); got != "explicit-origin" {
|
||||
t.Fatalf("Originator = %s, want explicit-origin", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexPromptCacheHeadersSetsLowercaseSessionAndLegacyConversation(t *testing.T) {
|
||||
req := cliproxyexecutor.Request{Model: "gpt-5-codex", Payload: []byte(`{"prompt_cache_key":"cache-1"}`)}
|
||||
|
||||
_, headers := applyCodexPromptCacheHeaders("openai-response", req, []byte(`{"model":"gpt-5-codex"}`))
|
||||
|
||||
if got := headerValueCaseInsensitive(headers, "session_id"); got != "cache-1" {
|
||||
t.Fatalf("session_id = %s, want cache-1", got)
|
||||
}
|
||||
if _, ok := headers["session_id"]; !ok {
|
||||
t.Fatalf("expected lowercase session_id key, got %#v", headers)
|
||||
}
|
||||
if got := headers.Get("Conversation_id"); got != "cache-1" {
|
||||
t.Fatalf("Conversation_id = %s, want cache-1", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexWebsocketHeadersUsesCanonicalAccountHeader(t *testing.T) {
|
||||
auth := &cliproxyauth.Auth{Provider: "codex", Metadata: map[string]any{"account_id": "acct-1"}}
|
||||
|
||||
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", nil)
|
||||
|
||||
if got := headerValueCaseInsensitive(headers, "ChatGPT-Account-ID"); got != "acct-1" {
|
||||
t.Fatalf("ChatGPT-Account-ID = %s, want acct-1", got)
|
||||
}
|
||||
values, ok := headers["ChatGPT-Account-ID"]
|
||||
if !ok {
|
||||
t.Fatalf("expected exact ChatGPT-Account-ID key, got %#v", headers)
|
||||
}
|
||||
if len(values) != 1 || values[0] != "acct-1" {
|
||||
t.Fatalf("ChatGPT-Account-ID values = %#v, want [acct-1]", values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCodexResponsesWebsocketURLRequiresHTTPURL(t *testing.T) {
|
||||
if got, err := buildCodexResponsesWebsocketURL("https://example.com/backend/responses"); err != nil || got != "wss://example.com/backend/responses" {
|
||||
t.Fatalf("https URL = %q, %v; want wss URL", got, err)
|
||||
}
|
||||
if _, err := buildCodexResponsesWebsocketURL("ftp://example.com/responses"); err == nil {
|
||||
t.Fatalf("expected unsupported scheme error")
|
||||
}
|
||||
if _, err := buildCodexResponsesWebsocketURL("https:///responses"); err == nil {
|
||||
t.Fatalf("expected empty host error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexWebsocketErrorMarksConnectionLimitRetryable(t *testing.T) {
|
||||
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"error":{"code":"websocket_connection_limit_reached","message":"too many websockets"},"headers":{"retry-after":"1"}}`))
|
||||
if !ok {
|
||||
t.Fatalf("expected websocket error")
|
||||
}
|
||||
status, ok := err.(interface{ StatusCode() int })
|
||||
if !ok || status.StatusCode() != http.StatusTooManyRequests {
|
||||
t.Fatalf("status = %#v, want 429", err)
|
||||
}
|
||||
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
|
||||
if !ok || retryable.RetryAfter() == nil {
|
||||
t.Fatalf("expected retryable websocket connection limit error")
|
||||
}
|
||||
if got := *retryable.RetryAfter(); got != 0 {
|
||||
t.Fatalf("retryAfter = %v, want connection-limit fallback 0", got)
|
||||
}
|
||||
withHeaders, ok := err.(interface{ Headers() http.Header })
|
||||
if !ok || withHeaders.Headers().Get("retry-after") != "1" {
|
||||
t.Fatalf("headers = %#v, want retry-after", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexWebsocketErrorUsesUsageLimitRetryMetadata(t *testing.T) {
|
||||
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"type":"usage_limit_reached","message":"usage limit reached","resets_in_seconds":7}}}`))
|
||||
if !ok {
|
||||
t.Fatalf("expected websocket error")
|
||||
}
|
||||
|
||||
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
|
||||
if !ok || retryable.RetryAfter() == nil {
|
||||
t.Fatalf("expected retryable usage limit websocket error")
|
||||
}
|
||||
if got := *retryable.RetryAfter(); got != 7*time.Second {
|
||||
t.Fatalf("retryAfter = %v, want 7s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCodexWebsocketErrorPreservesWrappedBodyAndHeaders(t *testing.T) {
|
||||
err, ok := parseCodexWebsocketError([]byte(`{"type":"error","status":429,"body":{"error":{"code":"websocket_connection_limit_reached","type":"server_error","message":"too many websocket connections"}},"headers":{"x-request-id":"req-1"}}`))
|
||||
if !ok {
|
||||
t.Fatalf("expected websocket error")
|
||||
}
|
||||
|
||||
parsed := gjson.Parse(err.Error())
|
||||
if got := parsed.Get("status").Int(); got != http.StatusTooManyRequests {
|
||||
t.Fatalf("wrapped status = %d, want 429; payload=%s", got, err.Error())
|
||||
}
|
||||
if got := parsed.Get("body.error.code").String(); got != "websocket_connection_limit_reached" {
|
||||
t.Fatalf("wrapped body error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
|
||||
}
|
||||
if got := parsed.Get("error.code").String(); got != "websocket_connection_limit_reached" {
|
||||
t.Fatalf("surface error code = %s, want websocket_connection_limit_reached; payload=%s", got, err.Error())
|
||||
}
|
||||
retryable, ok := err.(interface{ RetryAfter() *time.Duration })
|
||||
if !ok || retryable.RetryAfter() == nil {
|
||||
t.Fatalf("expected body.error.code websocket connection limit to be retryable")
|
||||
}
|
||||
withHeaders, ok := err.(interface{ Headers() http.Header })
|
||||
if !ok || withHeaders.Headers().Get("x-request-id") != "req-1" {
|
||||
t.Fatalf("headers = %#v, want x-request-id", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
|
||||
|
||||
@@ -411,19 +411,30 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
reporter.EnsurePublished(ctx)
|
||||
@@ -434,7 +445,10 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errRead}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
@@ -442,12 +456,20 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: segments[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}(httpResp, append([]byte(nil), payload...), attemptModel)
|
||||
|
||||
|
||||
@@ -324,17 +324,28 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
|
||||
@@ -338,6 +338,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
|
||||
}
|
||||
|
||||
action := getVertexAction(baseModel, false)
|
||||
@@ -459,6 +460,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
|
||||
|
||||
action := getVertexAction(baseModel, false)
|
||||
if req.Metadata != nil {
|
||||
@@ -570,6 +572,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
baseURL := vertexBaseURL(location)
|
||||
@@ -656,17 +659,28 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
@@ -700,6 +714,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body = helps.StripVertexOpenAIResponsesToolCallIDs(body, from.String())
|
||||
|
||||
action := getVertexAction(baseModel, true)
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
@@ -786,17 +801,28 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: lines[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
@@ -818,6 +844,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
|
||||
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
|
||||
translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
@@ -907,6 +934,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
|
||||
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
|
||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
|
||||
translatedReq = helps.StripVertexOpenAIResponsesToolCallIDs(translatedReq, from.String())
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
type UsageReporter struct {
|
||||
provider string
|
||||
model string
|
||||
alias string
|
||||
authID string
|
||||
authIndex string
|
||||
authType string
|
||||
@@ -29,9 +30,14 @@ type UsageReporter struct {
|
||||
|
||||
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
|
||||
apiKey := APIKeyFromContext(ctx)
|
||||
alias := usage.RequestedModelAliasFromContext(ctx)
|
||||
if alias == "" {
|
||||
alias = model
|
||||
}
|
||||
reporter := &UsageReporter{
|
||||
provider: provider,
|
||||
model: model,
|
||||
alias: strings.TrimSpace(alias),
|
||||
requestedAt: time.Now(),
|
||||
apiKey: apiKey,
|
||||
source: resolveUsageSource(auth, apiKey),
|
||||
@@ -139,6 +145,7 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
|
||||
return usage.Record{
|
||||
Provider: r.provider,
|
||||
Model: model,
|
||||
Alias: r.alias,
|
||||
Source: r.source,
|
||||
APIKey: r.apiKey,
|
||||
AuthID: r.authID,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -107,6 +108,19 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
|
||||
ctx := usage.WithRequestedModelAlias(context.Background(), "client-gpt")
|
||||
reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
|
||||
|
||||
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
|
||||
if record.Model != "gpt-5.4" {
|
||||
t.Fatalf("model = %q, want %q", record.Model, "gpt-5.4")
|
||||
}
|
||||
if record.Alias != "client-gpt" {
|
||||
t.Fatalf("alias = %q, want %q", record.Alias, "client-gpt")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
|
||||
reporter := &UsageReporter{
|
||||
provider: "codex",
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package helps
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// StripVertexOpenAIResponsesToolCallIDs removes OpenAI Responses call IDs that
|
||||
// Vertex rejects in Gemini functionCall/functionResponse payloads.
|
||||
func StripVertexOpenAIResponsesToolCallIDs(payload []byte, sourceFormat string) []byte {
|
||||
if !strings.EqualFold(strings.TrimSpace(sourceFormat), "openai-response") {
|
||||
return payload
|
||||
}
|
||||
|
||||
contents := gjson.GetBytes(payload, "contents")
|
||||
if !contents.IsArray() {
|
||||
return payload
|
||||
}
|
||||
|
||||
out := payload
|
||||
for contentIndex, content := range contents.Array() {
|
||||
parts := content.Get("parts")
|
||||
if !parts.IsArray() {
|
||||
continue
|
||||
}
|
||||
for partIndex, part := range parts.Array() {
|
||||
if part.Get("functionCall.id").Exists() {
|
||||
if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionCall.id", contentIndex, partIndex)); errDelete == nil {
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
if part.Get("functionResponse.id").Exists() {
|
||||
if updated, errDelete := sjson.DeleteBytes(out, fmt.Sprintf("contents.%d.parts.%d.functionResponse.id", contentIndex, partIndex)); errDelete == nil {
|
||||
out = updated
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -290,17 +290,28 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: doneChunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
@@ -322,7 +333,17 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
out := body
|
||||
msgs := messages.Array()
|
||||
out, dropped, err := filterKimiEmptyAssistantMessages(body, msgs)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
if dropped > 0 {
|
||||
log.WithField("dropped_assistant_messages", dropped).Debug("kimi executor: dropped empty assistant messages")
|
||||
}
|
||||
|
||||
messages = gjson.GetBytes(out, "messages")
|
||||
msgs = messages.Array()
|
||||
pending := make([]string, 0)
|
||||
patched := 0
|
||||
patchedReasoning := 0
|
||||
@@ -340,7 +361,6 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
msgs := messages.Array()
|
||||
for msgIdx := range msgs {
|
||||
msg := msgs[msgIdx]
|
||||
role := strings.TrimSpace(msg.Get("role").String())
|
||||
@@ -428,6 +448,96 @@ func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func filterKimiEmptyAssistantMessages(body []byte, msgs []gjson.Result) ([]byte, int, error) {
|
||||
kept := make([]string, 0, len(msgs))
|
||||
dropped := 0
|
||||
for _, msg := range msgs {
|
||||
if shouldDropKimiAssistantMessage(msg) {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
kept = append(kept, msg.Raw)
|
||||
}
|
||||
if dropped == 0 {
|
||||
return body, 0, nil
|
||||
}
|
||||
|
||||
rawMessages := []byte("[" + strings.Join(kept, ",") + "]")
|
||||
out, err := sjson.SetRawBytes(body, "messages", rawMessages)
|
||||
if err != nil {
|
||||
return body, 0, fmt.Errorf("kimi executor: failed to drop empty assistant messages: %w", err)
|
||||
}
|
||||
return out, dropped, nil
|
||||
}
|
||||
|
||||
func shouldDropKimiAssistantMessage(msg gjson.Result) bool {
|
||||
if strings.TrimSpace(msg.Get("role").String()) != "assistant" {
|
||||
return false
|
||||
}
|
||||
if hasKimiToolCalls(msg) || hasKimiLegacyFunctionCall(msg) || hasKimiAssistantReasoning(msg) {
|
||||
return false
|
||||
}
|
||||
return isKimiAssistantContentEmpty(msg.Get("content"))
|
||||
}
|
||||
|
||||
func hasKimiToolCalls(msg gjson.Result) bool {
|
||||
toolCalls := msg.Get("tool_calls")
|
||||
return toolCalls.Exists() && toolCalls.IsArray() && len(toolCalls.Array()) > 0
|
||||
}
|
||||
|
||||
func hasKimiLegacyFunctionCall(msg gjson.Result) bool {
|
||||
functionCall := msg.Get("function_call")
|
||||
if !functionCall.Exists() || functionCall.Type == gjson.Null {
|
||||
return false
|
||||
}
|
||||
if functionCall.IsObject() && strings.TrimSpace(functionCall.Raw) == "{}" {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(functionCall.Raw) != ""
|
||||
}
|
||||
|
||||
func hasKimiAssistantReasoning(msg gjson.Result) bool {
|
||||
reasoning := msg.Get("reasoning_content")
|
||||
return reasoning.Exists() && strings.TrimSpace(reasoning.String()) != ""
|
||||
}
|
||||
|
||||
func isKimiAssistantContentEmpty(content gjson.Result) bool {
|
||||
if !content.Exists() || content.Type == gjson.Null {
|
||||
return true
|
||||
}
|
||||
if content.Type == gjson.String {
|
||||
return strings.TrimSpace(content.String()) == ""
|
||||
}
|
||||
if !content.IsArray() {
|
||||
return false
|
||||
}
|
||||
for _, part := range content.Array() {
|
||||
if !isKimiAssistantContentPartEmpty(part) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isKimiAssistantContentPartEmpty(part gjson.Result) bool {
|
||||
if !part.Exists() || part.Type == gjson.Null {
|
||||
return true
|
||||
}
|
||||
if part.Type == gjson.String {
|
||||
return strings.TrimSpace(part.String()) == ""
|
||||
}
|
||||
if !part.IsObject() {
|
||||
return false
|
||||
}
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
return strings.TrimSpace(text.String()) == ""
|
||||
}
|
||||
if strings.TrimSpace(part.Get("type").String()) == "text" {
|
||||
return true
|
||||
}
|
||||
return strings.TrimSpace(part.Raw) == "{}"
|
||||
}
|
||||
|
||||
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
|
||||
if hasLatest && strings.TrimSpace(latest) != "" {
|
||||
return latest
|
||||
|
||||
@@ -203,3 +203,70 @@ func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing
|
||||
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_DropsEmptyAssistantWithoutToolLink(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"user","content":"start"},
|
||||
{"role":"assistant","content":""},
|
||||
{"role":"assistant","content":" "},
|
||||
{"role":"assistant","content":"","tool_calls":null},
|
||||
{"role":"assistant","content":[{"type":"text","text":" "}]},
|
||||
{"role":"assistant"},
|
||||
{"role":"assistant","content":"keep"},
|
||||
{"role":"user","content":"next"}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(out, "messages").Array()
|
||||
if len(messages) != 3 {
|
||||
t.Fatalf("messages length = %d, want 3, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
|
||||
}
|
||||
if got := messages[0].Get("content").String(); got != "start" {
|
||||
t.Fatalf("messages.0.content = %q, want %q", got, "start")
|
||||
}
|
||||
if got := messages[1].Get("content").String(); got != "keep" {
|
||||
t.Fatalf("messages.1.content = %q, want %q", got, "keep")
|
||||
}
|
||||
if got := messages[2].Get("content").String(); got != "next" {
|
||||
t.Fatalf("messages.2.content = %q, want %q", got, "next")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeKimiToolMessageLinks_PreservesAssistantWithToolLinkOrReasoning(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"messages":[
|
||||
{"role":"assistant","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
|
||||
{"role":"assistant","content":"","function_call":{"name":"legacy_call","arguments":"{}"}},
|
||||
{"role":"assistant","content":"","reasoning_content":"thought"},
|
||||
{"role":"assistant","content":[{"type":"text","text":" visible "}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
out, err := normalizeKimiToolMessageLinks(body)
|
||||
if err != nil {
|
||||
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
|
||||
}
|
||||
|
||||
messages := gjson.GetBytes(out, "messages").Array()
|
||||
if len(messages) != 4 {
|
||||
t.Fatalf("messages length = %d, want 4, raw = %s", len(messages), gjson.GetBytes(out, "messages").Raw)
|
||||
}
|
||||
if !messages[0].Get("tool_calls").Exists() {
|
||||
t.Fatalf("messages.0.tool_calls should exist")
|
||||
}
|
||||
if !messages[1].Get("function_call").Exists() {
|
||||
t.Fatalf("messages.1.function_call should exist")
|
||||
}
|
||||
if got := messages[2].Get("reasoning_content").String(); got != "thought" {
|
||||
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "thought")
|
||||
}
|
||||
if got := messages[3].Get("content.0.text").String(); got != " visible " {
|
||||
t.Fatalf("messages.3.content.0.text = %q, want %q", got, " visible ")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +96,12 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
|
||||
@@ -105,11 +111,6 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
}
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + endpoint
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
|
||||
if err != nil {
|
||||
@@ -199,15 +200,16 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
translated = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel, requestPath)
|
||||
|
||||
// Request usage data in the final streaming chunk so that token statistics
|
||||
// are captured even when the upstream is an OpenAI-compatible provider.
|
||||
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
|
||||
@@ -281,32 +283,57 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
if detail, ok := helps.ParseOpenAIStreamUsage(line); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
if len(line) == 0 {
|
||||
trimmedLine := bytes.TrimSpace(line)
|
||||
if len(trimmedLine) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
if !bytes.HasPrefix(trimmedLine, []byte("data:")) {
|
||||
if bytes.HasPrefix(trimmedLine, []byte(":")) || bytes.HasPrefix(trimmedLine, []byte("event:")) ||
|
||||
bytes.HasPrefix(trimmedLine, []byte("id:")) || bytes.HasPrefix(trimmedLine, []byte("retry:")) {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(trimmedLine, []byte("{")) || bytes.HasPrefix(trimmedLine, []byte("[")) {
|
||||
streamErr := statusErr{code: http.StatusBadGateway, msg: string(trimmedLine)}
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, streamErr)
|
||||
reporter.PublishFailure(ctx)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: streamErr}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
// OpenAI-compatible streams must use SSE data lines.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(trimmedLine), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
} else {
|
||||
// In case the upstream close the stream without a terminal [DONE] marker.
|
||||
// Feed a synthetic done marker through the translator so pending
|
||||
// response.completed events are still emitted exactly once.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("data: [DONE]"), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ensure we record the request if no usage chunk was ever seen
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
@@ -56,3 +57,125 @@ func TestOpenAICompatExecutorCompactPassthrough(t *testing.T) {
|
||||
t.Fatalf("payload = %s", string(resp.Payload))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorPayloadOverrideWinsOverThinkingSuffix(t *testing.T) {
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
gotBody = body
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"id":"chatcmpl_1","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{
|
||||
Payload: config.PayloadConfig{
|
||||
Override: []config.PayloadRule{
|
||||
{
|
||||
Models: []config.PayloadModelRule{
|
||||
{Name: "custom-openai", Protocol: "openai"},
|
||||
},
|
||||
Params: map[string]any{
|
||||
"reasoning_effort": "low",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL + "/v1",
|
||||
"api_key": "test",
|
||||
}}
|
||||
payload := []byte(`{"model":"custom-openai(high)","messages":[{"role":"user","content":"hi"}]}`)
|
||||
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "custom-openai(high)",
|
||||
Payload: payload,
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute error: %v", err)
|
||||
}
|
||||
if got := gjson.GetBytes(gotBody, "reasoning_effort").String(); got != "low" {
|
||||
t.Fatalf("reasoning_effort = %q, want %q; body=%s", got, "low", string(gotBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorStreamRejectsPlainJSONAfterBlankLines(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: error\n"))
|
||||
_, _ = w.Write([]byte(`{"error":{"message":"upstream failed","type":"server_error"}}` + "\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL + "/v1",
|
||||
"api_key": "test",
|
||||
}}
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "openrouter-model",
|
||||
Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var gotErr error
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
gotErr = chunk.Err
|
||||
break
|
||||
}
|
||||
}
|
||||
if gotErr == nil {
|
||||
t.Fatalf("expected plain JSON stream error")
|
||||
}
|
||||
if status, ok := gotErr.(interface{ StatusCode() int }); !ok || status.StatusCode() != http.StatusBadGateway {
|
||||
t.Fatalf("stream error status = %v, want %d", gotErr, http.StatusBadGateway)
|
||||
}
|
||||
if !strings.Contains(gotErr.Error(), "upstream failed") {
|
||||
t.Fatalf("stream error = %v", gotErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatExecutorStreamSkipsKeepAliveUntilDataLine(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("\n\n: openrouter processing\n\nevent: ping\nid: 1\nretry: 1000\n"))
|
||||
_, _ = w.Write([]byte(`data: {"id":"chatcmpl_1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"hello"},"finish_reason":null}]}` + "\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
|
||||
auth := &cliproxyauth.Auth{Attributes: map[string]string{
|
||||
"base_url": server.URL + "/v1",
|
||||
"api_key": "test",
|
||||
}}
|
||||
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "openrouter-model",
|
||||
Payload: []byte(`{"model":"openrouter-model","messages":[{"role":"user","content":"hi"}],"stream":true}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FromString("openai"),
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStream error: %v", err)
|
||||
}
|
||||
|
||||
var got strings.Builder
|
||||
for chunk := range result.Chunks {
|
||||
if chunk.Err != nil {
|
||||
t.Fatalf("unexpected stream error: %v", chunk.Err)
|
||||
}
|
||||
got.Write(chunk.Payload)
|
||||
}
|
||||
if gjson.Get(got.String(), "choices.0.delta.content").String() != "hello" {
|
||||
t.Fatalf("stream payload = %s", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,10 +25,19 @@ type ConvertAnthropicResponseToOpenAIParams struct {
|
||||
CreatedAt int64
|
||||
ResponseID string
|
||||
FinishReason string
|
||||
Usage claudeUsageTokens
|
||||
// Tool calls accumulator for streaming
|
||||
ToolCallsAccumulator map[int]*ToolCallAccumulator
|
||||
}
|
||||
|
||||
type claudeUsageTokens struct {
|
||||
InputTokens int64
|
||||
OutputTokens int64
|
||||
CacheCreationInputTokens int64
|
||||
CacheReadInputTokens int64
|
||||
HasUsage bool
|
||||
}
|
||||
|
||||
// ToolCallAccumulator holds the state for accumulating tool call data
|
||||
type ToolCallAccumulator struct {
|
||||
ID string
|
||||
@@ -36,15 +45,30 @@ type ToolCallAccumulator struct {
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
func calculateClaudeUsageTokens(usage gjson.Result) (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
||||
inputTokens := usage.Get("input_tokens").Int()
|
||||
completionTokens = usage.Get("output_tokens").Int()
|
||||
cachedTokens = usage.Get("cache_read_input_tokens").Int()
|
||||
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
|
||||
func (u *claudeUsageTokens) Merge(usage gjson.Result) {
|
||||
if !usage.Exists() {
|
||||
return
|
||||
}
|
||||
u.HasUsage = true
|
||||
if inputTokens := usage.Get("input_tokens"); inputTokens.Exists() {
|
||||
u.InputTokens = inputTokens.Int()
|
||||
}
|
||||
if outputTokens := usage.Get("output_tokens"); outputTokens.Exists() {
|
||||
u.OutputTokens = outputTokens.Int()
|
||||
}
|
||||
if cacheCreationInputTokens := usage.Get("cache_creation_input_tokens"); cacheCreationInputTokens.Exists() {
|
||||
u.CacheCreationInputTokens = cacheCreationInputTokens.Int()
|
||||
}
|
||||
if cacheReadInputTokens := usage.Get("cache_read_input_tokens"); cacheReadInputTokens.Exists() {
|
||||
u.CacheReadInputTokens = cacheReadInputTokens.Int()
|
||||
}
|
||||
}
|
||||
|
||||
promptTokens = inputTokens + cacheCreationInputTokens + cachedTokens
|
||||
func (u claudeUsageTokens) OpenAIUsage() (promptTokens, completionTokens, totalTokens, cachedTokens int64) {
|
||||
cachedTokens = u.CacheReadInputTokens
|
||||
promptTokens = u.InputTokens + u.CacheCreationInputTokens + cachedTokens
|
||||
completionTokens = u.OutputTokens
|
||||
totalTokens = promptTokens + completionTokens
|
||||
|
||||
return promptTokens, completionTokens, totalTokens, cachedTokens
|
||||
}
|
||||
|
||||
@@ -112,6 +136,7 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||
}
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(message.Get("usage"))
|
||||
}
|
||||
return [][]byte{template}
|
||||
|
||||
@@ -215,7 +240,8 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
|
||||
|
||||
// Handle usage information for token counts
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
(*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.Merge(usage)
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := (*param).(*ConvertAnthropicResponseToOpenAIParams).Usage.OpenAIUsage()
|
||||
template, _ = sjson.SetBytes(template, "usage.prompt_tokens", promptTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.completion_tokens", completionTokens)
|
||||
template, _ = sjson.SetBytes(template, "usage.total_tokens", totalTokens)
|
||||
@@ -296,6 +322,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
var stopReason string
|
||||
var contentParts []string
|
||||
var reasoningParts []string
|
||||
usageTokens := claudeUsageTokens{}
|
||||
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
@@ -309,6 +336,7 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
messageID = message.Get("id").String()
|
||||
model = message.Get("model").String()
|
||||
createdAt = time.Now().Unix()
|
||||
usageTokens.Merge(message.Get("usage"))
|
||||
}
|
||||
|
||||
case "content_block_start":
|
||||
@@ -371,15 +399,19 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
}
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := calculateClaudeUsageTokens(usage)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
usageTokens.Merge(usage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if usageTokens.HasUsage {
|
||||
promptTokens, completionTokens, totalTokens, cachedTokens := usageTokens.OpenAIUsage()
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens", promptTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.completion_tokens", completionTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.total_tokens", totalTokens)
|
||||
out, _ = sjson.SetBytes(out, "usage.prompt_tokens_details.cached_tokens", cachedTokens)
|
||||
}
|
||||
|
||||
// Set basic response fields including message ID, creation time, and model
|
||||
out, _ = sjson.SetBytes(out, "id", messageID)
|
||||
out, _ = sjson.SetBytes(out, "created", createdAt)
|
||||
|
||||
@@ -37,6 +37,44 @@ func TestConvertClaudeResponseToOpenAI_StreamUsageIncludesCachedTokens(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeResponseToOpenAI_StreamUsageMergesMessageStartUsage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var param any
|
||||
|
||||
ConvertClaudeResponseToOpenAI(
|
||||
ctx,
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
nil,
|
||||
[]byte(`data: {"type":"message_start","message":{"id":"msg_123","model":"claude-opus-4-6","usage":{"input_tokens":13,"output_tokens":1,"cache_read_input_tokens":22000,"cache_creation_input_tokens":31}}}`),
|
||||
¶m,
|
||||
)
|
||||
out := ConvertClaudeResponseToOpenAI(
|
||||
ctx,
|
||||
"claude-opus-4-6",
|
||||
nil,
|
||||
nil,
|
||||
[]byte(`data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":4}}`),
|
||||
¶m,
|
||||
)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 chunk, got %d", len(out))
|
||||
}
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out[0], "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out[0], "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out[0], "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out[0], "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *testing.T) {
|
||||
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\"}}\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"input_tokens\":13,\"output_tokens\":4,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}\n")
|
||||
@@ -56,3 +94,23 @@ func TestConvertClaudeResponseToOpenAINonStream_UsageIncludesCachedTokens(t *tes
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeResponseToOpenAINonStream_UsageMergesMessageStartUsage(t *testing.T) {
|
||||
rawJSON := []byte("data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_123\",\"model\":\"claude-opus-4-6\",\"usage\":{\"input_tokens\":13,\"output_tokens\":1,\"cache_read_input_tokens\":22000,\"cache_creation_input_tokens\":31}}}\n" +
|
||||
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"end_turn\"},\"usage\":{\"output_tokens\":4}}\n")
|
||||
|
||||
out := ConvertClaudeResponseToOpenAINonStream(context.Background(), "", nil, nil, rawJSON, nil)
|
||||
|
||||
if gotPromptTokens := gjson.GetBytes(out, "usage.prompt_tokens").Int(); gotPromptTokens != 22044 {
|
||||
t.Fatalf("expected prompt_tokens %d, got %d", 22044, gotPromptTokens)
|
||||
}
|
||||
if gotCompletionTokens := gjson.GetBytes(out, "usage.completion_tokens").Int(); gotCompletionTokens != 4 {
|
||||
t.Fatalf("expected completion_tokens %d, got %d", 4, gotCompletionTokens)
|
||||
}
|
||||
if gotTotalTokens := gjson.GetBytes(out, "usage.total_tokens").Int(); gotTotalTokens != 22048 {
|
||||
t.Fatalf("expected total_tokens %d, got %d", 22048, gotTotalTokens)
|
||||
}
|
||||
if gotCachedTokens := gjson.GetBytes(out, "usage.prompt_tokens_details.cached_tokens").Int(); gotCachedTokens != 22000 {
|
||||
t.Fatalf("expected cached_tokens %d, got %d", 22000, gotCachedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,25 +339,21 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
})
|
||||
}
|
||||
|
||||
includedToolNames := map[string]struct{}{}
|
||||
toolNameMap := map[string]string{}
|
||||
|
||||
// tools mapping: parameters -> input_schema
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
toolsJSON := []byte("[]")
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
|
||||
if n := tool.Get("name"); n.Exists() {
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "name", n.String())
|
||||
convertedTools := convertResponsesToolToClaudeTools(tool, toolNameMap)
|
||||
for _, tJSON := range convertedTools {
|
||||
toolName := gjson.GetBytes(tJSON, "name").String()
|
||||
if toolName != "" {
|
||||
includedToolNames[toolName] = struct{}{}
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
|
||||
}
|
||||
if d := tool.Get("description"); d.Exists() {
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "description", d.String())
|
||||
}
|
||||
|
||||
if params := tool.Get("parameters"); params.Exists() {
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
|
||||
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", []byte(params.Raw))
|
||||
}
|
||||
|
||||
toolsJSON, _ = sjson.SetRawBytes(toolsJSON, "-1", tJSON)
|
||||
return true
|
||||
})
|
||||
if parsedTools := gjson.ParseBytes(toolsJSON); parsedTools.IsArray() && len(parsedTools.Array()) > 0 {
|
||||
@@ -375,14 +371,24 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
case "none":
|
||||
// Leave unset; implies no tools
|
||||
case "required":
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
if len(includedToolNames) > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", []byte(`{"type":"any"}`))
|
||||
}
|
||||
}
|
||||
case gjson.JSON:
|
||||
if toolChoice.Get("type").String() == "function" {
|
||||
fn := toolChoice.Get("function.name").String()
|
||||
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
|
||||
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
|
||||
if fn == "" {
|
||||
fn = toolChoice.Get("name").String()
|
||||
}
|
||||
if mappedName := toolNameMap[fn]; mappedName != "" {
|
||||
fn = mappedName
|
||||
}
|
||||
if _, ok := includedToolNames[fn]; ok {
|
||||
toolChoiceJSON := []byte(`{"name":"","type":"tool"}`)
|
||||
toolChoiceJSON, _ = sjson.SetBytes(toolChoiceJSON, "name", fn)
|
||||
out, _ = sjson.SetRawBytes(out, "tool_choice", toolChoiceJSON)
|
||||
}
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -391,3 +397,167 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func convertResponsesToolToClaudeTools(tool gjson.Result, toolNameMap map[string]string) [][]byte {
|
||||
toolType := strings.TrimSpace(tool.Get("type").String())
|
||||
switch toolType {
|
||||
case "", "function":
|
||||
if tJSON, ok := convertResponsesFunctionToolToClaude(tool, ""); ok {
|
||||
return [][]byte{tJSON}
|
||||
}
|
||||
case "namespace":
|
||||
return convertResponsesNamespaceToolToClaude(tool, toolNameMap)
|
||||
case "web_search":
|
||||
if tJSON, ok := convertResponsesWebSearchToolToClaude(tool); ok {
|
||||
if name := gjson.GetBytes(tJSON, "name").String(); name != "" {
|
||||
toolNameMap[name] = name
|
||||
}
|
||||
return [][]byte{tJSON}
|
||||
}
|
||||
default:
|
||||
if isUnsupportedOpenAIBuiltinToolType(toolType) {
|
||||
return nil
|
||||
}
|
||||
if tool.Get("name").String() != "" {
|
||||
return [][]byte{[]byte(tool.Raw)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertResponsesNamespaceToolToClaude(tool gjson.Result, toolNameMap map[string]string) [][]byte {
|
||||
namespaceName := strings.TrimSpace(tool.Get("name").String())
|
||||
children := tool.Get("tools")
|
||||
if !children.Exists() || !children.IsArray() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var out [][]byte
|
||||
children.ForEach(func(_, child gjson.Result) bool {
|
||||
childName := responsesToolName(child)
|
||||
qualifiedName := qualifyResponsesNamespaceToolName(namespaceName, childName)
|
||||
if tJSON, ok := convertResponsesFunctionToolToClaude(child, qualifiedName); ok {
|
||||
out = append(out, tJSON)
|
||||
toolNameMap[qualifiedName] = qualifiedName
|
||||
if childName != "" {
|
||||
toolNameMap[childName] = qualifiedName
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func convertResponsesFunctionToolToClaude(tool gjson.Result, overrideName string) ([]byte, bool) {
|
||||
name := strings.TrimSpace(overrideName)
|
||||
if name == "" {
|
||||
name = responsesToolName(tool)
|
||||
}
|
||||
if name == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
tJSON := []byte(`{"name":"","description":"","input_schema":{}}`)
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "name", name)
|
||||
if d := responsesToolDescription(tool); d != "" {
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "description", d)
|
||||
}
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "input_schema", normalizeClaudeToolInputSchema(responsesToolParameters(tool)))
|
||||
return tJSON, true
|
||||
}
|
||||
|
||||
func convertResponsesWebSearchToolToClaude(tool gjson.Result) ([]byte, bool) {
|
||||
if externalWebAccess := tool.Get("external_web_access"); externalWebAccess.Exists() && !externalWebAccess.Bool() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(tool.Get("name").String())
|
||||
if name == "" {
|
||||
name = "web_search"
|
||||
}
|
||||
tJSON := []byte(`{"type":"web_search_20250305","name":""}`)
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "name", name)
|
||||
if maxUses := tool.Get("max_uses"); maxUses.Exists() {
|
||||
tJSON, _ = sjson.SetBytes(tJSON, "max_uses", maxUses.Int())
|
||||
}
|
||||
if allowedDomains := tool.Get("filters.allowed_domains"); allowedDomains.Exists() && allowedDomains.IsArray() {
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "allowed_domains", []byte(allowedDomains.Raw))
|
||||
}
|
||||
if userLocation := tool.Get("user_location"); userLocation.Exists() && userLocation.IsObject() {
|
||||
tJSON, _ = sjson.SetRawBytes(tJSON, "user_location", []byte(userLocation.Raw))
|
||||
}
|
||||
return tJSON, true
|
||||
}
|
||||
|
||||
func responsesToolName(tool gjson.Result) string {
|
||||
if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
|
||||
return name
|
||||
}
|
||||
return strings.TrimSpace(tool.Get("function.name").String())
|
||||
}
|
||||
|
||||
func responsesToolDescription(tool gjson.Result) string {
|
||||
if description := tool.Get("description").String(); description != "" {
|
||||
return description
|
||||
}
|
||||
return tool.Get("function.description").String()
|
||||
}
|
||||
|
||||
func responsesToolParameters(tool gjson.Result) gjson.Result {
|
||||
for _, path := range []string{
|
||||
"parameters",
|
||||
"parametersJsonSchema",
|
||||
"input_schema",
|
||||
"function.parameters",
|
||||
"function.parametersJsonSchema",
|
||||
} {
|
||||
if parameters := tool.Get(path); parameters.Exists() {
|
||||
return parameters
|
||||
}
|
||||
}
|
||||
return gjson.Result{}
|
||||
}
|
||||
|
||||
func normalizeClaudeToolInputSchema(parameters gjson.Result) []byte {
|
||||
raw := strings.TrimSpace(parameters.Raw)
|
||||
if raw == "" || raw == "null" || !gjson.Valid(raw) {
|
||||
return []byte(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
result := gjson.Parse(raw)
|
||||
if !result.IsObject() {
|
||||
return []byte(`{"type":"object","properties":{}}`)
|
||||
}
|
||||
schema := []byte(raw)
|
||||
schemaType := result.Get("type").String()
|
||||
if schemaType == "" {
|
||||
schema, _ = sjson.SetBytes(schema, "type", "object")
|
||||
schemaType = "object"
|
||||
}
|
||||
if schemaType == "object" && !result.Get("properties").Exists() {
|
||||
schema, _ = sjson.SetRawBytes(schema, "properties", []byte(`{}`))
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func qualifyResponsesNamespaceToolName(namespaceName, childName string) string {
|
||||
childName = strings.TrimSpace(childName)
|
||||
if childName == "" || namespaceName == "" || strings.HasPrefix(childName, "mcp__") {
|
||||
return childName
|
||||
}
|
||||
if strings.HasPrefix(childName, namespaceName) {
|
||||
return childName
|
||||
}
|
||||
if strings.HasSuffix(namespaceName, "__") {
|
||||
return namespaceName + childName
|
||||
}
|
||||
return namespaceName + "__" + childName
|
||||
}
|
||||
|
||||
func isUnsupportedOpenAIBuiltinToolType(toolType string) bool {
|
||||
switch toolType {
|
||||
case "image_generation", "file_search", "code_interpreter", "computer_use_preview":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,8 @@ type claudeToResponsesState struct {
|
||||
FuncNames map[int]string // index -> function name
|
||||
FuncCallIDs map[int]string // index -> call id
|
||||
// message text aggregation
|
||||
TextBuf strings.Builder
|
||||
TextBuf strings.Builder
|
||||
CurrentTextBuf strings.Builder
|
||||
// reasoning state
|
||||
ReasoningActive bool
|
||||
ReasoningItemID string
|
||||
@@ -80,6 +81,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
st.CreatedAt = time.Now().Unix()
|
||||
// Reset per-message aggregation state
|
||||
st.TextBuf.Reset()
|
||||
st.CurrentTextBuf.Reset()
|
||||
st.ReasoningBuf.Reset()
|
||||
st.ReasoningActive = false
|
||||
st.InTextBlock = false
|
||||
@@ -128,6 +130,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if typ == "text" {
|
||||
// open message item + content part
|
||||
st.InTextBlock = true
|
||||
st.CurrentTextBuf.Reset()
|
||||
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
|
||||
item := []byte(`{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`)
|
||||
item, _ = sjson.SetBytes(item, "sequence_number", nextSeq())
|
||||
@@ -189,6 +192,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
out = append(out, emitEvent("response.output_text.delta", msg))
|
||||
// aggregate text for response.output
|
||||
st.TextBuf.WriteString(t.String())
|
||||
st.CurrentTextBuf.WriteString(t.String())
|
||||
}
|
||||
} else if dt == "input_json_delta" {
|
||||
idx := int(root.Get("index").Int())
|
||||
@@ -220,17 +224,21 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
case "content_block_stop":
|
||||
idx := int(root.Get("index").Int())
|
||||
if st.InTextBlock {
|
||||
fullText := st.CurrentTextBuf.String()
|
||||
done := []byte(`{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`)
|
||||
done, _ = sjson.SetBytes(done, "sequence_number", nextSeq())
|
||||
done, _ = sjson.SetBytes(done, "item_id", st.CurrentMsgID)
|
||||
done, _ = sjson.SetBytes(done, "text", fullText)
|
||||
out = append(out, emitEvent("response.output_text.done", done))
|
||||
partDone := []byte(`{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`)
|
||||
partDone, _ = sjson.SetBytes(partDone, "sequence_number", nextSeq())
|
||||
partDone, _ = sjson.SetBytes(partDone, "item_id", st.CurrentMsgID)
|
||||
partDone, _ = sjson.SetBytes(partDone, "part.text", fullText)
|
||||
out = append(out, emitEvent("response.content_part.done", partDone))
|
||||
final := []byte(`{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`)
|
||||
final, _ = sjson.SetBytes(final, "sequence_number", nextSeq())
|
||||
final, _ = sjson.SetBytes(final, "item.id", st.CurrentMsgID)
|
||||
final, _ = sjson.SetBytes(final, "item.content.0.text", fullText)
|
||||
out = append(out, emitEvent("response.output_item.done", final))
|
||||
st.InTextBlock = false
|
||||
} else if st.InFuncBlock {
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertGeminiCLIRequestToCodex_PreservesSchemaPropertyNamedType(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"request": {
|
||||
"tools": [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{
|
||||
"name": "ask_user",
|
||||
"description": "Ask the user one or more questions.",
|
||||
"parametersJsonSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"header": {
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"default": "choice",
|
||||
"description": "Question type.",
|
||||
"enum": [
|
||||
"choice",
|
||||
"text",
|
||||
"yesno"
|
||||
],
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"question",
|
||||
"header",
|
||||
"type"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"questions"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
|
||||
out := ConvertGeminiCLIRequestToCodex("gpt-5.2", input, true)
|
||||
tool := gjson.GetBytes(out, "tools.0")
|
||||
if got := tool.Get("type").String(); got != "function" {
|
||||
t.Fatalf("expected tool type %q, got %q; output=%s", "function", got, string(out))
|
||||
}
|
||||
|
||||
typeProperty := tool.Get("parameters.properties.questions.items.properties.type")
|
||||
if !typeProperty.IsObject() {
|
||||
t.Fatalf("expected schema property named type to stay an object; output=%s", string(out))
|
||||
}
|
||||
if got := typeProperty.Get("type").String(); got != "string" {
|
||||
t.Fatalf("expected schema property type %q, got %q; output=%s", "string", got, string(out))
|
||||
}
|
||||
if got := typeProperty.Get("default").String(); got != "choice" {
|
||||
t.Fatalf("expected default %q, got %q; output=%s", "choice", got, string(out))
|
||||
}
|
||||
if got := typeProperty.Get("enum.2").String(); got != "yesno" {
|
||||
t.Fatalf("expected enum value %q, got %q; output=%s", "yesno", got, string(out))
|
||||
}
|
||||
}
|
||||
@@ -284,7 +284,11 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
util.Walk(toolsResult, "", "type", &pathsToLower)
|
||||
for _, p := range pathsToLower {
|
||||
fullPath := fmt.Sprintf("tools.%s", p)
|
||||
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(gjson.GetBytes(out, fullPath).String()))
|
||||
typeValue := gjson.GetBytes(out, fullPath)
|
||||
if typeValue.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
out, _ = sjson.SetBytes(out, fullPath, strings.ToLower(typeValue.String()))
|
||||
}
|
||||
|
||||
return out
|
||||
|
||||
@@ -121,13 +121,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
case "tool":
|
||||
// Handle tool response messages as top-level function_call_output objects
|
||||
toolCallID := m.Get("tool_call_id").String()
|
||||
content := m.Get("content").String()
|
||||
content := m.Get("content")
|
||||
|
||||
// Create function_call_output object
|
||||
funcOutput := []byte(`{}`)
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "type", "function_call_output")
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "call_id", toolCallID)
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content)
|
||||
funcOutput = setToolCallOutputContent(funcOutput, content)
|
||||
out, _ = sjson.SetRawBytes(out, "input.-1", funcOutput)
|
||||
|
||||
default:
|
||||
@@ -359,6 +359,91 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
return out
|
||||
}
|
||||
|
||||
func setToolCallOutputContent(funcOutput []byte, content gjson.Result) []byte {
|
||||
switch {
|
||||
case content.Type == gjson.String:
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "output", content.String())
|
||||
case content.IsArray():
|
||||
output := []byte(`[]`)
|
||||
for _, item := range content.Array() {
|
||||
output = appendToolOutputContentPart(output, item)
|
||||
}
|
||||
funcOutput, _ = sjson.SetRawBytes(funcOutput, "output", output)
|
||||
default:
|
||||
fallbackOutput := content.Raw
|
||||
if fallbackOutput == "" {
|
||||
fallbackOutput = content.String()
|
||||
}
|
||||
funcOutput, _ = sjson.SetBytes(funcOutput, "output", fallbackOutput)
|
||||
}
|
||||
return funcOutput
|
||||
}
|
||||
|
||||
func appendToolOutputContentPart(output []byte, item gjson.Result) []byte {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_text")
|
||||
part, _ = sjson.SetBytes(part, "text", item.Get("text").String())
|
||||
output, _ = sjson.SetRawBytes(output, "-1", part)
|
||||
case "image_url":
|
||||
imageURL := item.Get("image_url.url").String()
|
||||
fileID := item.Get("image_url.file_id").String()
|
||||
if imageURL == "" && fileID == "" {
|
||||
return appendToolOutputFallbackPart(output, item)
|
||||
}
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_image")
|
||||
if imageURL != "" {
|
||||
part, _ = sjson.SetBytes(part, "image_url", imageURL)
|
||||
}
|
||||
if fileID != "" {
|
||||
part, _ = sjson.SetBytes(part, "file_id", fileID)
|
||||
}
|
||||
if detail := item.Get("image_url.detail").String(); detail != "" {
|
||||
part, _ = sjson.SetBytes(part, "detail", detail)
|
||||
}
|
||||
output, _ = sjson.SetRawBytes(output, "-1", part)
|
||||
case "file":
|
||||
fileID := item.Get("file.file_id").String()
|
||||
fileData := item.Get("file.file_data").String()
|
||||
fileURL := item.Get("file.file_url").String()
|
||||
if fileID == "" && fileData == "" && fileURL == "" {
|
||||
return appendToolOutputFallbackPart(output, item)
|
||||
}
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_file")
|
||||
if fileID != "" {
|
||||
part, _ = sjson.SetBytes(part, "file_id", fileID)
|
||||
}
|
||||
if fileData != "" {
|
||||
part, _ = sjson.SetBytes(part, "file_data", fileData)
|
||||
}
|
||||
if fileURL != "" {
|
||||
part, _ = sjson.SetBytes(part, "file_url", fileURL)
|
||||
}
|
||||
if filename := item.Get("file.filename").String(); filename != "" {
|
||||
part, _ = sjson.SetBytes(part, "filename", filename)
|
||||
}
|
||||
output, _ = sjson.SetRawBytes(output, "-1", part)
|
||||
default:
|
||||
output = appendToolOutputFallbackPart(output, item)
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func appendToolOutputFallbackPart(output []byte, item gjson.Result) []byte {
|
||||
text := item.Raw
|
||||
if text == "" {
|
||||
text = item.String()
|
||||
}
|
||||
part := []byte(`{}`)
|
||||
part, _ = sjson.SetBytes(part, "type", "input_text")
|
||||
part, _ = sjson.SetBytes(part, "text", text)
|
||||
output, _ = sjson.SetRawBytes(output, "-1", part)
|
||||
return output
|
||||
}
|
||||
|
||||
// shortenNameIfNeeded applies the simple shortening rule for a single name.
|
||||
// If the name length exceeds 64, it will try to preserve the "mcp__" prefix and last segment.
|
||||
// Otherwise it truncates to 64 characters.
|
||||
|
||||
@@ -176,6 +176,182 @@ func TestToolCallWithContent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallOutputWithMultimodalContent(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Show me the generated result."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_output_1",
|
||||
"type": "function",
|
||||
"function": {"name": "render_output", "arguments": "{}"}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_output_1",
|
||||
"content": [
|
||||
{"type":"text","text":"Rendered result attached."},
|
||||
{"type":"image_url","image_url":{"url":"https://example.com/generated.png","detail":"high"}},
|
||||
{"type":"image_url","image_url":{"file_id":"file-img-123"}},
|
||||
{"type":"file","file":{"file_id":"file-doc-123","filename":"doc.pdf"}},
|
||||
{"type":"file","file":{"file_data":"SGVsbG8=","filename":"inline.txt"}},
|
||||
{"type":"file","file":{"file_url":"https://example.com/report.pdf","filename":"report.pdf"}}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "render_output", "description": "Render output", "parameters": {"type": "object", "properties": {}}}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
output := gjson.Get(result, "input.2.output")
|
||||
if !output.IsArray() {
|
||||
t.Fatalf("expected tool output to be an array, got: %s", output.Raw)
|
||||
}
|
||||
|
||||
parts := output.Array()
|
||||
if len(parts) != 6 {
|
||||
t.Fatalf("expected 6 output parts, got %d: %s", len(parts), output.Raw)
|
||||
}
|
||||
if parts[0].Get("type").String() != "input_text" || parts[0].Get("text").String() != "Rendered result attached." {
|
||||
t.Fatalf("part 0: expected input_text with rendered text, got %s", parts[0].Raw)
|
||||
}
|
||||
if parts[1].Get("type").String() != "input_image" {
|
||||
t.Fatalf("part 1: expected input_image, got %s", parts[1].Raw)
|
||||
}
|
||||
if parts[1].Get("image_url").String() != "https://example.com/generated.png" {
|
||||
t.Errorf("part 1: unexpected image_url %s", parts[1].Get("image_url").String())
|
||||
}
|
||||
if parts[1].Get("detail").String() != "high" {
|
||||
t.Errorf("part 1: unexpected detail %s", parts[1].Get("detail").String())
|
||||
}
|
||||
if parts[2].Get("type").String() != "input_image" || parts[2].Get("file_id").String() != "file-img-123" {
|
||||
t.Fatalf("part 2: expected file_id-backed input_image, got %s", parts[2].Raw)
|
||||
}
|
||||
if parts[3].Get("type").String() != "input_file" || parts[3].Get("file_id").String() != "file-doc-123" {
|
||||
t.Fatalf("part 3: expected file_id-backed input_file, got %s", parts[3].Raw)
|
||||
}
|
||||
if parts[3].Get("filename").String() != "doc.pdf" {
|
||||
t.Errorf("part 3: unexpected filename %s", parts[3].Get("filename").String())
|
||||
}
|
||||
if parts[4].Get("type").String() != "input_file" || parts[4].Get("file_data").String() != "SGVsbG8=" {
|
||||
t.Fatalf("part 4: expected file_data-backed input_file, got %s", parts[4].Raw)
|
||||
}
|
||||
if parts[5].Get("type").String() != "input_file" || parts[5].Get("file_url").String() != "https://example.com/report.pdf" {
|
||||
t.Fatalf("part 5: expected file_url-backed input_file, got %s", parts[5].Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallOutputFallsBackForInvalidStructuredParts(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Check tool output."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{"id": "call_invalid_parts", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_invalid_parts",
|
||||
"content": [
|
||||
{"type":"image_url","image_url":{"detail":"low"}},
|
||||
{"type":"file","file":{"filename":"orphan.txt"}},
|
||||
{"type":"unknown_type","foo":"bar","nested":{"a":1}}
|
||||
]
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
parts := gjson.Get(result, "input.2.output").Array()
|
||||
if len(parts) != 3 {
|
||||
t.Fatalf("expected 3 output parts, got %d: %s", len(parts), gjson.Get(result, "input.2.output").Raw)
|
||||
}
|
||||
|
||||
expectedFallbacks := []string{
|
||||
`{"type":"image_url","image_url":{"detail":"low"}}`,
|
||||
`{"type":"file","file":{"filename":"orphan.txt"}}`,
|
||||
`{"type":"unknown_type","foo":"bar","nested":{"a":1}}`,
|
||||
}
|
||||
for i, expectedFallback := range expectedFallbacks {
|
||||
if parts[i].Get("type").String() != "input_text" {
|
||||
t.Fatalf("part %d: expected input_text fallback, got %s", i, parts[i].Raw)
|
||||
}
|
||||
if parts[i].Get("text").String() != expectedFallback {
|
||||
t.Fatalf("part %d: expected fallback %s, got %s", i, expectedFallback, parts[i].Get("text").String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallOutputWithNonStringJSONContent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
expectedOutput string
|
||||
}{
|
||||
{name: "null", content: `null`, expectedOutput: `null`},
|
||||
{name: "object", content: `{"status":"ok","count":2}`, expectedOutput: `{"status":"ok","count":2}`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
input := []byte(`{
|
||||
"model": "gpt-4o",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Check tool output."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{"id": "call_json", "type": "function", "function": {"name": "inspect", "arguments": "{}"}}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_json",
|
||||
"content": ` + tt.content + `
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "inspect", "description": "Inspect", "parameters": {"type": "object", "properties": {}}}}
|
||||
]
|
||||
}`)
|
||||
|
||||
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
|
||||
result := string(out)
|
||||
|
||||
output := gjson.Get(result, "input.2.output")
|
||||
if !output.Exists() {
|
||||
t.Fatalf("expected output field to exist: %s", gjson.Get(result, "input.2").Raw)
|
||||
}
|
||||
if output.String() != tt.expectedOutput {
|
||||
t.Fatalf("expected output %s, got %s", tt.expectedOutput, output.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
|
||||
// and outputs must be translated and paired correctly.
|
||||
func TestMultipleToolCalls(t *testing.T) {
|
||||
|
||||
@@ -236,7 +236,7 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
|
||||
// Handle function name
|
||||
if function := toolCall.Get("function"); function.Exists() {
|
||||
if name := function.Get("name"); name.Exists() {
|
||||
if name := function.Get("name"); name.Exists() && name.String() != "" {
|
||||
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
|
||||
|
||||
stopThinkingContentBlock(param, &results)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertOpenAIResponseToClaude_StreamIgnoresNullToolNameDelta(t *testing.T) {
|
||||
originalRequest := []byte(`{"stream":true}`)
|
||||
var param any
|
||||
|
||||
firstChunks := ConvertOpenAIResponseToClaude(
|
||||
context.Background(),
|
||||
"test-model",
|
||||
originalRequest,
|
||||
nil,
|
||||
[]byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"role":"assistant","tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"read_file","arguments":""}}]},"finish_reason":null}]}`),
|
||||
¶m,
|
||||
)
|
||||
firstOutput := bytes.Join(firstChunks, nil)
|
||||
if !bytes.Contains(firstOutput, []byte(`"name":"read_file"`)) {
|
||||
t.Fatalf("expected first chunk to start read_file tool block, got %s", string(firstOutput))
|
||||
}
|
||||
|
||||
secondChunks := ConvertOpenAIResponseToClaude(
|
||||
context.Background(),
|
||||
"test-model",
|
||||
originalRequest,
|
||||
nil,
|
||||
[]byte(`data: {"id":"chatcmpl_1","model":"test-model","created":1,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"name":null,"arguments":"{\"path\":\"/tmp/a\"}"}}]},"finish_reason":null}]}`),
|
||||
¶m,
|
||||
)
|
||||
secondOutput := bytes.Join(secondChunks, nil)
|
||||
if bytes.Contains(secondOutput, []byte(`content_block_start`)) {
|
||||
t.Fatalf("did not expect null tool name delta to start a new content block, got %s", string(secondOutput))
|
||||
}
|
||||
if bytes.Contains(secondOutput, []byte(`"name":""`)) {
|
||||
t.Fatalf("did not expect null tool name delta to emit an empty tool name, got %s", string(secondOutput))
|
||||
}
|
||||
}
|
||||
+4
-18
@@ -18,7 +18,6 @@ const (
|
||||
tabAuthFiles
|
||||
tabAPIKeys
|
||||
tabOAuth
|
||||
tabUsage
|
||||
tabLogs
|
||||
)
|
||||
|
||||
@@ -40,7 +39,6 @@ type App struct {
|
||||
auth authTabModel
|
||||
keys keysTabModel
|
||||
oauth oauthTabModel
|
||||
usage usageTabModel
|
||||
logs logsTabModel
|
||||
|
||||
client *Client
|
||||
@@ -50,7 +48,7 @@ type App struct {
|
||||
ready bool
|
||||
|
||||
// Track which tabs have been initialized (fetched data)
|
||||
initialized [7]bool
|
||||
initialized [6]bool
|
||||
}
|
||||
|
||||
type authConnectMsg struct {
|
||||
@@ -81,10 +79,9 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
|
||||
auth: newAuthTabModel(client),
|
||||
keys: newKeysTabModel(client),
|
||||
oauth: newOAuthTabModel(client),
|
||||
usage: newUsageTabModel(client),
|
||||
logs: newLogsTabModel(client, hook),
|
||||
client: client,
|
||||
initialized: [7]bool{
|
||||
initialized: [6]bool{
|
||||
tabDashboard: true,
|
||||
tabLogs: true,
|
||||
},
|
||||
@@ -92,7 +89,7 @@ func NewApp(port int, secretKey string, hook *LogHook) App {
|
||||
|
||||
app.refreshTabs()
|
||||
if authRequired {
|
||||
app.initialized = [7]bool{}
|
||||
app.initialized = [6]bool{}
|
||||
}
|
||||
app.setAuthInputPrompt()
|
||||
return app
|
||||
@@ -128,7 +125,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
a.auth.SetSize(contentW, contentH)
|
||||
a.keys.SetSize(contentW, contentH)
|
||||
a.oauth.SetSize(contentW, contentH)
|
||||
a.usage.SetSize(contentW, contentH)
|
||||
a.logs.SetSize(contentW, contentH)
|
||||
return a, nil
|
||||
|
||||
@@ -142,7 +138,7 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
a.authenticated = true
|
||||
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
|
||||
a.refreshTabs()
|
||||
a.initialized = [7]bool{}
|
||||
a.initialized = [6]bool{}
|
||||
a.initialized[tabDashboard] = true
|
||||
cmds := []tea.Cmd{a.dashboard.Init()}
|
||||
if a.logsEnabled {
|
||||
@@ -258,8 +254,6 @@ func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
a.keys, cmd = a.keys.Update(msg)
|
||||
case tabOAuth:
|
||||
a.oauth, cmd = a.oauth.Update(msg)
|
||||
case tabUsage:
|
||||
a.usage, cmd = a.usage.Update(msg)
|
||||
case tabLogs:
|
||||
a.logs, cmd = a.logs.Update(msg)
|
||||
}
|
||||
@@ -322,8 +316,6 @@ func (a *App) initTabIfNeeded(_ int) tea.Cmd {
|
||||
return a.keys.Init()
|
||||
case tabOAuth:
|
||||
return a.oauth.Init()
|
||||
case tabUsage:
|
||||
return a.usage.Init()
|
||||
case tabLogs:
|
||||
if !a.logsEnabled {
|
||||
return nil
|
||||
@@ -360,8 +352,6 @@ func (a App) View() string {
|
||||
sb.WriteString(a.keys.View())
|
||||
case tabOAuth:
|
||||
sb.WriteString(a.oauth.View())
|
||||
case tabUsage:
|
||||
sb.WriteString(a.usage.View())
|
||||
case tabLogs:
|
||||
if a.logsEnabled {
|
||||
sb.WriteString(a.logs.View())
|
||||
@@ -529,10 +519,6 @@ func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.usage, cmd = a.usage.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
a.logs, cmd = a.logs.Update(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
|
||||
@@ -140,11 +140,6 @@ func (c *Client) PutConfigYAML(yamlContent string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// GetUsage fetches usage statistics.
|
||||
func (c *Client) GetUsage() (map[string]any, error) {
|
||||
return c.getJSON("/v0/management/usage")
|
||||
}
|
||||
|
||||
// GetAuthFiles lists auth credential files.
|
||||
// API returns {"files": [...]}.
|
||||
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
|
||||
|
||||
@@ -22,14 +22,12 @@ type dashboardModel struct {
|
||||
|
||||
// Cached data for re-rendering on locale change
|
||||
lastConfig map[string]any
|
||||
lastUsage map[string]any
|
||||
lastAuthFiles []map[string]any
|
||||
lastAPIKeys []string
|
||||
}
|
||||
|
||||
type dashboardDataMsg struct {
|
||||
config map[string]any
|
||||
usage map[string]any
|
||||
authFiles []map[string]any
|
||||
apiKeys []string
|
||||
err error
|
||||
@@ -47,25 +45,24 @@ func (m dashboardModel) Init() tea.Cmd {
|
||||
|
||||
func (m dashboardModel) fetchData() tea.Msg {
|
||||
cfg, cfgErr := m.client.GetConfig()
|
||||
usage, usageErr := m.client.GetUsage()
|
||||
authFiles, authErr := m.client.GetAuthFiles()
|
||||
apiKeys, keysErr := m.client.GetAPIKeys()
|
||||
|
||||
var err error
|
||||
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
|
||||
for _, e := range []error{cfgErr, authErr, keysErr} {
|
||||
if e != nil {
|
||||
err = e
|
||||
break
|
||||
}
|
||||
}
|
||||
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
|
||||
return dashboardDataMsg{config: cfg, authFiles: authFiles, apiKeys: apiKeys, err: err}
|
||||
}
|
||||
|
||||
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
// Re-render immediately with cached data using new locale
|
||||
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
|
||||
m.content = m.renderDashboard(m.lastConfig, m.lastAuthFiles, m.lastAPIKeys)
|
||||
m.viewport.SetContent(m.content)
|
||||
// Also fetch fresh data in background
|
||||
return m, m.fetchData
|
||||
@@ -78,11 +75,10 @@ func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
|
||||
m.err = nil
|
||||
// Cache data for locale switching
|
||||
m.lastConfig = msg.config
|
||||
m.lastUsage = msg.usage
|
||||
m.lastAuthFiles = msg.authFiles
|
||||
m.lastAPIKeys = msg.apiKeys
|
||||
|
||||
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
|
||||
m.content = m.renderDashboard(msg.config, msg.authFiles, msg.apiKeys)
|
||||
}
|
||||
m.viewport.SetContent(m.content)
|
||||
return m, nil
|
||||
@@ -121,7 +117,7 @@ func (m dashboardModel) View() string {
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
|
||||
func (m dashboardModel) renderDashboard(cfg map[string]any, authFiles []map[string]any, apiKeys []string) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("dashboard_title")))
|
||||
@@ -138,7 +134,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
|
||||
// ━━━ Stats Cards ━━━
|
||||
cardWidth := 25
|
||||
if m.width > 0 {
|
||||
cardWidth = (m.width - 6) / 4
|
||||
cardWidth = (m.width - 2) / 2
|
||||
if cardWidth < 18 {
|
||||
cardWidth = 18
|
||||
}
|
||||
@@ -173,34 +169,7 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
|
||||
))
|
||||
|
||||
// Card 3: Total Requests
|
||||
totalReqs := int64(0)
|
||||
successReqs := int64(0)
|
||||
failedReqs := int64(0)
|
||||
totalTokens := int64(0)
|
||||
if usage != nil {
|
||||
if usageMap, ok := usage["usage"].(map[string]any); ok {
|
||||
totalReqs = int64(getFloat(usageMap, "total_requests"))
|
||||
successReqs = int64(getFloat(usageMap, "success_count"))
|
||||
failedReqs = int64(getFloat(usageMap, "failure_count"))
|
||||
totalTokens = int64(getFloat(usageMap, "total_tokens"))
|
||||
}
|
||||
}
|
||||
card3 := cardStyle.Render(fmt.Sprintf(
|
||||
"%s\n%s",
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
|
||||
))
|
||||
|
||||
// Card 4: Total Tokens
|
||||
tokenStr := formatLargeNumber(totalTokens)
|
||||
card4 := cardStyle.Render(fmt.Sprintf(
|
||||
"%s\n%s",
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
|
||||
))
|
||||
|
||||
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
|
||||
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// ━━━ Current Config ━━━
|
||||
@@ -258,38 +227,6 @@ func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []m
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// ━━━ Per-Model Usage ━━━
|
||||
if usage != nil {
|
||||
if usageMap, ok := usage["usage"].(map[string]any); ok {
|
||||
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
|
||||
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
|
||||
sb.WriteString(tableHeaderStyle.Render(header))
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, apiSnap := range apis {
|
||||
if apiMap, ok := apiSnap.(map[string]any); ok {
|
||||
if models, ok := apiMap["models"].(map[string]any); ok {
|
||||
for model, v := range models {
|
||||
if stats, ok := v.(map[string]any); ok {
|
||||
reqs := int64(getFloat(stats, "total_requests"))
|
||||
toks := int64(getFloat(stats, "total_tokens"))
|
||||
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
|
||||
sb.WriteString(tableCellStyle.Render(row))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
|
||||
@@ -50,8 +50,8 @@ var locales = map[string]map[string]string{
|
||||
// ──────────────────────────────────────────
|
||||
// Tab names
|
||||
// ──────────────────────────────────────────
|
||||
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
|
||||
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
|
||||
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "日志"}
|
||||
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Logs"}
|
||||
|
||||
// TabNames returns tab names in the current locale.
|
||||
func TabNames() []string {
|
||||
|
||||
@@ -1,418 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/viewport"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// usageTabModel displays usage statistics with charts and breakdowns.
|
||||
type usageTabModel struct {
|
||||
client *Client
|
||||
viewport viewport.Model
|
||||
usage map[string]any
|
||||
err error
|
||||
width int
|
||||
height int
|
||||
ready bool
|
||||
}
|
||||
|
||||
type usageDataMsg struct {
|
||||
usage map[string]any
|
||||
err error
|
||||
}
|
||||
|
||||
func newUsageTabModel(client *Client) usageTabModel {
|
||||
return usageTabModel{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (m usageTabModel) Init() tea.Cmd {
|
||||
return m.fetchData
|
||||
}
|
||||
|
||||
func (m usageTabModel) fetchData() tea.Msg {
|
||||
usage, err := m.client.GetUsage()
|
||||
return usageDataMsg{usage: usage, err: err}
|
||||
}
|
||||
|
||||
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
|
||||
switch msg := msg.(type) {
|
||||
case localeChangedMsg:
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
case usageDataMsg:
|
||||
if msg.err != nil {
|
||||
m.err = msg.err
|
||||
} else {
|
||||
m.err = nil
|
||||
m.usage = msg.usage
|
||||
}
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
return m, nil
|
||||
|
||||
case tea.KeyMsg:
|
||||
if msg.String() == "r" {
|
||||
return m, m.fetchData
|
||||
}
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.viewport, cmd = m.viewport.Update(msg)
|
||||
return m, cmd
|
||||
}
|
||||
|
||||
func (m *usageTabModel) SetSize(w, h int) {
|
||||
m.width = w
|
||||
m.height = h
|
||||
if !m.ready {
|
||||
m.viewport = viewport.New(w, h)
|
||||
m.viewport.SetContent(m.renderContent())
|
||||
m.ready = true
|
||||
} else {
|
||||
m.viewport.Width = w
|
||||
m.viewport.Height = h
|
||||
}
|
||||
}
|
||||
|
||||
func (m usageTabModel) View() string {
|
||||
if !m.ready {
|
||||
return T("loading")
|
||||
}
|
||||
return m.viewport.View()
|
||||
}
|
||||
|
||||
func (m usageTabModel) renderContent() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render(T("usage_title")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpStyle.Render(T("usage_help")))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if m.err != nil {
|
||||
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
if m.usage == nil {
|
||||
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
usageMap, _ := m.usage["usage"].(map[string]any)
|
||||
if usageMap == nil {
|
||||
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
totalReqs := int64(getFloat(usageMap, "total_requests"))
|
||||
successCnt := int64(getFloat(usageMap, "success_count"))
|
||||
failureCnt := int64(getFloat(usageMap, "failure_count"))
|
||||
totalTokens := int64(getFloat(usageMap, "total_tokens"))
|
||||
|
||||
// ━━━ Overview Cards ━━━
|
||||
cardWidth := 20
|
||||
if m.width > 0 {
|
||||
cardWidth = (m.width - 6) / 4
|
||||
if cardWidth < 16 {
|
||||
cardWidth = 16
|
||||
}
|
||||
}
|
||||
cardStyle := lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(lipgloss.Color("240")).
|
||||
Padding(0, 1).
|
||||
Width(cardWidth).
|
||||
Height(3)
|
||||
|
||||
// Total Requests
|
||||
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
|
||||
))
|
||||
|
||||
// Total Tokens
|
||||
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
|
||||
))
|
||||
|
||||
// RPM
|
||||
rpm := float64(0)
|
||||
if totalReqs > 0 {
|
||||
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
|
||||
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
|
||||
}
|
||||
}
|
||||
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
|
||||
))
|
||||
|
||||
// TPM
|
||||
tpm := float64(0)
|
||||
if totalTokens > 0 {
|
||||
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
|
||||
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
|
||||
}
|
||||
}
|
||||
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
|
||||
"%s\n%s\n%s",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
|
||||
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
|
||||
))
|
||||
|
||||
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// ━━━ Requests by Hour (ASCII bar chart) ━━━
|
||||
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ━━━ Tokens by Hour ━━━
|
||||
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ━━━ Requests by Day ━━━
|
||||
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
// ━━━ API Detail Stats ━━━
|
||||
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
|
||||
sb.WriteString("\n")
|
||||
|
||||
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
|
||||
sb.WriteString(tableHeaderStyle.Render(header))
|
||||
sb.WriteString("\n")
|
||||
|
||||
for apiName, apiSnap := range apis {
|
||||
if apiMap, ok := apiSnap.(map[string]any); ok {
|
||||
apiReqs := int64(getFloat(apiMap, "total_requests"))
|
||||
apiToks := int64(getFloat(apiMap, "total_tokens"))
|
||||
|
||||
row := fmt.Sprintf(" %-30s %10d %12s",
|
||||
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
|
||||
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Per-model breakdown
|
||||
if models, ok := apiMap["models"].(map[string]any); ok {
|
||||
for model, v := range models {
|
||||
if stats, ok := v.(map[string]any); ok {
|
||||
mReqs := int64(getFloat(stats, "total_requests"))
|
||||
mToks := int64(getFloat(stats, "total_tokens"))
|
||||
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
|
||||
truncate(model, 28), mReqs, formatLargeNumber(mToks))
|
||||
sb.WriteString(tableCellStyle.Render(mRow))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Token type breakdown from details
|
||||
sb.WriteString(m.renderTokenBreakdown(stats))
|
||||
|
||||
// Latency breakdown from details
|
||||
sb.WriteString(m.renderLatencyBreakdown(stats))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
|
||||
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
|
||||
details, ok := modelStats["details"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
detailList, ok := details.([]any)
|
||||
if !ok || len(detailList) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
|
||||
for _, d := range detailList {
|
||||
dm, ok := d.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
tokens, ok := dm["tokens"].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
inputTotal += int64(getFloat(tokens, "input_tokens"))
|
||||
outputTotal += int64(getFloat(tokens, "output_tokens"))
|
||||
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
|
||||
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
|
||||
}
|
||||
|
||||
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := []string{}
|
||||
if inputTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
|
||||
}
|
||||
if outputTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
|
||||
}
|
||||
if cachedTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
|
||||
}
|
||||
if reasoningTotal > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" │ %s\n",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
|
||||
}
|
||||
|
||||
// renderLatencyBreakdown aggregates latency_ms from model details and displays avg/min/max.
|
||||
func (m usageTabModel) renderLatencyBreakdown(modelStats map[string]any) string {
|
||||
details, ok := modelStats["details"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
detailList, ok := details.([]any)
|
||||
if !ok || len(detailList) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var totalLatency int64
|
||||
var count int
|
||||
var minLatency, maxLatency int64
|
||||
first := true
|
||||
|
||||
for _, d := range detailList {
|
||||
dm, ok := d.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
latencyMs := int64(getFloat(dm, "latency_ms"))
|
||||
if latencyMs <= 0 {
|
||||
continue
|
||||
}
|
||||
totalLatency += latencyMs
|
||||
count++
|
||||
if first {
|
||||
minLatency = latencyMs
|
||||
maxLatency = latencyMs
|
||||
first = false
|
||||
} else {
|
||||
if latencyMs < minLatency {
|
||||
minLatency = latencyMs
|
||||
}
|
||||
if latencyMs > maxLatency {
|
||||
maxLatency = latencyMs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
avgLatency := totalLatency / int64(count)
|
||||
return fmt.Sprintf(" │ %s: avg %dms min %dms max %dms\n",
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_time")),
|
||||
avgLatency, minLatency, maxLatency)
|
||||
}
|
||||
|
||||
// renderBarChart renders a simple ASCII horizontal bar chart.
|
||||
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
|
||||
if maxBarWidth < 10 {
|
||||
maxBarWidth = 10
|
||||
}
|
||||
|
||||
// Sort keys
|
||||
keys := make([]string, 0, len(data))
|
||||
for k := range data {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// Find max value
|
||||
maxVal := float64(0)
|
||||
for _, k := range keys {
|
||||
v := getFloat(data, k)
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
}
|
||||
if maxVal == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
barStyle := lipgloss.NewStyle().Foreground(barColor)
|
||||
var sb strings.Builder
|
||||
|
||||
labelWidth := 12
|
||||
barAvail := maxBarWidth - labelWidth - 12
|
||||
if barAvail < 5 {
|
||||
barAvail = 5
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
v := getFloat(data, k)
|
||||
barLen := int(v / maxVal * float64(barAvail))
|
||||
if barLen < 1 && v > 0 {
|
||||
barLen = 1
|
||||
}
|
||||
bar := strings.Repeat("█", barLen)
|
||||
label := k
|
||||
if len(label) > labelWidth {
|
||||
label = label[:labelWidth]
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
|
||||
labelWidth, label,
|
||||
barStyle.Render(bar),
|
||||
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
|
||||
))
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRenderLatencyBreakdown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelStats map[string]any
|
||||
wantEmpty bool
|
||||
wantContains string
|
||||
}{
|
||||
{
|
||||
name: "no details",
|
||||
modelStats: map[string]any{},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty details",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{},
|
||||
},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "details with zero latency",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(0),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "single request with latency",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(1500),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: false,
|
||||
wantContains: "avg 1500ms min 1500ms max 1500ms",
|
||||
},
|
||||
{
|
||||
name: "multiple requests with varying latency",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(100),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(200),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(300),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: false,
|
||||
wantContains: "avg 200ms min 100ms max 300ms",
|
||||
},
|
||||
{
|
||||
name: "mixed valid and invalid latency values",
|
||||
modelStats: map[string]any{
|
||||
"details": []any{
|
||||
map[string]any{
|
||||
"latency_ms": float64(500),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(0),
|
||||
},
|
||||
map[string]any{
|
||||
"latency_ms": float64(1500),
|
||||
},
|
||||
},
|
||||
},
|
||||
wantEmpty: false,
|
||||
wantContains: "avg 1000ms min 500ms max 1500ms",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m := usageTabModel{}
|
||||
result := m.renderLatencyBreakdown(tt.modelStats)
|
||||
|
||||
if tt.wantEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("renderLatencyBreakdown() = %q, want empty string", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result == "" {
|
||||
t.Errorf("renderLatencyBreakdown() = empty, want non-empty string")
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantContains != "" && !strings.Contains(result, tt.wantContains) {
|
||||
t.Errorf("renderLatencyBreakdown() = %q, want to contain %q", result, tt.wantContains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsageTimeTranslations(t *testing.T) {
|
||||
prevLocale := CurrentLocale()
|
||||
t.Cleanup(func() {
|
||||
SetLocale(prevLocale)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
locale string
|
||||
want string
|
||||
}{
|
||||
{locale: "en", want: "Time"},
|
||||
{locale: "zh", want: "时间"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.locale, func(t *testing.T) {
|
||||
SetLocale(tt.locale)
|
||||
if got := T("usage_time"); got != tt.want {
|
||||
t.Fatalf("T(usage_time) = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,484 +0,0 @@
|
||||
// Package usage provides usage tracking and logging functionality for the CLI Proxy API server.
|
||||
// It includes plugins for monitoring API usage, token consumption, and other metrics
|
||||
// to help with observability and billing purposes.
|
||||
package usage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
var statisticsEnabled atomic.Bool
|
||||
|
||||
func init() {
|
||||
statisticsEnabled.Store(true)
|
||||
coreusage.RegisterPlugin(NewLoggerPlugin())
|
||||
}
|
||||
|
||||
// LoggerPlugin collects in-memory request statistics for usage analysis.
|
||||
// It implements coreusage.Plugin to receive usage records emitted by the runtime.
|
||||
type LoggerPlugin struct {
|
||||
stats *RequestStatistics
|
||||
}
|
||||
|
||||
// NewLoggerPlugin constructs a new logger plugin instance.
|
||||
//
|
||||
// Returns:
|
||||
// - *LoggerPlugin: A new logger plugin instance wired to the shared statistics store.
|
||||
func NewLoggerPlugin() *LoggerPlugin { return &LoggerPlugin{stats: defaultRequestStatistics} }
|
||||
|
||||
// HandleUsage implements coreusage.Plugin.
|
||||
// It updates the in-memory statistics store whenever a usage record is received.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context for the usage record
|
||||
// - record: The usage record to aggregate
|
||||
func (p *LoggerPlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
|
||||
if !statisticsEnabled.Load() {
|
||||
return
|
||||
}
|
||||
if p == nil || p.stats == nil {
|
||||
return
|
||||
}
|
||||
p.stats.Record(ctx, record)
|
||||
}
|
||||
|
||||
// SetStatisticsEnabled toggles whether in-memory statistics are recorded.
|
||||
func SetStatisticsEnabled(enabled bool) { statisticsEnabled.Store(enabled) }
|
||||
|
||||
// StatisticsEnabled reports the current recording state.
|
||||
func StatisticsEnabled() bool { return statisticsEnabled.Load() }
|
||||
|
||||
// RequestStatistics maintains aggregated request metrics in memory.
|
||||
type RequestStatistics struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
totalRequests int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
totalTokens int64
|
||||
|
||||
apis map[string]*apiStats
|
||||
|
||||
requestsByDay map[string]int64
|
||||
requestsByHour map[int]int64
|
||||
tokensByDay map[string]int64
|
||||
tokensByHour map[int]int64
|
||||
}
|
||||
|
||||
// apiStats holds aggregated metrics for a single API key.
|
||||
type apiStats struct {
|
||||
TotalRequests int64
|
||||
TotalTokens int64
|
||||
Models map[string]*modelStats
|
||||
}
|
||||
|
||||
// modelStats holds aggregated metrics for a specific model within an API.
|
||||
type modelStats struct {
|
||||
TotalRequests int64
|
||||
TotalTokens int64
|
||||
Details []RequestDetail
|
||||
}
|
||||
|
||||
// RequestDetail stores the timestamp, latency, and token usage for a single request.
|
||||
type RequestDetail struct {
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
LatencyMs int64 `json:"latency_ms"`
|
||||
Source string `json:"source"`
|
||||
AuthIndex string `json:"auth_index"`
|
||||
Tokens TokenStats `json:"tokens"`
|
||||
Failed bool `json:"failed"`
|
||||
}
|
||||
|
||||
// TokenStats captures the token usage breakdown for a request.
|
||||
type TokenStats struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
ReasoningTokens int64 `json:"reasoning_tokens"`
|
||||
CachedTokens int64 `json:"cached_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// StatisticsSnapshot represents an immutable view of the aggregated metrics.
|
||||
type StatisticsSnapshot struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
SuccessCount int64 `json:"success_count"`
|
||||
FailureCount int64 `json:"failure_count"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
|
||||
APIs map[string]APISnapshot `json:"apis"`
|
||||
|
||||
RequestsByDay map[string]int64 `json:"requests_by_day"`
|
||||
RequestsByHour map[string]int64 `json:"requests_by_hour"`
|
||||
TokensByDay map[string]int64 `json:"tokens_by_day"`
|
||||
TokensByHour map[string]int64 `json:"tokens_by_hour"`
|
||||
}
|
||||
|
||||
// APISnapshot summarises metrics for a single API key.
|
||||
type APISnapshot struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Models map[string]ModelSnapshot `json:"models"`
|
||||
}
|
||||
|
||||
// ModelSnapshot summarises metrics for a specific model.
|
||||
type ModelSnapshot struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
Details []RequestDetail `json:"details"`
|
||||
}
|
||||
|
||||
var defaultRequestStatistics = NewRequestStatistics()
|
||||
|
||||
// GetRequestStatistics returns the shared statistics store.
|
||||
func GetRequestStatistics() *RequestStatistics { return defaultRequestStatistics }
|
||||
|
||||
// NewRequestStatistics constructs an empty statistics store.
|
||||
func NewRequestStatistics() *RequestStatistics {
|
||||
return &RequestStatistics{
|
||||
apis: make(map[string]*apiStats),
|
||||
requestsByDay: make(map[string]int64),
|
||||
requestsByHour: make(map[int]int64),
|
||||
tokensByDay: make(map[string]int64),
|
||||
tokensByHour: make(map[int]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// Record ingests a new usage record and updates the aggregates.
|
||||
func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if !statisticsEnabled.Load() {
|
||||
return
|
||||
}
|
||||
timestamp := record.RequestedAt
|
||||
if timestamp.IsZero() {
|
||||
timestamp = time.Now()
|
||||
}
|
||||
detail := normaliseDetail(record.Detail)
|
||||
totalTokens := detail.TotalTokens
|
||||
statsKey := record.APIKey
|
||||
if statsKey == "" {
|
||||
statsKey = resolveAPIIdentifier(ctx, record)
|
||||
}
|
||||
failed := record.Failed
|
||||
if !failed {
|
||||
failed = !resolveSuccess(ctx)
|
||||
}
|
||||
success := !failed
|
||||
modelName := record.Model
|
||||
if modelName == "" {
|
||||
modelName = "unknown"
|
||||
}
|
||||
dayKey := timestamp.Format("2006-01-02")
|
||||
hourKey := timestamp.Hour()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.totalRequests++
|
||||
if success {
|
||||
s.successCount++
|
||||
} else {
|
||||
s.failureCount++
|
||||
}
|
||||
s.totalTokens += totalTokens
|
||||
|
||||
stats, ok := s.apis[statsKey]
|
||||
if !ok {
|
||||
stats = &apiStats{Models: make(map[string]*modelStats)}
|
||||
s.apis[statsKey] = stats
|
||||
}
|
||||
s.updateAPIStats(stats, modelName, RequestDetail{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: normaliseLatency(record.Latency),
|
||||
Source: record.Source,
|
||||
AuthIndex: record.AuthIndex,
|
||||
Tokens: detail,
|
||||
Failed: failed,
|
||||
})
|
||||
|
||||
s.requestsByDay[dayKey]++
|
||||
s.requestsByHour[hourKey]++
|
||||
s.tokensByDay[dayKey] += totalTokens
|
||||
s.tokensByHour[hourKey] += totalTokens
|
||||
}
|
||||
|
||||
func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail RequestDetail) {
|
||||
stats.TotalRequests++
|
||||
stats.TotalTokens += detail.Tokens.TotalTokens
|
||||
modelStatsValue, ok := stats.Models[model]
|
||||
if !ok {
|
||||
modelStatsValue = &modelStats{}
|
||||
stats.Models[model] = modelStatsValue
|
||||
}
|
||||
modelStatsValue.TotalRequests++
|
||||
modelStatsValue.TotalTokens += detail.Tokens.TotalTokens
|
||||
modelStatsValue.Details = append(modelStatsValue.Details, detail)
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the aggregated metrics for external consumption.
|
||||
func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
|
||||
result := StatisticsSnapshot{}
|
||||
if s == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
result.TotalRequests = s.totalRequests
|
||||
result.SuccessCount = s.successCount
|
||||
result.FailureCount = s.failureCount
|
||||
result.TotalTokens = s.totalTokens
|
||||
|
||||
result.APIs = make(map[string]APISnapshot, len(s.apis))
|
||||
for apiName, stats := range s.apis {
|
||||
apiSnapshot := APISnapshot{
|
||||
TotalRequests: stats.TotalRequests,
|
||||
TotalTokens: stats.TotalTokens,
|
||||
Models: make(map[string]ModelSnapshot, len(stats.Models)),
|
||||
}
|
||||
for modelName, modelStatsValue := range stats.Models {
|
||||
requestDetails := make([]RequestDetail, len(modelStatsValue.Details))
|
||||
copy(requestDetails, modelStatsValue.Details)
|
||||
apiSnapshot.Models[modelName] = ModelSnapshot{
|
||||
TotalRequests: modelStatsValue.TotalRequests,
|
||||
TotalTokens: modelStatsValue.TotalTokens,
|
||||
Details: requestDetails,
|
||||
}
|
||||
}
|
||||
result.APIs[apiName] = apiSnapshot
|
||||
}
|
||||
|
||||
result.RequestsByDay = make(map[string]int64, len(s.requestsByDay))
|
||||
for k, v := range s.requestsByDay {
|
||||
result.RequestsByDay[k] = v
|
||||
}
|
||||
|
||||
result.RequestsByHour = make(map[string]int64, len(s.requestsByHour))
|
||||
for hour, v := range s.requestsByHour {
|
||||
key := formatHour(hour)
|
||||
result.RequestsByHour[key] = v
|
||||
}
|
||||
|
||||
result.TokensByDay = make(map[string]int64, len(s.tokensByDay))
|
||||
for k, v := range s.tokensByDay {
|
||||
result.TokensByDay[k] = v
|
||||
}
|
||||
|
||||
result.TokensByHour = make(map[string]int64, len(s.tokensByHour))
|
||||
for hour, v := range s.tokensByHour {
|
||||
key := formatHour(hour)
|
||||
result.TokensByHour[key] = v
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type MergeResult struct {
|
||||
Added int64 `json:"added"`
|
||||
Skipped int64 `json:"skipped"`
|
||||
}
|
||||
|
||||
// MergeSnapshot merges an exported statistics snapshot into the current store.
|
||||
// Existing data is preserved and duplicate request details are skipped.
|
||||
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
|
||||
result := MergeResult{}
|
||||
if s == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
for apiName, stats := range s.apis {
|
||||
if stats == nil {
|
||||
continue
|
||||
}
|
||||
for modelName, modelStatsValue := range stats.Models {
|
||||
if modelStatsValue == nil {
|
||||
continue
|
||||
}
|
||||
for _, detail := range modelStatsValue.Details {
|
||||
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for apiName, apiSnapshot := range snapshot.APIs {
|
||||
apiName = strings.TrimSpace(apiName)
|
||||
if apiName == "" {
|
||||
continue
|
||||
}
|
||||
stats, ok := s.apis[apiName]
|
||||
if !ok || stats == nil {
|
||||
stats = &apiStats{Models: make(map[string]*modelStats)}
|
||||
s.apis[apiName] = stats
|
||||
} else if stats.Models == nil {
|
||||
stats.Models = make(map[string]*modelStats)
|
||||
}
|
||||
for modelName, modelSnapshot := range apiSnapshot.Models {
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if modelName == "" {
|
||||
modelName = "unknown"
|
||||
}
|
||||
for _, detail := range modelSnapshot.Details {
|
||||
detail.Tokens = normaliseTokenStats(detail.Tokens)
|
||||
if detail.LatencyMs < 0 {
|
||||
detail.LatencyMs = 0
|
||||
}
|
||||
if detail.Timestamp.IsZero() {
|
||||
detail.Timestamp = time.Now()
|
||||
}
|
||||
key := dedupKey(apiName, modelName, detail)
|
||||
if _, exists := seen[key]; exists {
|
||||
result.Skipped++
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
s.recordImported(apiName, modelName, stats, detail)
|
||||
result.Added++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
|
||||
totalTokens := detail.Tokens.TotalTokens
|
||||
if totalTokens < 0 {
|
||||
totalTokens = 0
|
||||
}
|
||||
|
||||
s.totalRequests++
|
||||
if detail.Failed {
|
||||
s.failureCount++
|
||||
} else {
|
||||
s.successCount++
|
||||
}
|
||||
s.totalTokens += totalTokens
|
||||
|
||||
s.updateAPIStats(stats, modelName, detail)
|
||||
|
||||
dayKey := detail.Timestamp.Format("2006-01-02")
|
||||
hourKey := detail.Timestamp.Hour()
|
||||
|
||||
s.requestsByDay[dayKey]++
|
||||
s.requestsByHour[hourKey]++
|
||||
s.tokensByDay[dayKey] += totalTokens
|
||||
s.tokensByHour[hourKey] += totalTokens
|
||||
}
|
||||
|
||||
func dedupKey(apiName, modelName string, detail RequestDetail) string {
|
||||
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
|
||||
tokens := normaliseTokenStats(detail.Tokens)
|
||||
return fmt.Sprintf(
|
||||
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
|
||||
apiName,
|
||||
modelName,
|
||||
timestamp,
|
||||
detail.Source,
|
||||
detail.AuthIndex,
|
||||
detail.Failed,
|
||||
tokens.InputTokens,
|
||||
tokens.OutputTokens,
|
||||
tokens.ReasoningTokens,
|
||||
tokens.CachedTokens,
|
||||
tokens.TotalTokens,
|
||||
)
|
||||
}
|
||||
|
||||
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
|
||||
if ctx != nil {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
|
||||
path := ginCtx.FullPath()
|
||||
if path == "" && ginCtx.Request != nil {
|
||||
path = ginCtx.Request.URL.Path
|
||||
}
|
||||
method := ""
|
||||
if ginCtx.Request != nil {
|
||||
method = ginCtx.Request.Method
|
||||
}
|
||||
if path != "" {
|
||||
if method != "" {
|
||||
return method + " " + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
}
|
||||
}
|
||||
if record.Provider != "" {
|
||||
return record.Provider
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func resolveSuccess(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return true
|
||||
}
|
||||
ginCtx, ok := ctx.Value("gin").(*gin.Context)
|
||||
if !ok || ginCtx == nil {
|
||||
return true
|
||||
}
|
||||
status := ginCtx.Writer.Status()
|
||||
if status == 0 {
|
||||
return true
|
||||
}
|
||||
return status < httpStatusBadRequest
|
||||
}
|
||||
|
||||
const httpStatusBadRequest = 400
|
||||
|
||||
func normaliseDetail(detail coreusage.Detail) TokenStats {
|
||||
tokens := TokenStats{
|
||||
InputTokens: detail.InputTokens,
|
||||
OutputTokens: detail.OutputTokens,
|
||||
ReasoningTokens: detail.ReasoningTokens,
|
||||
CachedTokens: detail.CachedTokens,
|
||||
TotalTokens: detail.TotalTokens,
|
||||
}
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens + detail.CachedTokens
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func normaliseTokenStats(tokens TokenStats) TokenStats {
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
|
||||
}
|
||||
if tokens.TotalTokens == 0 {
|
||||
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func normaliseLatency(latency time.Duration) int64 {
|
||||
if latency <= 0 {
|
||||
return 0
|
||||
}
|
||||
return latency.Milliseconds()
|
||||
}
|
||||
|
||||
func formatHour(hour int) string {
|
||||
if hour < 0 {
|
||||
hour = 0
|
||||
}
|
||||
hour = hour % 24
|
||||
return fmt.Sprintf("%02d", hour)
|
||||
}
|
||||
@@ -1,96 +0,0 @@
|
||||
package usage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
)
|
||||
|
||||
func TestRequestStatisticsRecordIncludesLatency(t *testing.T) {
|
||||
stats := NewRequestStatistics()
|
||||
stats.Record(context.Background(), coreusage.Record{
|
||||
APIKey: "test-key",
|
||||
Model: "gpt-5.4",
|
||||
RequestedAt: time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC),
|
||||
Latency: 1500 * time.Millisecond,
|
||||
Detail: coreusage.Detail{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
})
|
||||
|
||||
snapshot := stats.Snapshot()
|
||||
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("details len = %d, want 1", len(details))
|
||||
}
|
||||
if details[0].LatencyMs != 1500 {
|
||||
t.Fatalf("latency_ms = %d, want 1500", details[0].LatencyMs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestStatisticsMergeSnapshotDedupIgnoresLatency(t *testing.T) {
|
||||
stats := NewRequestStatistics()
|
||||
timestamp := time.Date(2026, 3, 20, 12, 0, 0, 0, time.UTC)
|
||||
first := StatisticsSnapshot{
|
||||
APIs: map[string]APISnapshot{
|
||||
"test-key": {
|
||||
Models: map[string]ModelSnapshot{
|
||||
"gpt-5.4": {
|
||||
Details: []RequestDetail{{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: 0,
|
||||
Source: "user@example.com",
|
||||
AuthIndex: "0",
|
||||
Tokens: TokenStats{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
second := StatisticsSnapshot{
|
||||
APIs: map[string]APISnapshot{
|
||||
"test-key": {
|
||||
Models: map[string]ModelSnapshot{
|
||||
"gpt-5.4": {
|
||||
Details: []RequestDetail{{
|
||||
Timestamp: timestamp,
|
||||
LatencyMs: 2500,
|
||||
Source: "user@example.com",
|
||||
AuthIndex: "0",
|
||||
Tokens: TokenStats{
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalTokens: 30,
|
||||
},
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := stats.MergeSnapshot(first)
|
||||
if result.Added != 1 || result.Skipped != 0 {
|
||||
t.Fatalf("first merge = %+v, want added=1 skipped=0", result)
|
||||
}
|
||||
|
||||
result = stats.MergeSnapshot(second)
|
||||
if result.Added != 0 || result.Skipped != 1 {
|
||||
t.Fatalf("second merge = %+v, want added=0 skipped=1", result)
|
||||
}
|
||||
|
||||
snapshot := stats.Snapshot()
|
||||
details := snapshot.APIs["test-key"].Models["gpt-5.4"].Details
|
||||
if len(details) != 1 {
|
||||
t.Fatalf("details len = %d, want 1", len(details))
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
|
||||
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
|
||||
}
|
||||
if oldCfg.RedisUsageQueueRetentionSeconds != newCfg.RedisUsageQueueRetentionSeconds {
|
||||
changes = append(changes, fmt.Sprintf("redis-usage-queue-retention-seconds: %d -> %d", oldCfg.RedisUsageQueueRetentionSeconds, newCfg.RedisUsageQueueRetentionSeconds))
|
||||
}
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user