Merge pull request #163 from router-for-me/nb

fix(gemini): map responseModalities to uppercase IMAGE/TEXT
This commit is contained in:
Luis Pater
2025-10-26 22:41:18 +08:00
committed by GitHub
6 changed files with 63 additions and 15 deletions
+53 -5
View File
@@ -3,6 +3,7 @@ package executor
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@@ -72,7 +73,7 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
AuthValue: authValue, AuthValue: authValue,
}) })
wsResp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) wsResp, err := e.relay.NonStream(ctx, e.provider, wsReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) recordAPIResponseError(ctx, e.cfg, err)
return resp, err return resp, err
@@ -87,7 +88,7 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), &param) out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)} resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
return resp, nil return resp, nil
} }
@@ -156,7 +157,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
} }
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), &param) lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
} }
break break
} }
@@ -172,7 +173,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
} }
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param) lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param)
for i := range lines { for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
} }
reporter.publish(ctx, parseGeminiUsage(event.Payload)) reporter.publish(ctx, parseGeminiUsage(event.Payload))
return return
@@ -220,7 +221,7 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
AuthType: authType, AuthType: authType,
AuthValue: authValue, AuthValue: authValue,
}) })
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq) resp, err := e.relay.NonStream(ctx, e.provider, wsReq)
if err != nil { if err != nil {
recordAPIResponseError(ctx, e.cfg, err) recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err return cliproxyexecutor.Response{}, err
@@ -346,3 +347,50 @@ func stripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
} }
return cleaned, true return cleaned, true
} }
// ensureColonSpacedJSON normalizes JSON objects so that colons are followed by a single space while
// keeping the payload otherwise compact. Non-JSON inputs are returned unchanged.
func ensureColonSpacedJSON(payload []byte) []byte {
trimmed := bytes.TrimSpace(payload)
if len(trimmed) == 0 {
return payload
}
var decoded any
if err := json.Unmarshal(trimmed, &decoded); err != nil {
return payload
}
indented, err := json.MarshalIndent(decoded, "", " ")
if err != nil {
return payload
}
compacted := make([]byte, 0, len(indented))
inString := false
skipSpace := false
for i := 0; i < len(indented); i++ {
ch := indented[i]
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
inString = !inString
}
if !inString {
if ch == '\n' || ch == '\r' {
skipSpace = true
continue
}
if skipSpace {
if ch == ' ' || ch == '\t' {
continue
}
skipSpace = false
}
}
compacted = append(compacted, ch)
}
return compacted
}
@@ -703,7 +703,7 @@ func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
} }
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["Image", "Text"]`)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
} }
} }
rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig") rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig")
+1 -1
View File
@@ -494,7 +494,7 @@ func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
} }
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["Image", "Text"]`)) rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
} }
} }
rawJSON, _ = sjson.DeleteBytes(rawJSON, "generationConfig.imageConfig") rawJSON, _ = sjson.DeleteBytes(rawJSON, "generationConfig.imageConfig")
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
} }
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities // Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["Image", "Text"] // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() { if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string var responseMods []string
for _, m := range mods.Array() { for _, m := range mods.Array() {
switch strings.ToLower(m.String()) { switch strings.ToLower(m.String()) {
case "text": case "text":
responseMods = append(responseMods, "Text") responseMods = append(responseMods, "TEXT")
case "image": case "image":
responseMods = append(responseMods, "Image") responseMods = append(responseMods, "IMAGE")
} }
} }
if len(responseMods) > 0 { if len(responseMods) > 0 {
@@ -66,15 +66,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} }
// Map OpenAI modalities -> Gemini generationConfig.responseModalities // Map OpenAI modalities -> Gemini generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["Image", "Text"] // e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() { if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string var responseMods []string
for _, m := range mods.Array() { for _, m := range mods.Array() {
switch strings.ToLower(m.String()) { switch strings.ToLower(m.String()) {
case "text": case "text":
responseMods = append(responseMods, "Text") responseMods = append(responseMods, "TEXT")
case "image": case "image":
responseMods = append(responseMods, "Image") responseMods = append(responseMods, "IMAGE")
} }
} }
if len(responseMods) > 0 { if len(responseMods) > 0 {
+2 -2
View File
@@ -35,8 +35,8 @@ type StreamEvent struct {
Err error Err error
} }
// RoundTrip executes a non-streaming HTTP request using the websocket provider. // NonStream executes a non-streaming HTTP request using the websocket provider.
func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) { func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
if req == nil { if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil") return nil, fmt.Errorf("wsrelay: request is nil")
} }