Files
CLIProxyAPI/sdk/translator/registry.go
T
Longwu Ou e1e9fc43c1 fix: normalize model name in TranslateRequest fallback to prevent prefix leak
When no request translator is registered for a format pair (e.g.
        openai-response → openai-response), TranslateRequest returned the raw
        payload unchanged. This caused client-side model prefixes (e.g.
        "copilot/gpt-5-mini") to leak into upstream requests, resulting in
        "The requested model is not supported" errors from providers.

        The fallback path now updates the "model" field in the payload to
        match the resolved model name before returning.
2026-03-18 12:30:22 -04:00

153 lines
5.1 KiB
Go

package translator
import (
"context"
"sync"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Registry manages translation functions across schemas.
type Registry struct {
mu sync.RWMutex
requests map[Format]map[Format]RequestTransform
responses map[Format]map[Format]ResponseTransform
}
// NewRegistry constructs an empty translator registry.
func NewRegistry() *Registry {
return &Registry{
requests: make(map[Format]map[Format]RequestTransform),
responses: make(map[Format]map[Format]ResponseTransform),
}
}
// Register stores request/response transforms between two formats.
func (r *Registry) Register(from, to Format, request RequestTransform, response ResponseTransform) {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.requests[from]; !ok {
r.requests[from] = make(map[Format]RequestTransform)
}
if request != nil {
r.requests[from][to] = request
}
if _, ok := r.responses[from]; !ok {
r.responses[from] = make(map[Format]ResponseTransform)
}
r.responses[from][to] = response
}
// TranslateRequest converts a payload between schemas, returning the original payload
// if no translator is registered. When falling back to the original payload, the
// "model" field is still updated to match the resolved model name so that
// client-side prefixes (e.g. "copilot/gpt-5-mini") are not leaked upstream.
func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.requests[from]; ok {
if fn, isOk := byTarget[to]; isOk && fn != nil {
return fn(model, rawJSON, stream)
}
}
if model != "" && gjson.GetBytes(rawJSON, "model").String() != model {
if updated, err := sjson.SetBytes(rawJSON, "model", model); err == nil {
return updated
}
}
return rawJSON
}
// HasResponseTransformer indicates whether a response translator exists.
func (r *Registry) HasResponseTransformer(from, to Format) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[from]; ok {
if _, isOk := byTarget[to]; isOk {
return true
}
}
return false
}
// TranslateStream applies the registered streaming response translator.
func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.Stream != nil {
return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return []string{string(rawJSON)}
}
// TranslateNonStream applies the registered non-stream response translator.
func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.NonStream != nil {
return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return string(rawJSON)
}
// TranslateNonStream applies the registered non-stream response translator.
func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.TokenCount != nil {
return fn.TokenCount(ctx, count)
}
}
return string(rawJSON)
}
var defaultRegistry = NewRegistry()
// Default exposes the package-level registry for shared use.
func Default() *Registry {
return defaultRegistry
}
// Register attaches transforms to the default registry.
func Register(from, to Format, request RequestTransform, response ResponseTransform) {
defaultRegistry.Register(from, to, request, response)
}
// TranslateRequest is a helper on the default registry.
func TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
return defaultRegistry.TranslateRequest(from, to, model, rawJSON, stream)
}
// HasResponseTransformer inspects the default registry.
func HasResponseTransformer(from, to Format) bool {
return defaultRegistry.HasResponseTransformer(from, to)
}
// TranslateStream is a helper on the default registry.
func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
return defaultRegistry.TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateNonStream is a helper on the default registry.
func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
return defaultRegistry.TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateTokenCount is a helper on the default registry.
func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON)
}