Merge pull request #488 from router-for-me/gemini
Unify the Gemini executor style
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
// Package executor provides runtime execution capabilities for various AI service providers.
|
||||||
|
// This file implements the AI Studio executor that routes requests through a websocket-backed
|
||||||
|
// transport for the AI Studio provider.
|
||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -26,19 +29,28 @@ type AIStudioExecutor struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAIStudioExecutor constructs a websocket executor for the provider name.
|
// NewAIStudioExecutor creates a new AI Studio executor instance.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - provider: The provider name
|
||||||
|
// - relay: The websocket relay manager
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *AIStudioExecutor: A new AI Studio executor instance
|
||||||
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
|
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
|
||||||
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
|
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identifier returns the logical provider key for routing.
|
// Identifier returns the executor identifier.
|
||||||
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
|
||||||
|
|
||||||
// PrepareRequest is a no-op because websocket transport already injects headers.
|
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
|
||||||
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute performs a non-streaming request to the AI Studio API.
|
||||||
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
@@ -92,6 +104,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteStream performs a streaming request to the AI Studio API.
|
||||||
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
@@ -239,6 +252,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
|||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountTokens counts tokens for the given request using the AI Studio API.
|
||||||
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
_, body, err := e.translateRequest(req, opts, false)
|
_, body, err := e.translateRequest(req, opts, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -293,8 +307,8 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
// Refresh refreshes the authentication credentials (no-op for AI Studio).
|
||||||
_ = ctx
|
func (e *AIStudioExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Package executor provides runtime execution capabilities for various AI service providers.
|
||||||
|
// This file implements the Antigravity executor that proxies requests to the antigravity
|
||||||
|
// upstream using OAuth credentials.
|
||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -38,7 +41,6 @@ const (
|
|||||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||||
antigravityAuthType = "antigravity"
|
antigravityAuthType = "antigravity"
|
||||||
refreshSkew = 3000 * time.Second
|
refreshSkew = 3000 * time.Second
|
||||||
streamScannerBuffer int = 52_428_800 // 50MB
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
@@ -48,18 +50,24 @@ type AntigravityExecutor struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAntigravityExecutor constructs a new executor instance.
|
// NewAntigravityExecutor creates a new Antigravity executor instance.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *AntigravityExecutor: A new Antigravity executor instance
|
||||||
func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
|
func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
|
||||||
return &AntigravityExecutor{cfg: cfg}
|
return &AntigravityExecutor{cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identifier implements ProviderExecutor.
|
// Identifier returns the executor identifier.
|
||||||
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
|
||||||
|
|
||||||
// PrepareRequest implements ProviderExecutor.
|
// PrepareRequest prepares the HTTP request for execution (no-op for Antigravity).
|
||||||
func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||||
|
|
||||||
// Execute handles non-streaming requests via the antigravity generate endpoint.
|
// Execute performs a non-streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||||
if errToken != nil {
|
if errToken != nil {
|
||||||
@@ -152,7 +160,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream handles streaming requests via the antigravity upstream.
|
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
ctx = context.WithValue(ctx, "alt", "")
|
ctx = context.WithValue(ctx, "alt", "")
|
||||||
|
|
||||||
@@ -292,7 +300,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh refreshes the OAuth token using the refresh token.
|
// Refresh refreshes the authentication credentials using the refresh token.
|
||||||
func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
return auth, nil
|
return auth, nil
|
||||||
@@ -304,7 +312,7 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
|
|||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens is not supported for the antigravity provider.
|
// CountTokens counts tokens for the given request (not supported for Antigravity).
|
||||||
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
|
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Package executor provides runtime execution capabilities for various AI service providers.
|
||||||
|
// This file implements the Gemini CLI executor that talks to Cloud Code Assist endpoints
|
||||||
|
// using OAuth credentials from auth metadata.
|
||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -29,11 +32,11 @@ import (
|
|||||||
const (
|
const (
|
||||||
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
|
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
codeAssistVersion = "v1internal"
|
codeAssistVersion = "v1internal"
|
||||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
geminiOauthClientSecret = "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"
|
geminiOAuthClientSecret = "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"
|
||||||
)
|
)
|
||||||
|
|
||||||
var geminiOauthScopes = []string{
|
var geminiOAuthScopes = []string{
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
@@ -44,14 +47,24 @@ type GeminiCLIExecutor struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewGeminiCLIExecutor creates a new Gemini CLI executor instance.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *GeminiCLIExecutor: A new Gemini CLI executor instance
|
||||||
func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
|
func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
|
||||||
return &GeminiCLIExecutor{cfg: cfg}
|
return &GeminiCLIExecutor{cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Identifier returns the executor identifier.
|
||||||
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
|
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
|
||||||
|
|
||||||
|
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini CLI).
|
||||||
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
func (e *GeminiCLIExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
||||||
|
|
||||||
|
// Execute performs a non-streaming request to the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -189,6 +202,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
|||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteStream performs a streaming request to the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -309,7 +323,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
}()
|
}()
|
||||||
if opts.Alt == "" {
|
if opts.Alt == "" {
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
@@ -371,6 +385,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountTokens counts tokens for the given request using the Gemini CLI API.
|
||||||
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -471,9 +486,8 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
|||||||
return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody)
|
return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
// Refresh refreshes the authentication credentials (no-op for Gemini CLI).
|
||||||
log.Debugf("gemini cli executor: refresh called")
|
func (e *GeminiCLIExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
_ = ctx
|
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -515,9 +529,9 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
|
|||||||
}
|
}
|
||||||
|
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: geminiOauthClientID,
|
ClientID: geminiOAuthClientID,
|
||||||
ClientSecret: geminiOauthClientSecret,
|
ClientSecret: geminiOAuthClientSecret,
|
||||||
Scopes: geminiOauthScopes,
|
Scopes: geminiOAuthScopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
@@ -21,8 +20,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"golang.org/x/oauth2/google"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -31,6 +28,9 @@ const (
|
|||||||
|
|
||||||
// glAPIVersion is the API version used for Gemini requests.
|
// glAPIVersion is the API version used for Gemini requests.
|
||||||
glAPIVersion = "v1beta"
|
glAPIVersion = "v1beta"
|
||||||
|
|
||||||
|
// streamScannerBuffer is the buffer size for SSE stream scanning.
|
||||||
|
streamScannerBuffer = 52_428_800
|
||||||
)
|
)
|
||||||
|
|
||||||
// GeminiExecutor is a stateless executor for the official Gemini API using API keys.
|
// GeminiExecutor is a stateless executor for the official Gemini API using API keys.
|
||||||
@@ -48,9 +48,11 @@ type GeminiExecutor struct {
|
|||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *GeminiExecutor: A new Gemini executor instance
|
// - *GeminiExecutor: A new Gemini executor instance
|
||||||
func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor { return &GeminiExecutor{cfg: cfg} }
|
func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor {
|
||||||
|
return &GeminiExecutor{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
// Identifier returns the executor identifier for Gemini.
|
// Identifier returns the executor identifier.
|
||||||
func (e *GeminiExecutor) Identifier() string { return "gemini" }
|
func (e *GeminiExecutor) Identifier() string { return "gemini" }
|
||||||
|
|
||||||
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
|
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
|
||||||
@@ -164,6 +166,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecuteStream performs a streaming request to the Gemini API.
|
||||||
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
@@ -249,7 +252,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
scanner := bufio.NewScanner(httpResp.Body)
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
@@ -280,6 +283,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountTokens counts tokens for the given request using the Gemini API.
|
||||||
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
apiKey, bearer := geminiCreds(auth)
|
apiKey, bearer := geminiCreds(auth)
|
||||||
|
|
||||||
@@ -353,106 +357,8 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
|||||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *GeminiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
|
||||||
log.Debugf("gemini executor: refresh called")
|
func (e *GeminiExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
// OAuth bearer token refresh for official Gemini API.
|
|
||||||
if auth == nil {
|
|
||||||
return nil, fmt.Errorf("gemini executor: auth is nil")
|
|
||||||
}
|
|
||||||
if auth.Metadata == nil {
|
|
||||||
return auth, nil
|
|
||||||
}
|
|
||||||
// Token data is typically nested under "token" map in Gemini files.
|
|
||||||
tokenMap, _ := auth.Metadata["token"].(map[string]any)
|
|
||||||
var refreshToken, accessToken, clientID, clientSecret, tokenURI, expiryStr string
|
|
||||||
if tokenMap != nil {
|
|
||||||
if v, ok := tokenMap["refresh_token"].(string); ok {
|
|
||||||
refreshToken = v
|
|
||||||
}
|
|
||||||
if v, ok := tokenMap["access_token"].(string); ok {
|
|
||||||
accessToken = v
|
|
||||||
}
|
|
||||||
if v, ok := tokenMap["client_id"].(string); ok {
|
|
||||||
clientID = v
|
|
||||||
}
|
|
||||||
if v, ok := tokenMap["client_secret"].(string); ok {
|
|
||||||
clientSecret = v
|
|
||||||
}
|
|
||||||
if v, ok := tokenMap["token_uri"].(string); ok {
|
|
||||||
tokenURI = v
|
|
||||||
}
|
|
||||||
if v, ok := tokenMap["expiry"].(string); ok {
|
|
||||||
expiryStr = v
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Fallback to top-level keys if present
|
|
||||||
if v, ok := auth.Metadata["refresh_token"].(string); ok {
|
|
||||||
refreshToken = v
|
|
||||||
}
|
|
||||||
if v, ok := auth.Metadata["access_token"].(string); ok {
|
|
||||||
accessToken = v
|
|
||||||
}
|
|
||||||
if v, ok := auth.Metadata["client_id"].(string); ok {
|
|
||||||
clientID = v
|
|
||||||
}
|
|
||||||
if v, ok := auth.Metadata["client_secret"].(string); ok {
|
|
||||||
clientSecret = v
|
|
||||||
}
|
|
||||||
if v, ok := auth.Metadata["token_uri"].(string); ok {
|
|
||||||
tokenURI = v
|
|
||||||
}
|
|
||||||
if v, ok := auth.Metadata["expiry"].(string); ok {
|
|
||||||
expiryStr = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if refreshToken == "" {
|
|
||||||
// Nothing to do for API key or cookie based entries
|
|
||||||
return auth, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare oauth2 config; default to Google endpoints
|
|
||||||
endpoint := google.Endpoint
|
|
||||||
if tokenURI != "" {
|
|
||||||
endpoint.TokenURL = tokenURI
|
|
||||||
}
|
|
||||||
conf := &oauth2.Config{ClientID: clientID, ClientSecret: clientSecret, Endpoint: endpoint}
|
|
||||||
|
|
||||||
// Ensure proxy-aware HTTP client for token refresh
|
|
||||||
httpClient := util.SetProxy(&e.cfg.SDKConfig, &http.Client{})
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
|
||||||
|
|
||||||
// Build base token
|
|
||||||
tok := &oauth2.Token{AccessToken: accessToken, RefreshToken: refreshToken}
|
|
||||||
if t, err := time.Parse(time.RFC3339, expiryStr); err == nil {
|
|
||||||
tok.Expiry = t
|
|
||||||
}
|
|
||||||
newTok, err := conf.TokenSource(ctx, tok).Token()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Persist back to metadata; prefer nested token map if present
|
|
||||||
if tokenMap == nil {
|
|
||||||
tokenMap = make(map[string]any)
|
|
||||||
}
|
|
||||||
tokenMap["access_token"] = newTok.AccessToken
|
|
||||||
tokenMap["refresh_token"] = newTok.RefreshToken
|
|
||||||
tokenMap["expiry"] = newTok.Expiry.Format(time.RFC3339)
|
|
||||||
if clientID != "" {
|
|
||||||
tokenMap["client_id"] = clientID
|
|
||||||
}
|
|
||||||
if clientSecret != "" {
|
|
||||||
tokenMap["client_secret"] = clientSecret
|
|
||||||
}
|
|
||||||
if tokenURI != "" {
|
|
||||||
tokenMap["token_uri"] = tokenURI
|
|
||||||
}
|
|
||||||
auth.Metadata["token"] = tokenMap
|
|
||||||
|
|
||||||
// Also mirror top-level access_token for compatibility if previously present
|
|
||||||
if _, ok := auth.Metadata["access_token"]; ok {
|
|
||||||
auth.Metadata["access_token"] = newTok.AccessToken
|
|
||||||
}
|
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// Package executor contains provider executors. This file implements the Vertex AI
|
// Package executor provides runtime execution capabilities for various AI service providers.
|
||||||
// Gemini executor that talks to Google Vertex AI endpoints using service account
|
// This file implements the Vertex AI Gemini executor that talks to Google Vertex AI
|
||||||
// credentials imported by the CLI.
|
// endpoints using service account credentials or API keys.
|
||||||
package executor
|
package executor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -36,20 +36,26 @@ type GeminiVertexExecutor struct {
|
|||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGeminiVertexExecutor constructs the Vertex executor.
|
// NewGeminiVertexExecutor creates a new Vertex AI Gemini executor instance.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *GeminiVertexExecutor: A new Vertex AI Gemini executor instance
|
||||||
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
||||||
return &GeminiVertexExecutor{cfg: cfg}
|
return &GeminiVertexExecutor{cfg: cfg}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identifier returns provider key for manager routing.
|
// Identifier returns the executor identifier.
|
||||||
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
||||||
|
|
||||||
// PrepareRequest is a no-op for Vertex.
|
// PrepareRequest prepares the HTTP request for execution (no-op for Vertex).
|
||||||
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute handles non-streaming requests.
|
// Execute performs a non-streaming request to the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
// Try API key authentication first
|
// Try API key authentication first
|
||||||
apiKey, baseURL := vertexAPICreds(auth)
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
@@ -67,7 +73,7 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecuteStream handles SSE streaming for Vertex.
|
// ExecuteStream performs a streaming request to the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
// Try API key authentication first
|
// Try API key authentication first
|
||||||
apiKey, baseURL := vertexAPICreds(auth)
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
@@ -85,7 +91,7 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens calls Vertex countTokens endpoint.
|
// CountTokens counts tokens for the given request using the Vertex AI API.
|
||||||
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
// Try API key authentication first
|
// Try API key authentication first
|
||||||
apiKey, baseURL := vertexAPICreds(auth)
|
apiKey, baseURL := vertexAPICreds(auth)
|
||||||
@@ -103,185 +109,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
|||||||
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// countTokensWithServiceAccount handles token counting using service account credentials.
|
// Refresh refreshes the authentication credentials (no-op for Vertex).
|
||||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
|
||||||
to := sdktranslator.FromString("gemini")
|
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
||||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
||||||
|
|
||||||
baseURL := vertexBaseURL(location)
|
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
|
||||||
if errNewReq != nil {
|
|
||||||
return cliproxyexecutor.Response{}, errNewReq
|
|
||||||
}
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
|
||||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
||||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
||||||
} else if errTok != nil {
|
|
||||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
|
||||||
}
|
|
||||||
applyGeminiHeaders(httpReq, auth)
|
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
|
||||||
if auth != nil {
|
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
|
||||||
authType, authValue = auth.AccountInfo()
|
|
||||||
}
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
||||||
URL: url,
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Headers: httpReq.Header.Clone(),
|
|
||||||
Body: translatedReq,
|
|
||||||
Provider: e.Identifier(),
|
|
||||||
AuthID: authID,
|
|
||||||
AuthLabel: authLabel,
|
|
||||||
AuthType: authType,
|
|
||||||
AuthValue: authValue,
|
|
||||||
})
|
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
|
||||||
if errDo != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
||||||
return cliproxyexecutor.Response{}, errDo
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
||||||
}
|
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
return cliproxyexecutor.Response{}, errRead
|
|
||||||
}
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
||||||
}
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
|
||||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
|
||||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
||||||
|
|
||||||
from := opts.SourceFormat
|
|
||||||
to := sdktranslator.FromString("gemini")
|
|
||||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
||||||
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
||||||
if budgetOverride != nil {
|
|
||||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
||||||
budgetOverride = &norm
|
|
||||||
}
|
|
||||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
||||||
}
|
|
||||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
||||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
|
||||||
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
|
||||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
||||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
||||||
|
|
||||||
// For API key auth, use simpler URL format without project/location
|
|
||||||
if baseURL == "" {
|
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
|
||||||
}
|
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
|
||||||
|
|
||||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
|
||||||
if errNewReq != nil {
|
|
||||||
return cliproxyexecutor.Response{}, errNewReq
|
|
||||||
}
|
|
||||||
httpReq.Header.Set("Content-Type", "application/json")
|
|
||||||
if apiKey != "" {
|
|
||||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
||||||
}
|
|
||||||
applyGeminiHeaders(httpReq, auth)
|
|
||||||
|
|
||||||
var authID, authLabel, authType, authValue string
|
|
||||||
if auth != nil {
|
|
||||||
authID = auth.ID
|
|
||||||
authLabel = auth.Label
|
|
||||||
authType, authValue = auth.AccountInfo()
|
|
||||||
}
|
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
||||||
URL: url,
|
|
||||||
Method: http.MethodPost,
|
|
||||||
Headers: httpReq.Header.Clone(),
|
|
||||||
Body: translatedReq,
|
|
||||||
Provider: e.Identifier(),
|
|
||||||
AuthID: authID,
|
|
||||||
AuthLabel: authLabel,
|
|
||||||
AuthType: authType,
|
|
||||||
AuthValue: authValue,
|
|
||||||
})
|
|
||||||
|
|
||||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
||||||
httpResp, errDo := httpClient.Do(httpReq)
|
|
||||||
if errDo != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
||||||
return cliproxyexecutor.Response{}, errDo
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
||||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
b, _ := io.ReadAll(httpResp.Body)
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
||||||
}
|
|
||||||
data, errRead := io.ReadAll(httpResp.Body)
|
|
||||||
if errRead != nil {
|
|
||||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
||||||
return cliproxyexecutor.Response{}, errRead
|
|
||||||
}
|
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
||||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
||||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
||||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
||||||
}
|
|
||||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
||||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
|
||||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh is a no-op for service account based credentials.
|
|
||||||
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
@@ -579,7 +407,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
scanner := bufio.NewScanner(httpResp.Body)
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
@@ -696,7 +524,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
scanner := bufio.NewScanner(httpResp.Body)
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
scanner.Buffer(nil, 52_428_800) // 50MB
|
scanner.Buffer(nil, streamScannerBuffer)
|
||||||
var param any
|
var param any
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Bytes()
|
line := scanner.Bytes()
|
||||||
@@ -722,6 +550,184 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// countTokensWithServiceAccount counts tokens using service account credentials.
|
||||||
|
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||||
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||||
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||||
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||||
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
|
baseURL := vertexBaseURL(location)
|
||||||
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
} else if errTok != nil {
|
||||||
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translatedReq,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return cliproxyexecutor.Response{}, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return cliproxyexecutor.Response{}, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||||
|
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||||
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||||
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||||
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
||||||
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
|
// For API key auth, use simpler URL format without project/location
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
|
}
|
||||||
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if apiKey != "" {
|
||||||
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translatedReq,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return cliproxyexecutor.Response{}, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return cliproxyexecutor.Response{}, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||||
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
|
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
|
||||||
if a == nil || a.Metadata == nil {
|
if a == nil || a.Metadata == nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user