- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
265 lines
9.2 KiB
Go
265 lines
9.2 KiB
Go
package executor
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
// AistudioExecutor routes AI Studio requests through a websocket-backed transport.
|
|
type AistudioExecutor struct {
|
|
provider string
|
|
relay *wsrelay.Manager
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewAistudioExecutor constructs a websocket executor for the provider name.
|
|
func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor {
|
|
return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
|
|
}
|
|
|
|
// Identifier returns the provider key served by this executor.
|
|
func (e *AistudioExecutor) Identifier() string { return e.provider }
|
|
|
|
// PrepareRequest is a no-op because websocket transport already injects headers.
|
|
func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
|
return nil
|
|
}
|
|
|
|
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
translatedReq, body, err := e.translateRequest(req, opts, false)
|
|
if err != nil {
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
|
wsReq := &wsrelay.HTTPRequest{
|
|
Method: http.MethodPost,
|
|
URL: endpoint,
|
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
|
Body: body.payload,
|
|
}
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: endpoint,
|
|
Method: http.MethodPost,
|
|
Headers: wsReq.Headers.Clone(),
|
|
Body: bytes.Clone(body.payload),
|
|
Provider: e.provider,
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
|
if len(resp.Body) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
|
}
|
|
if resp.Status < 200 || resp.Status >= 300 {
|
|
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
|
}
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m)
|
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
}
|
|
|
|
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
|
translatedReq, body, err := e.translateRequest(req, opts, true)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
|
wsReq := &wsrelay.HTTPRequest{
|
|
Method: http.MethodPost,
|
|
URL: endpoint,
|
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
|
Body: body.payload,
|
|
}
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: endpoint,
|
|
Method: http.MethodPost,
|
|
Headers: wsReq.Headers.Clone(),
|
|
Body: bytes.Clone(body.payload),
|
|
Provider: e.provider,
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
stream, err := e.relay.Stream(ctx, e.provider, wsReq)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return nil, err
|
|
}
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
go func() {
|
|
defer close(out)
|
|
var param any
|
|
metadataLogged := false
|
|
for event := range stream {
|
|
if event.Err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, event.Err)
|
|
out <- cliproxyexecutor.StreamChunk{Err: event.Err}
|
|
return
|
|
}
|
|
switch event.Type {
|
|
case wsrelay.MessageTypeStreamStart:
|
|
if !metadataLogged && event.Status > 0 {
|
|
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
|
metadataLogged = true
|
|
}
|
|
case wsrelay.MessageTypeStreamChunk:
|
|
if len(event.Payload) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
case wsrelay.MessageTypeStreamEnd:
|
|
return
|
|
case wsrelay.MessageTypeHTTPResp:
|
|
if !metadataLogged && event.Status > 0 {
|
|
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
|
metadataLogged = true
|
|
}
|
|
if len(event.Payload) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
return
|
|
case wsrelay.MessageTypeError:
|
|
recordAPIResponseError(ctx, e.cfg, event.Err)
|
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
return out, nil
|
|
}
|
|
|
|
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
translatedReq, body, err := e.translateRequest(req, opts, false)
|
|
if err != nil {
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
|
|
wsReq := &wsrelay.HTTPRequest{
|
|
Method: http.MethodPost,
|
|
URL: endpoint,
|
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
|
Body: body.payload,
|
|
}
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: endpoint,
|
|
Method: http.MethodPost,
|
|
Headers: wsReq.Headers.Clone(),
|
|
Body: bytes.Clone(body.payload),
|
|
Provider: e.provider,
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
|
if len(resp.Body) > 0 {
|
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
|
}
|
|
if resp.Status < 200 || resp.Status >= 300 {
|
|
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
|
}
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), ¶m)
|
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
}
|
|
|
|
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
_ = ctx
|
|
return auth, nil
|
|
}
|
|
|
|
type translatedPayload struct {
|
|
payload []byte
|
|
action string
|
|
toFormat sdktranslator.Format
|
|
}
|
|
|
|
func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
|
|
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
|
}
|
|
payload = disableGeminiThinkingConfig(payload, req.Model)
|
|
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
|
metadataAction := "generateContent"
|
|
if req.Metadata != nil {
|
|
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
|
|
metadataAction = action
|
|
}
|
|
}
|
|
action := metadataAction
|
|
if stream && action != "countTokens" {
|
|
action = "streamGenerateContent"
|
|
}
|
|
payload, _ = sjson.DeleteBytes(payload, "session_id")
|
|
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
|
|
}
|
|
|
|
func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
|
|
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
|
|
if action == "streamGenerateContent" {
|
|
if alt == "" {
|
|
return base + "?alt=sse"
|
|
}
|
|
return base + "?$alt=" + url.QueryEscape(alt)
|
|
}
|
|
if alt != "" && action != "countTokens" {
|
|
return base + "?$alt=" + url.QueryEscape(alt)
|
|
}
|
|
return base
|
|
}
|