feat(usage): add support for requested model alias handling

- Introduced methods for setting and retrieving model aliases in execution and usage contexts.
- Enhanced `UsageReporter` and related structures to include client-requested aliases.
- Updated tests to validate alias propagation and ensure correct usage reporting.
- Adjusted metadata handling in CLIProxyAPI executors to address alias integration.
This commit is contained in:
Luis Pater
2026-05-05 01:47:53 +08:00
parent 28b4b19e7e
commit ba5d8ca733
8 changed files with 125 additions and 6 deletions
+6
View File
@@ -33,6 +33,10 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
if modelName == "" { if modelName == "" {
modelName = "unknown" modelName = "unknown"
} }
aliasName := strings.TrimSpace(record.Alias)
if aliasName == "" {
aliasName = modelName
}
provider := strings.TrimSpace(record.Provider) provider := strings.TrimSpace(record.Provider)
if provider == "" { if provider == "" {
provider = "unknown" provider = "unknown"
@@ -76,6 +80,7 @@ func (p *usageQueuePlugin) HandleUsage(ctx context.Context, record coreusage.Rec
requestDetail: detail, requestDetail: detail,
Provider: provider, Provider: provider,
Model: modelName, Model: modelName,
Alias: aliasName,
Endpoint: resolveEndpoint(ctx), Endpoint: resolveEndpoint(ctx),
AuthType: authType, AuthType: authType,
APIKey: apiKey, APIKey: apiKey,
@@ -91,6 +96,7 @@ type queuedUsageDetail struct {
requestDetail requestDetail
Provider string `json:"provider"` Provider string `json:"provider"`
Model string `json:"model"` Model string `json:"model"`
Alias string `json:"alias"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
AuthType string `json:"auth_type"` AuthType string `json:"auth_type"`
APIKey string `json:"api_key"` APIKey string `json:"api_key"`
+6
View File
@@ -24,6 +24,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
plugin.HandleUsage(ctx, coreusage.Record{ plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai", Provider: "openai",
Model: "gpt-5.4", Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key", APIKey: "test-key",
AuthIndex: "0", AuthIndex: "0",
AuthType: "apikey", AuthType: "apikey",
@@ -40,6 +41,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndSuccess(t *testing.T) {
payload := popSinglePayload(t) payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai") requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4") requireStringField(t, payload, "model", "gpt-5.4")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions") requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "auth_type", "apikey") requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "ctx-request-id") requireStringField(t, payload, "request_id", "ctx-request-id")
@@ -58,6 +60,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
plugin.HandleUsage(ctx, coreusage.Record{ plugin.HandleUsage(ctx, coreusage.Record{
Provider: "openai", Provider: "openai",
Model: "gpt-5.4-mini", Model: "gpt-5.4-mini",
Alias: "client-mini",
APIKey: "test-key", APIKey: "test-key",
AuthIndex: "0", AuthIndex: "0",
AuthType: "apikey", AuthType: "apikey",
@@ -74,6 +77,7 @@ func TestUsageQueuePluginPayloadIncludesStableFieldsAndFailureAndGinRequestID(t
payload := popSinglePayload(t) payload := popSinglePayload(t)
requireStringField(t, payload, "provider", "openai") requireStringField(t, payload, "provider", "openai")
requireStringField(t, payload, "model", "gpt-5.4-mini") requireStringField(t, payload, "model", "gpt-5.4-mini")
requireStringField(t, payload, "alias", "client-mini")
requireStringField(t, payload, "endpoint", "GET /v1/responses") requireStringField(t, payload, "endpoint", "GET /v1/responses")
requireStringField(t, payload, "auth_type", "apikey") requireStringField(t, payload, "auth_type", "apikey")
requireStringField(t, payload, "request_id", "gin-request-id") requireStringField(t, payload, "request_id", "gin-request-id")
@@ -102,6 +106,7 @@ func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) {
mgr.Publish(ctx, coreusage.Record{ mgr.Publish(ctx, coreusage.Record{
Provider: "openai", Provider: "openai",
Model: "gpt-5.4", Model: "gpt-5.4",
Alias: "client-gpt",
APIKey: "test-key", APIKey: "test-key",
AuthIndex: "0", AuthIndex: "0",
AuthType: "apikey", AuthType: "apikey",
@@ -117,6 +122,7 @@ func TestUsageQueuePluginAsyncIgnoresRecycledGinContext(t *testing.T) {
payload := waitForSinglePayload(t, 2*time.Second) payload := waitForSinglePayload(t, 2*time.Second)
requireStringField(t, payload, "endpoint", "POST /v1/chat/completions") requireStringField(t, payload, "endpoint", "POST /v1/chat/completions")
requireStringField(t, payload, "alias", "client-gpt")
requireStringField(t, payload, "request_id", "ctx-request-id") requireStringField(t, payload, "request_id", "ctx-request-id")
requireBoolField(t, payload, "failed", true) requireBoolField(t, payload, "failed", true)
}) })
@@ -18,6 +18,7 @@ import (
type UsageReporter struct { type UsageReporter struct {
provider string provider string
model string model string
alias string
authID string authID string
authIndex string authIndex string
authType string authType string
@@ -29,9 +30,14 @@ type UsageReporter struct {
func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter { func NewUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *UsageReporter {
apiKey := APIKeyFromContext(ctx) apiKey := APIKeyFromContext(ctx)
alias := usage.RequestedModelAliasFromContext(ctx)
if alias == "" {
alias = model
}
reporter := &UsageReporter{ reporter := &UsageReporter{
provider: provider, provider: provider,
model: model, model: model,
alias: strings.TrimSpace(alias),
requestedAt: time.Now(), requestedAt: time.Now(),
apiKey: apiKey, apiKey: apiKey,
source: resolveUsageSource(auth, apiKey), source: resolveUsageSource(auth, apiKey),
@@ -139,6 +145,7 @@ func (r *UsageReporter) buildRecordForModel(model string, detail usage.Detail, f
return usage.Record{ return usage.Record{
Provider: r.provider, Provider: r.provider,
Model: model, Model: model,
Alias: r.alias,
Source: r.source, Source: r.source,
APIKey: r.apiKey, APIKey: r.apiKey,
AuthID: r.authID, AuthID: r.authID,
@@ -1,6 +1,7 @@
package helps package helps
import ( import (
"context"
"testing" "testing"
"time" "time"
@@ -107,6 +108,19 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) {
} }
} }
func TestUsageReporterBuildRecordIncludesRequestedModelAlias(t *testing.T) {
ctx := usage.WithRequestedModelAlias(context.Background(), "client-gpt")
reporter := NewUsageReporter(ctx, "openai", "gpt-5.4", nil)
record := reporter.buildRecord(usage.Detail{TotalTokens: 3}, false)
if record.Model != "gpt-5.4" {
t.Fatalf("model = %q, want %q", record.Model, "gpt-5.4")
}
if record.Alias != "client-gpt" {
t.Fatalf("alias = %q, want %q", record.Alias, "client-gpt")
}
}
func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) { func TestUsageReporterBuildAdditionalModelRecordSkipsZeroTokens(t *testing.T) {
reporter := &UsageReporter{ reporter := &UsageReporter{
provider: "codex", provider: "codex",
+3 -3
View File
@@ -539,7 +539,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
return nil, nil, errMsg return nil, nil, errMsg
} }
reqMeta := requestExecutionMetadata(ctx) reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON payload := rawJSON
if len(payload) == 0 { if len(payload) == 0 {
payload = nil payload = nil
@@ -587,7 +587,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
return nil, nil, errMsg return nil, nil, errMsg
} }
reqMeta := requestExecutionMetadata(ctx) reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON payload := rawJSON
if len(payload) == 0 { if len(payload) == 0 {
payload = nil payload = nil
@@ -639,7 +639,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return nil, nil, errChan return nil, nil, errChan
} }
reqMeta := requestExecutionMetadata(ctx) reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel reqMeta[coreexecutor.RequestedModelMetadataKey] = modelName
payload := rawJSON payload := rawJSON
if len(payload) == 0 { if len(payload) == 0 {
payload = nil payload = nil
+35
View File
@@ -22,6 +22,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -827,6 +828,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi
if executor == nil { if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
} }
ctx = contextWithRequestedModelAlias(ctx, opts, routeModel)
var lastErr error var lastErr error
for idx, execModel := range execModels { for idx, execModel := range execModels {
resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled) resultModel := m.stateModelForExecution(auth, routeModel, execModel, pooled)
@@ -1319,6 +1321,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
} }
execCtx = contextWithRequestedModelAlias(execCtx, opts, routeModel)
models, pooled := m.preparedExecutionModels(auth, routeModel) models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 { if len(models) == 0 {
@@ -1397,6 +1400,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
} }
execCtx = contextWithRequestedModelAlias(execCtx, opts, routeModel)
models, pooled := m.preparedExecutionModels(auth, routeModel) models, pooled := m.preparedExecutionModels(auth, routeModel)
if len(models) == 0 { if len(models) == 0 {
@@ -1534,6 +1538,36 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
} }
} }
func contextWithRequestedModelAlias(ctx context.Context, opts cliproxyexecutor.Options, fallback string) context.Context {
alias := requestedModelAliasFromOptions(opts, fallback)
return coreusage.WithRequestedModelAlias(ctx, alias)
}
func requestedModelAliasFromOptions(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 {
return fallback
}
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return fallback
}
switch value := raw.(type) {
case string:
if strings.TrimSpace(value) == "" {
return fallback
}
return strings.TrimSpace(value)
case []byte:
if len(value) == 0 {
return fallback
}
return strings.TrimSpace(string(value))
default:
return fallback
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string { func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 { if len(meta) == 0 {
return "" return ""
@@ -3096,6 +3130,7 @@ func (m *Manager) tryAntigravityCreditsExecute(ctx context.Context, req cliproxy
creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt) creditsCtx = context.WithValue(creditsCtx, "cliproxy.roundtripper", rt)
} }
creditsOpts := ensureRequestedModelMetadata(opts, routeModel) creditsOpts := ensureRequestedModelMetadata(opts, routeModel)
creditsCtx = contextWithRequestedModelAlias(creditsCtx, creditsOpts, routeModel)
publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID) publishSelectedAuthMetadata(creditsOpts.Metadata, c.auth.ID)
models := m.executionModelCandidates(c.auth, routeModel) models := m.executionModelCandidates(c.auth, routeModel)
if len(models) == 0 { if len(models) == 0 {
@@ -10,6 +10,7 @@ import (
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
) )
type aliasRoutingExecutor struct { type aliasRoutingExecutor struct {
@@ -17,13 +18,15 @@ type aliasRoutingExecutor struct {
mu sync.Mutex mu sync.Mutex
executeModels []string executeModels []string
executeAliases []string
} }
func (e *aliasRoutingExecutor) Identifier() string { return e.id } func (e *aliasRoutingExecutor) Identifier() string { return e.id }
func (e *aliasRoutingExecutor) Execute(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (e *aliasRoutingExecutor) Execute(ctx context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.mu.Lock() e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model) e.executeModels = append(e.executeModels, req.Model)
e.executeAliases = append(e.executeAliases, coreusage.RequestedModelAliasFromContext(ctx))
e.mu.Unlock() e.mu.Unlock()
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
} }
@@ -52,6 +55,14 @@ func (e *aliasRoutingExecutor) ExecuteModels() []string {
return out return out
} }
func (e *aliasRoutingExecutor) ExecuteAliases() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeAliases))
copy(out, e.executeAliases)
return out
}
func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) { func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
const ( const (
provider = "antigravity" provider = "antigravity"
@@ -108,4 +119,12 @@ func TestManagerExecute_OAuthAliasBypassesBlockedRouteModel(t *testing.T) {
if gotModels[0] != targetModel { if gotModels[0] != targetModel {
t.Fatalf("execute model = %q, want %q", gotModels[0], targetModel) t.Fatalf("execute model = %q, want %q", gotModels[0], targetModel)
} }
gotAliases := executor.ExecuteAliases()
if len(gotAliases) != 1 {
t.Fatalf("execute aliases len = %d, want 1", len(gotAliases))
}
if gotAliases[0] != routeModel {
t.Fatalf("execute alias = %q, want %q", gotAliases[0], routeModel)
}
} }
+32
View File
@@ -2,6 +2,7 @@ package usage
import ( import (
"context" "context"
"strings"
"sync" "sync"
"time" "time"
@@ -12,6 +13,7 @@ import (
type Record struct { type Record struct {
Provider string Provider string
Model string Model string
Alias string
APIKey string APIKey string
AuthID string AuthID string
AuthIndex string AuthIndex string
@@ -32,6 +34,36 @@ type Detail struct {
TotalTokens int64 TotalTokens int64
} }
type requestedModelAliasContextKey struct{}
// WithRequestedModelAlias stores the client-requested model name for usage sinks.
func WithRequestedModelAlias(ctx context.Context, alias string) context.Context {
if ctx == nil {
ctx = context.Background()
}
alias = strings.TrimSpace(alias)
if alias == "" {
return ctx
}
return context.WithValue(ctx, requestedModelAliasContextKey{}, alias)
}
// RequestedModelAliasFromContext returns the client-requested model name stored in ctx.
func RequestedModelAliasFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(requestedModelAliasContextKey{})
switch value := raw.(type) {
case string:
return strings.TrimSpace(value)
case []byte:
return strings.TrimSpace(string(value))
default:
return ""
}
}
// Plugin consumes usage records emitted by the proxy runtime. // Plugin consumes usage records emitted by the proxy runtime.
type Plugin interface { type Plugin interface {
HandleUsage(ctx context.Context, record Record) HandleUsage(ctx context.Context, record Record)