refactor(thinking): add Gemini family provider grouping for strict validation

This commit is contained in:
hkfires
2026-01-18 11:30:53 +08:00
parent c7e8830a56
commit 03005b5d29
4 changed files with 230 additions and 109 deletions
+15 -47
View File
@@ -38,31 +38,15 @@ func (r *Registry) Register(from, to Format, request RequestTransform, response
r.responses[from][to] = response
}
// formatAliases returns compatible aliases for a format, ordered by preference.
func formatAliases(format Format) []Format {
switch format {
case "codex":
return []Format{"codex", "openai-response"}
case "openai-response":
return []Format{"openai-response", "codex"}
default:
return []Format{format}
}
}
// TranslateRequest converts a payload between schemas, returning the original payload
// if no translator is registered.
func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
r.mu.RLock()
defer r.mu.RUnlock()
for _, fromFormat := range formatAliases(from) {
if byTarget, ok := r.requests[fromFormat]; ok {
for _, toFormat := range formatAliases(to) {
if fn, isOk := byTarget[toFormat]; isOk && fn != nil {
return fn(model, rawJSON, stream)
}
}
if byTarget, ok := r.requests[from]; ok {
if fn, isOk := byTarget[to]; isOk && fn != nil {
return fn(model, rawJSON, stream)
}
}
return rawJSON
@@ -73,13 +57,9 @@ func (r *Registry) HasResponseTransformer(from, to Format) bool {
r.mu.RLock()
defer r.mu.RUnlock()
for _, toFormat := range formatAliases(to) {
if byTarget, ok := r.responses[toFormat]; ok {
for _, fromFormat := range formatAliases(from) {
if _, isOk := byTarget[fromFormat]; isOk {
return true
}
}
if byTarget, ok := r.responses[from]; ok {
if _, isOk := byTarget[to]; isOk {
return true
}
}
return false
@@ -90,13 +70,9 @@ func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model s
r.mu.RLock()
defer r.mu.RUnlock()
for _, toFormat := range formatAliases(to) {
if byTarget, ok := r.responses[toFormat]; ok {
for _, fromFormat := range formatAliases(from) {
if fn, isOk := byTarget[fromFormat]; isOk && fn.Stream != nil {
return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.Stream != nil {
return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return []string{string(rawJSON)}
@@ -107,13 +83,9 @@ func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, mode
r.mu.RLock()
defer r.mu.RUnlock()
for _, toFormat := range formatAliases(to) {
if byTarget, ok := r.responses[toFormat]; ok {
for _, fromFormat := range formatAliases(from) {
if fn, isOk := byTarget[fromFormat]; isOk && fn.NonStream != nil {
return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.NonStream != nil {
return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return string(rawJSON)
@@ -124,13 +96,9 @@ func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, cou
r.mu.RLock()
defer r.mu.RUnlock()
for _, toFormat := range formatAliases(to) {
if byTarget, ok := r.responses[toFormat]; ok {
for _, fromFormat := range formatAliases(from) {
if fn, isOk := byTarget[fromFormat]; isOk && fn.TokenCount != nil {
return fn.TokenCount(ctx, count)
}
}
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.TokenCount != nil {
return fn.TokenCount(ctx, count)
}
}
return string(rawJSON)