Compare commits
9 Commits
v6.6.102
...
v6.6.109-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7248f65c36 | ||
|
|
086eb3df7a | ||
|
|
5a7e5bd870 | ||
|
|
6f8a8f8136 | ||
|
|
b163f8ed9e | ||
|
|
a1da6ff5ac | ||
|
|
43652d044c | ||
|
|
b1b379ea18 | ||
|
|
21ac161b21 |
@@ -61,6 +61,7 @@ func main() {
|
|||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
var iflowCookie bool
|
var iflowCookie bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
@@ -75,6 +76,7 @@ func main() {
|
|||||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||||
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
|
||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
|
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
@@ -425,7 +427,8 @@ func main() {
|
|||||||
|
|
||||||
// Create login options to be used in authentication flows.
|
// Create login options to be used in authentication flows.
|
||||||
options := &cmd.LoginOptions{
|
options := &cmd.LoginOptions{
|
||||||
NoBrowser: noBrowser,
|
NoBrowser: noBrowser,
|
||||||
|
CallbackPort: oauthCallbackPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register the shared token store once so all components use the same persistence backend.
|
// Register the shared token store once so all components use the same persistence backend.
|
||||||
|
|||||||
@@ -77,6 +77,9 @@ routing:
|
|||||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
ws-auth: false
|
ws-auth: false
|
||||||
|
|
||||||
|
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
|
||||||
|
nonstream-keepalive-interval: 0
|
||||||
|
|
||||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||||
# streaming:
|
# streaming:
|
||||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
geminiOauthClientSecret = "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"
|
geminiOauthClientSecret = "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"
|
||||||
|
geminiDefaultCallbackPort = 8085
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -49,8 +50,9 @@ type GeminiAuth struct {
|
|||||||
|
|
||||||
// WebLoginOptions customizes the interactive OAuth flow.
|
// WebLoginOptions customizes the interactive OAuth flow.
|
||||||
type WebLoginOptions struct {
|
type WebLoginOptions struct {
|
||||||
NoBrowser bool
|
NoBrowser bool
|
||||||
Prompt func(string) (string, error)
|
CallbackPort int
|
||||||
|
Prompt func(string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||||
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
|
|||||||
// - *http.Client: An HTTP client configured with authentication
|
// - *http.Client: An HTTP client configured with authentication
|
||||||
// - error: An error if the client configuration fails, nil otherwise
|
// - error: An error if the client configuration fails, nil otherwise
|
||||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||||
|
callbackPort := geminiDefaultCallbackPort
|
||||||
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -106,7 +114,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
|||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: geminiOauthClientID,
|
ClientID: geminiOauthClientID,
|
||||||
ClientSecret: geminiOauthClientSecret,
|
ClientSecret: geminiOauthClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
RedirectURL: callbackURL, // This will be used by the local server.
|
||||||
Scopes: geminiOauthScopes,
|
Scopes: geminiOauthScopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
|||||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||||
// - error: An error if the token acquisition fails, nil otherwise
|
// - error: An error if the token acquisition fails, nil otherwise
|
||||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||||
|
callbackPort := geminiDefaultCallbackPort
|
||||||
|
if opts != nil && opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
|
||||||
|
|
||||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||||
codeChan := make(chan string, 1)
|
codeChan := make(chan string, 1)
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
// Create a new HTTP server with its own multiplexer.
|
// Create a new HTTP server with its own multiplexer.
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
|
||||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
config.RedirectURL = callbackURL
|
||||||
|
|
||||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := r.URL.Query().Get("error"); err != "" {
|
if err := r.URL.Query().Get("error"); err != "" {
|
||||||
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
// Check if browser is available
|
// Check if browser is available
|
||||||
if !browser.IsAvailable() {
|
if !browser.IsAvailable() {
|
||||||
log.Warn("No browser available on this system")
|
log.Warn("No browser available on this system")
|
||||||
util.PrintSSHTunnelInstructions(8085)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
} else {
|
} else {
|
||||||
if err := browser.OpenURL(authURL); err != nil {
|
if err := browser.OpenURL(authURL); err != nil {
|
||||||
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||||
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||||
util.PrintSSHTunnelInstructions(8085)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
|
||||||
// Log platform info for debugging
|
// Log platform info for debugging
|
||||||
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
util.PrintSSHTunnelInstructions(8085)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
manager := newAuthManager()
|
manager := newAuthManager()
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||||
|
|||||||
@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
|
|
||||||
manager := newAuthManager()
|
manager := newAuthManager()
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||||
|
|||||||
@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
|
||||||
|
|||||||
@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginOpts := &sdkAuth.LoginOptions{
|
loginOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
ProjectID: trimmedProjectID,
|
ProjectID: trimmedProjectID,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: callbackPrompt,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: callbackPrompt,
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||||
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
|||||||
|
|
||||||
geminiAuth := gemini.NewGeminiAuth()
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Prompt: callbackPrompt,
|
CallbackPort: options.CallbackPort,
|
||||||
|
Prompt: callbackPrompt,
|
||||||
})
|
})
|
||||||
if errClient != nil {
|
if errClient != nil {
|
||||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ type LoginOptions struct {
|
|||||||
// NoBrowser indicates whether to skip opening the browser automatically.
|
// NoBrowser indicates whether to skip opening the browser automatically.
|
||||||
NoBrowser bool
|
NoBrowser bool
|
||||||
|
|
||||||
|
// CallbackPort overrides the local OAuth callback port when set (>0).
|
||||||
|
CallbackPort int
|
||||||
|
|
||||||
// Prompt allows the caller to provide interactive input when needed.
|
// Prompt allows the caller to provide interactive input when needed.
|
||||||
Prompt func(prompt string) (string, error)
|
Prompt func(prompt string) (string, error)
|
||||||
}
|
}
|
||||||
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
manager := newAuthManager()
|
manager := newAuthManager()
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||||
|
|||||||
@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
authOpts := &sdkAuth.LoginOptions{
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
NoBrowser: options.NoBrowser,
|
NoBrowser: options.NoBrowser,
|
||||||
Metadata: map[string]string{},
|
CallbackPort: options.CallbackPort,
|
||||||
Prompt: promptFn,
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
|
||||||
|
|||||||
@@ -242,6 +242,10 @@ type ClaudeKey struct {
|
|||||||
// APIKey is the authentication key for accessing Claude API services.
|
// APIKey is the authentication key for accessing Claude API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -280,6 +284,10 @@ type CodexKey struct {
|
|||||||
// APIKey is the authentication key for accessing Codex API services.
|
// APIKey is the authentication key for accessing Codex API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -318,6 +326,10 @@ type GeminiKey struct {
|
|||||||
// APIKey is the authentication key for accessing Gemini API services.
|
// APIKey is the authentication key for accessing Gemini API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -355,6 +367,10 @@ type OpenAICompatibility struct {
|
|||||||
// Name is the identifier for this OpenAI compatibility configuration.
|
// Name is the identifier for this OpenAI compatibility configuration.
|
||||||
Name string `yaml:"name" json:"name"`
|
Name string `yaml:"name" json:"name"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple providers or credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,10 @@ type SDKConfig struct {
|
|||||||
|
|
||||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||||
|
|
||||||
|
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
|
||||||
|
// <= 0 disables keep-alives. Value is in seconds.
|
||||||
|
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamingConfig holds server streaming behavior configuration.
|
// StreamingConfig holds server streaming behavior configuration.
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ type VertexCompatKey struct {
|
|||||||
// Maps to the x-goog-api-key header.
|
// Maps to the x-goog-api-key header.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
@@ -1104,12 +1105,49 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
|||||||
auth.Metadata["refresh_token"] = tokenResp.RefreshToken
|
auth.Metadata["refresh_token"] = tokenResp.RefreshToken
|
||||||
}
|
}
|
||||||
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
|
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
|
||||||
auth.Metadata["timestamp"] = time.Now().UnixMilli()
|
now := time.Now()
|
||||||
auth.Metadata["expired"] = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
auth.Metadata["timestamp"] = now.UnixMilli()
|
||||||
|
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||||
auth.Metadata["type"] = antigravityAuthType
|
auth.Metadata["type"] = antigravityAuthType
|
||||||
|
if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil {
|
||||||
|
log.Warnf("antigravity executor: ensure project id failed: %v", errProject)
|
||||||
|
}
|
||||||
return auth, nil
|
return auth, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) error {
|
||||||
|
if auth == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth.Metadata["project_id"] != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimSpace(accessToken)
|
||||||
|
if token == "" {
|
||||||
|
token = metaStringValue(auth.Metadata, "access_token")
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
|
||||||
|
if errFetch != nil {
|
||||||
|
return errFetch
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(projectID) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if auth.Metadata == nil {
|
||||||
|
auth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
|
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
|||||||
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
|
||||||
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
|
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
|
||||||
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
|
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
|
||||||
|
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
|
||||||
out = append(out, emitEvent("response.output_item.done", itemDone))
|
out = append(out, emitEvent("response.output_item.done", itemDone))
|
||||||
st.InFuncBlock = false
|
st.InFuncBlock = false
|
||||||
} else if st.ReasoningActive {
|
} else if st.ReasoningActive {
|
||||||
|
|||||||
@@ -54,6 +54,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
||||||
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
||||||
}
|
}
|
||||||
|
if oldCfg.NonStreamKeepAliveInterval != newCfg.NonStreamKeepAliveInterval {
|
||||||
|
changes = append(changes, fmt.Sprintf("nonstream-keepalive-interval: %d -> %d", oldCfg.NonStreamKeepAliveInterval, newCfg.NonStreamKeepAliveInterval))
|
||||||
|
}
|
||||||
|
|
||||||
// Quota-exceeded behavior
|
// Quota-exceeded behavior
|
||||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||||
|
|||||||
@@ -231,10 +231,11 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
|||||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||||
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
||||||
SDKConfig: sdkconfig.SDKConfig{
|
SDKConfig: sdkconfig.SDKConfig{
|
||||||
RequestLog: false,
|
RequestLog: false,
|
||||||
ProxyURL: "http://old-proxy",
|
ProxyURL: "http://old-proxy",
|
||||||
APIKeys: []string{"key-1"},
|
APIKeys: []string{"key-1"},
|
||||||
ForceModelPrefix: false,
|
ForceModelPrefix: false,
|
||||||
|
NonStreamKeepAliveInterval: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
newCfg := &config.Config{
|
newCfg := &config.Config{
|
||||||
@@ -267,10 +268,11 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
|||||||
SecretKey: "",
|
SecretKey: "",
|
||||||
},
|
},
|
||||||
SDKConfig: sdkconfig.SDKConfig{
|
SDKConfig: sdkconfig.SDKConfig{
|
||||||
RequestLog: true,
|
RequestLog: true,
|
||||||
ProxyURL: "http://new-proxy",
|
ProxyURL: "http://new-proxy",
|
||||||
APIKeys: []string{" key-1 ", "key-2"},
|
APIKeys: []string{" key-1 ", "key-2"},
|
||||||
ForceModelPrefix: true,
|
ForceModelPrefix: true,
|
||||||
|
NonStreamKeepAliveInterval: 5,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,6 +287,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
|||||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||||
expectContains(t, details, "ws-auth: false -> true")
|
expectContains(t, details, "ws-auth: false -> true")
|
||||||
expectContains(t, details, "force-model-prefix: false -> true")
|
expectContains(t, details, "force-model-prefix: false -> true")
|
||||||
|
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
|
||||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package synthesizer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||||
@@ -59,6 +60,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
|
|||||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
|
if entry.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(entry.Priority)
|
||||||
|
}
|
||||||
if base != "" {
|
if base != "" {
|
||||||
attrs["base_url"] = base
|
attrs["base_url"] = base
|
||||||
}
|
}
|
||||||
@@ -103,6 +107,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea
|
|||||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
|
if ck.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||||
|
}
|
||||||
if base != "" {
|
if base != "" {
|
||||||
attrs["base_url"] = base
|
attrs["base_url"] = base
|
||||||
}
|
}
|
||||||
@@ -147,6 +154,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
|||||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
|
if ck.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||||
|
}
|
||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
@@ -202,6 +212,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
|||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
"provider_key": providerName,
|
"provider_key": providerName,
|
||||||
}
|
}
|
||||||
|
if compat.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||||
|
}
|
||||||
if key != "" {
|
if key != "" {
|
||||||
attrs["api_key"] = key
|
attrs["api_key"] = key
|
||||||
}
|
}
|
||||||
@@ -233,6 +246,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
|||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
"provider_key": providerName,
|
"provider_key": providerName,
|
||||||
}
|
}
|
||||||
|
if compat.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||||
|
}
|
||||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||||
attrs["models_hash"] = hash
|
attrs["models_hash"] = hash
|
||||||
}
|
}
|
||||||
@@ -275,6 +291,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
|
|||||||
"base_url": base,
|
"base_url": base,
|
||||||
"provider_key": providerName,
|
"provider_key": providerName,
|
||||||
}
|
}
|
||||||
|
if compat.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||||
|
}
|
||||||
if key != "" {
|
if key != "" {
|
||||||
attrs["api_key"] = key
|
attrs["api_key"] = key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,10 +146,12 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
@@ -159,13 +161,18 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
|||||||
// Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header
|
// Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header
|
||||||
// This fixes title generation and other non-streaming responses that arrive compressed
|
// This fixes title generation and other non-streaming responses that arrive compressed
|
||||||
if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b {
|
if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b {
|
||||||
gzReader, err := gzip.NewReader(bytes.NewReader(resp))
|
gzReader, errGzip := gzip.NewReader(bytes.NewReader(resp))
|
||||||
if err != nil {
|
if errGzip != nil {
|
||||||
log.Warnf("failed to decompress gzipped Claude response: %v", err)
|
log.Warnf("failed to decompress gzipped Claude response: %v", errGzip)
|
||||||
} else {
|
} else {
|
||||||
defer gzReader.Close()
|
defer func() {
|
||||||
if decompressed, err := io.ReadAll(gzReader); err != nil {
|
if errClose := gzReader.Close(); errClose != nil {
|
||||||
log.Warnf("failed to read decompressed Claude response: %v", err)
|
log.Warnf("failed to close Claude gzip reader: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
decompressed, errRead := io.ReadAll(gzReader)
|
||||||
|
if errRead != nil {
|
||||||
|
log.Warnf("failed to read decompressed Claude response: %v", errRead)
|
||||||
} else {
|
} else {
|
||||||
resp = decompressed
|
resp = decompressed
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -336,7 +336,9 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
|||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
alt := h.GetAlt(c)
|
alt := h.GetAlt(c)
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||||
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -113,6 +114,19 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
|||||||
return time.Duration(seconds) * time.Second
|
return time.Duration(seconds) * time.Second
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NonStreamingKeepAliveInterval returns the keep-alive interval for non-streaming responses.
|
||||||
|
// Returning 0 disables keep-alives (default when unset).
|
||||||
|
func NonStreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||||
|
seconds := 0
|
||||||
|
if cfg != nil {
|
||||||
|
seconds = cfg.NonStreamKeepAliveInterval
|
||||||
|
}
|
||||||
|
if seconds <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return time.Duration(seconds) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||||
retries := defaultStreamingBootstrapRetries
|
retries := defaultStreamingBootstrapRetries
|
||||||
@@ -293,6 +307,53 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StartNonStreamingKeepAlive emits blank lines every 5 seconds while waiting for a non-streaming response.
|
||||||
|
// It returns a stop function that must be called before writing the final response.
|
||||||
|
func (h *BaseAPIHandler) StartNonStreamingKeepAlive(c *gin.Context, ctx context.Context) func() {
|
||||||
|
if h == nil || c == nil {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
interval := NonStreamingKeepAliveInterval(h.Cfg)
|
||||||
|
if interval <= 0 {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
stopChan := make(chan struct{})
|
||||||
|
var stopOnce sync.Once
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
stopOnce.Do(func() {
|
||||||
|
close(stopChan)
|
||||||
|
})
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// appendAPIResponse preserves any previously captured API response and appends new data.
|
// appendAPIResponse preserves any previously captured API response and appends new data.
|
||||||
func appendAPIResponse(c *gin.Context, data []byte) {
|
func appendAPIResponse(c *gin.Context, data []byte) {
|
||||||
if c == nil || len(data) == 0 {
|
if c == nil || len(data) == 0 {
|
||||||
|
|||||||
@@ -56,6 +56,14 @@ func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, co
|
|||||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *failOnceStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *failOnceStreamExecutor) Calls() int {
|
func (e *failOnceStreamExecutor) Calls() int {
|
||||||
e.mu.Lock()
|
e.mu.Lock()
|
||||||
defer e.mu.Unlock()
|
defer e.mu.Unlock()
|
||||||
|
|||||||
@@ -524,7 +524,9 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||||
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
cliCancel(errMsg.Error)
|
cliCancel(errMsg.Error)
|
||||||
|
|||||||
@@ -103,20 +103,17 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
|||||||
|
|
||||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
defer func() {
|
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
|
||||||
cliCancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||||
|
stopKeepAlive()
|
||||||
if errMsg != nil {
|
if errMsg != nil {
|
||||||
h.WriteErrorResponse(c, errMsg)
|
h.WriteErrorResponse(c, errMsg)
|
||||||
|
cliCancel(errMsg.Error)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, _ = c.Writer.Write(resp)
|
_, _ = c.Writer.Write(resp)
|
||||||
return
|
cliCancel()
|
||||||
|
|
||||||
// no legacy fallback
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleStreamingResponse handles streaming responses for Gemini models.
|
// handleStreamingResponse handles streaming responses for Gemini models.
|
||||||
|
|||||||
@@ -60,6 +60,11 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callbackPort := antigravityCallbackPort
|
||||||
|
if opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
|
||||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||||
|
|
||||||
state, err := misc.GenerateRandomState()
|
state, err := misc.GenerateRandomState()
|
||||||
@@ -67,7 +72,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
|||||||
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
|
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv, port, cbChan, errServer := startAntigravityCallbackServer()
|
srv, port, cbChan, errServer := startAntigravityCallbackServer(callbackPort)
|
||||||
if errServer != nil {
|
if errServer != nil {
|
||||||
return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer)
|
return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer)
|
||||||
}
|
}
|
||||||
@@ -224,13 +229,16 @@ type callbackResult struct {
|
|||||||
State string
|
State string
|
||||||
}
|
}
|
||||||
|
|
||||||
func startAntigravityCallbackServer() (*http.Server, int, <-chan callbackResult, error) {
|
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||||
addr := fmt.Sprintf(":%d", antigravityCallbackPort)
|
if port <= 0 {
|
||||||
|
port = antigravityCallbackPort
|
||||||
|
}
|
||||||
|
addr := fmt.Sprintf(":%d", port)
|
||||||
listener, err := net.Listen("tcp", addr)
|
listener, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, nil, err
|
return nil, 0, nil, err
|
||||||
}
|
}
|
||||||
port := listener.Addr().(*net.TCPAddr).Port
|
port = listener.Addr().(*net.TCPAddr).Port
|
||||||
resultCh := make(chan callbackResult, 1)
|
resultCh := make(chan callbackResult, 1)
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
@@ -374,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
|
|||||||
// Call loadCodeAssist to get the project
|
// Call loadCodeAssist to get the project
|
||||||
loadReqBody := map[string]any{
|
loadReqBody := map[string]any{
|
||||||
"metadata": map[string]string{
|
"metadata": map[string]string{
|
||||||
"ideType": "IDE_UNSPECIFIED",
|
"ideType": "ANTIGRAVITY",
|
||||||
"platform": "PLATFORM_UNSPECIFIED",
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
"pluginType": "GEMINI",
|
"pluginType": "GEMINI",
|
||||||
},
|
},
|
||||||
@@ -434,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
|
|||||||
}
|
}
|
||||||
|
|
||||||
if projectID == "" {
|
if projectID == "" {
|
||||||
return "", fmt.Errorf("no cloudaicompanionProject in response")
|
tierID := "legacy-tier"
|
||||||
|
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
|
||||||
|
for _, rawTier := range tiers {
|
||||||
|
tier, okTier := rawTier.(map[string]any)
|
||||||
|
if !okTier {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
|
||||||
|
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
|
||||||
|
tierID = strings.TrimSpace(id)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return projectID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return projectID, nil
|
return projectID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
|
||||||
|
// It returns an empty string when the operation times out or completes without a project ID.
|
||||||
|
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
|
||||||
|
if httpClient == nil {
|
||||||
|
httpClient = http.DefaultClient
|
||||||
|
}
|
||||||
|
fmt.Println("Antigravity: onboarding user...", tierID)
|
||||||
|
requestBody := map[string]any{
|
||||||
|
"tierId": tierID,
|
||||||
|
"metadata": map[string]string{
|
||||||
|
"ideType": "ANTIGRAVITY",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rawBody, errMarshal := json.Marshal(requestBody)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", fmt.Errorf("marshal request body: %w", errMarshal)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxAttempts := 5
|
||||||
|
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||||
|
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
|
||||||
|
|
||||||
|
reqCtx := ctx
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
if reqCtx == nil {
|
||||||
|
reqCtx = context.Background()
|
||||||
|
}
|
||||||
|
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
|
||||||
|
|
||||||
|
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
|
||||||
|
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
|
||||||
|
if errRequest != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("create request: %w", errRequest)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", antigravityAPIUserAgent)
|
||||||
|
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
|
||||||
|
req.Header.Set("Client-Metadata", antigravityClientMetadata)
|
||||||
|
|
||||||
|
resp, errDo := httpClient.Do(req)
|
||||||
|
if errDo != nil {
|
||||||
|
cancel()
|
||||||
|
return "", fmt.Errorf("execute request: %w", errDo)
|
||||||
|
}
|
||||||
|
|
||||||
|
bodyBytes, errRead := io.ReadAll(resp.Body)
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("close body error: %v", errClose)
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
if errRead != nil {
|
||||||
|
return "", fmt.Errorf("read response: %w", errRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
var data map[string]any
|
||||||
|
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
|
||||||
|
return "", fmt.Errorf("decode response: %w", errDecode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if done, okDone := data["done"].(bool); okDone && done {
|
||||||
|
projectID := ""
|
||||||
|
if responseData, okResp := data["response"].(map[string]any); okResp {
|
||||||
|
switch projectValue := responseData["cloudaicompanionProject"].(type) {
|
||||||
|
case map[string]any:
|
||||||
|
if id, okID := projectValue["id"].(string); okID {
|
||||||
|
projectID = strings.TrimSpace(id)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
projectID = strings.TrimSpace(projectValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if projectID != "" {
|
||||||
|
log.Infof("Successfully fetched project_id: %s", projectID)
|
||||||
|
return projectID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("no project_id in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
responsePreview := strings.TrimSpace(string(bodyBytes))
|
||||||
|
if len(responsePreview) > 500 {
|
||||||
|
responsePreview = responsePreview[:500]
|
||||||
|
}
|
||||||
|
|
||||||
|
responseErr := responsePreview
|
||||||
|
if len(responseErr) > 200 {
|
||||||
|
responseErr = responseErr[:200]
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -47,6 +47,11 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callbackPort := a.CallbackPort
|
||||||
|
if opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
|
||||||
pkceCodes, err := claude.GeneratePKCECodes()
|
pkceCodes, err := claude.GeneratePKCECodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("claude pkce generation failed: %w", err)
|
return nil, fmt.Errorf("claude pkce generation failed: %w", err)
|
||||||
@@ -57,7 +62,7 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
|||||||
return nil, fmt.Errorf("claude state generation failed: %w", err)
|
return nil, fmt.Errorf("claude state generation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthServer := claude.NewOAuthServer(a.CallbackPort)
|
oauthServer := claude.NewOAuthServer(callbackPort)
|
||||||
if err = oauthServer.Start(); err != nil {
|
if err = oauthServer.Start(); err != nil {
|
||||||
if strings.Contains(err.Error(), "already in use") {
|
if strings.Contains(err.Error(), "already in use") {
|
||||||
return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err)
|
return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err)
|
||||||
@@ -84,15 +89,15 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
|||||||
fmt.Println("Opening browser for Claude authentication")
|
fmt.Println("Opening browser for Claude authentication")
|
||||||
if !browser.IsAvailable() {
|
if !browser.IsAvailable() {
|
||||||
log.Warn("No browser available; please open the URL manually")
|
log.Warn("No browser available; please open the URL manually")
|
||||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
} else if err = browser.OpenURL(authURL); err != nil {
|
} else if err = browser.OpenURL(authURL); err != nil {
|
||||||
log.Warnf("Failed to open browser automatically: %v", err)
|
log.Warnf("Failed to open browser automatically: %v", err)
|
||||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,11 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callbackPort := a.CallbackPort
|
||||||
|
if opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
|
||||||
pkceCodes, err := codex.GeneratePKCECodes()
|
pkceCodes, err := codex.GeneratePKCECodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("codex pkce generation failed: %w", err)
|
return nil, fmt.Errorf("codex pkce generation failed: %w", err)
|
||||||
@@ -57,7 +62,7 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
return nil, fmt.Errorf("codex state generation failed: %w", err)
|
return nil, fmt.Errorf("codex state generation failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthServer := codex.NewOAuthServer(a.CallbackPort)
|
oauthServer := codex.NewOAuthServer(callbackPort)
|
||||||
if err = oauthServer.Start(); err != nil {
|
if err = oauthServer.Start(); err != nil {
|
||||||
if strings.Contains(err.Error(), "already in use") {
|
if strings.Contains(err.Error(), "already in use") {
|
||||||
return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err)
|
return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err)
|
||||||
@@ -83,15 +88,15 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
fmt.Println("Opening browser for Codex authentication")
|
fmt.Println("Opening browser for Codex authentication")
|
||||||
if !browser.IsAvailable() {
|
if !browser.IsAvailable() {
|
||||||
log.Warn("No browser available; please open the URL manually")
|
log.Warn("No browser available; please open the URL manually")
|
||||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
} else if err = browser.OpenURL(authURL); err != nil {
|
} else if err = browser.OpenURL(authURL); err != nil {
|
||||||
log.Warnf("Failed to open browser automatically: %v", err)
|
log.Warnf("Failed to open browser automatically: %v", err)
|
||||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
util.PrintSSHTunnelInstructions(a.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -77,15 +79,23 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
|||||||
if metadataEqualIgnoringTimestamps(existing, raw) {
|
if metadataEqualIgnoringTimestamps(existing, raw) {
|
||||||
return path, nil
|
return path, nil
|
||||||
}
|
}
|
||||||
} else if errRead != nil && !os.IsNotExist(errRead) {
|
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
|
||||||
|
if errOpen != nil {
|
||||||
|
return "", fmt.Errorf("auth filestore: open existing failed: %w", errOpen)
|
||||||
|
}
|
||||||
|
if _, errWrite := file.Write(raw); errWrite != nil {
|
||||||
|
_ = file.Close()
|
||||||
|
return "", fmt.Errorf("auth filestore: write existing failed: %w", errWrite)
|
||||||
|
}
|
||||||
|
if errClose := file.Close(); errClose != nil {
|
||||||
|
return "", fmt.Errorf("auth filestore: close existing failed: %w", errClose)
|
||||||
|
}
|
||||||
|
return path, nil
|
||||||
|
} else if !os.IsNotExist(errRead) {
|
||||||
return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead)
|
return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead)
|
||||||
}
|
}
|
||||||
tmp := path + ".tmp"
|
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
|
||||||
if errWrite := os.WriteFile(tmp, raw, 0o600); errWrite != nil {
|
return "", fmt.Errorf("auth filestore: write file failed: %w", errWrite)
|
||||||
return "", fmt.Errorf("auth filestore: write temp failed: %w", errWrite)
|
|
||||||
}
|
|
||||||
if errRename := os.Rename(tmp, path); errRename != nil {
|
|
||||||
return "", fmt.Errorf("auth filestore: rename failed: %w", errRename)
|
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID)
|
return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID)
|
||||||
@@ -178,6 +188,30 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
|
|||||||
if provider == "" {
|
if provider == "" {
|
||||||
provider = "unknown"
|
provider = "unknown"
|
||||||
}
|
}
|
||||||
|
if provider == "antigravity" {
|
||||||
|
projectID := ""
|
||||||
|
if pid, ok := metadata["project_id"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(pid)
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
accessToken := ""
|
||||||
|
if token, ok := metadata["access_token"].(string); ok {
|
||||||
|
accessToken = strings.TrimSpace(token)
|
||||||
|
}
|
||||||
|
if accessToken != "" {
|
||||||
|
fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient)
|
||||||
|
if errFetch == nil && strings.TrimSpace(fetchedProjectID) != "" {
|
||||||
|
metadata["project_id"] = strings.TrimSpace(fetchedProjectID)
|
||||||
|
if raw, errMarshal := json.Marshal(metadata); errMarshal == nil {
|
||||||
|
if file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600); errOpen == nil {
|
||||||
|
_, _ = file.Write(raw)
|
||||||
|
_ = file.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
info, err := os.Stat(path)
|
info, err := os.Stat(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("stat file: %w", err)
|
return nil, fmt.Errorf("stat file: %w", err)
|
||||||
@@ -266,92 +300,28 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
|||||||
return s.baseDir
|
return s.baseDir
|
||||||
}
|
}
|
||||||
|
|
||||||
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
|
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs, ignoring volatile fields that
|
||||||
// This function is kept for backward compatibility but can cause refresh loops.
|
// change on every refresh but don't affect authentication logic.
|
||||||
func jsonEqual(a, b []byte) bool {
|
|
||||||
var objA any
|
|
||||||
var objB any
|
|
||||||
if err := json.Unmarshal(a, &objA); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(b, &objB); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return deepEqualJSON(objA, objB)
|
|
||||||
}
|
|
||||||
|
|
||||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
|
|
||||||
// ignoring fields that change on every refresh but don't affect functionality.
|
|
||||||
// This prevents unnecessary file writes that would trigger watcher events and
|
|
||||||
// create refresh loops.
|
|
||||||
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
|
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
|
||||||
var objA, objB map[string]any
|
var objA map[string]any
|
||||||
if err := json.Unmarshal(a, &objA); err != nil {
|
var objB map[string]any
|
||||||
|
if errUnmarshalA := json.Unmarshal(a, &objA); errUnmarshalA != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(b, &objB); err != nil {
|
if errUnmarshalB := json.Unmarshal(b, &objB); errUnmarshalB != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
stripVolatileMetadataFields(objA)
|
||||||
// Fields to ignore: these change on every refresh but don't affect authentication logic.
|
stripVolatileMetadataFields(objB)
|
||||||
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
|
return reflect.DeepEqual(objA, objB)
|
||||||
// - access_token: Google OAuth returns a new access_token on each refresh, this is expected
|
|
||||||
// and shouldn't trigger file writes (the new token will be fetched again when needed)
|
|
||||||
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"}
|
|
||||||
for _, field := range ignoredFields {
|
|
||||||
delete(objA, field)
|
|
||||||
delete(objB, field)
|
|
||||||
}
|
|
||||||
|
|
||||||
return deepEqualJSON(objA, objB)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func deepEqualJSON(a, b any) bool {
|
func stripVolatileMetadataFields(metadata map[string]any) {
|
||||||
switch valA := a.(type) {
|
if metadata == nil {
|
||||||
case map[string]any:
|
return
|
||||||
valB, ok := b.(map[string]any)
|
}
|
||||||
if !ok || len(valA) != len(valB) {
|
// These fields change on refresh and would otherwise trigger watcher reload loops.
|
||||||
return false
|
for _, field := range []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"} {
|
||||||
}
|
delete(metadata, field)
|
||||||
for key, subA := range valA {
|
|
||||||
subB, ok1 := valB[key]
|
|
||||||
if !ok1 || !deepEqualJSON(subA, subB) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
case []any:
|
|
||||||
sliceB, ok := b.([]any)
|
|
||||||
if !ok || len(valA) != len(sliceB) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range valA {
|
|
||||||
if !deepEqualJSON(valA[i], sliceB[i]) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
case float64:
|
|
||||||
valB, ok := b.(float64)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return valA == valB
|
|
||||||
case string:
|
|
||||||
valB, ok := b.(string)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return valA == valB
|
|
||||||
case bool:
|
|
||||||
valB, ok := b.(bool)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return valA == valB
|
|
||||||
case nil:
|
|
||||||
return b == nil
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,8 +45,9 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
|||||||
|
|
||||||
geminiAuth := gemini.NewGeminiAuth()
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
|
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
|
||||||
NoBrowser: opts.NoBrowser,
|
NoBrowser: opts.NoBrowser,
|
||||||
Prompt: opts.Prompt,
|
CallbackPort: opts.CallbackPort,
|
||||||
|
Prompt: opts.Prompt,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("gemini authentication failed: %w", err)
|
return nil, fmt.Errorf("gemini authentication failed: %w", err)
|
||||||
|
|||||||
@@ -42,9 +42,14 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
opts = &LoginOptions{}
|
opts = &LoginOptions{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
callbackPort := iflow.CallbackPort
|
||||||
|
if opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
|
||||||
authSvc := iflow.NewIFlowAuth(cfg)
|
authSvc := iflow.NewIFlowAuth(cfg)
|
||||||
|
|
||||||
oauthServer := iflow.NewOAuthServer(iflow.CallbackPort)
|
oauthServer := iflow.NewOAuthServer(callbackPort)
|
||||||
if err := oauthServer.Start(); err != nil {
|
if err := oauthServer.Start(); err != nil {
|
||||||
if strings.Contains(err.Error(), "already in use") {
|
if strings.Contains(err.Error(), "already in use") {
|
||||||
return nil, fmt.Errorf("iflow authentication server port in use: %w", err)
|
return nil, fmt.Errorf("iflow authentication server port in use: %w", err)
|
||||||
@@ -64,21 +69,21 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
|||||||
return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err)
|
return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflow.CallbackPort)
|
authURL, redirectURI := authSvc.AuthorizationURL(state, callbackPort)
|
||||||
|
|
||||||
if !opts.NoBrowser {
|
if !opts.NoBrowser {
|
||||||
fmt.Println("Opening browser for iFlow authentication")
|
fmt.Println("Opening browser for iFlow authentication")
|
||||||
if !browser.IsAvailable() {
|
if !browser.IsAvailable() {
|
||||||
log.Warn("No browser available; please open the URL manually")
|
log.Warn("No browser available; please open the URL manually")
|
||||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
} else if err = browser.OpenURL(authURL); err != nil {
|
} else if err = browser.OpenURL(authURL); err != nil {
|
||||||
log.Warnf("Failed to open browser automatically: %v", err)
|
log.Warnf("Failed to open browser automatically: %v", err)
|
||||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
util.PrintSSHTunnelInstructions(iflow.CallbackPort)
|
util.PrintSSHTunnelInstructions(callbackPort)
|
||||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,11 @@ var ErrRefreshNotSupported = errors.New("cliproxy auth: refresh not supported")
|
|||||||
// LoginOptions captures generic knobs shared across authenticators.
|
// LoginOptions captures generic knobs shared across authenticators.
|
||||||
// Provider-specific logic can inspect Metadata for extra parameters.
|
// Provider-specific logic can inspect Metadata for extra parameters.
|
||||||
type LoginOptions struct {
|
type LoginOptions struct {
|
||||||
NoBrowser bool
|
NoBrowser bool
|
||||||
ProjectID string
|
ProjectID string
|
||||||
Metadata map[string]string
|
CallbackPort int
|
||||||
Prompt func(prompt string) (string, error)
|
Metadata map[string]string
|
||||||
|
Prompt func(prompt string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticator manages login and optional refresh flows for a provider.
|
// Authenticator manages login and optional refresh flows for a provider.
|
||||||
|
|||||||
@@ -271,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -281,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||||
return m.executeWithProvider(execCtx, provider, req, opts)
|
|
||||||
})
|
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -309,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -319,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||||
return m.executeCountWithProvider(execCtx, provider, req, opts)
|
|
||||||
})
|
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -347,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -357,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||||
return m.executeStreamWithProvider(execCtx, provider, req, opts)
|
|
||||||
})
|
|
||||||
if errStream == nil {
|
if errStream == nil {
|
||||||
return chunks, nil
|
return chunks, nil
|
||||||
}
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -378,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
routeModel := req.Model
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return cliproxyexecutor.Response{}, lastErr
|
||||||
|
}
|
||||||
|
return cliproxyexecutor.Response{}, errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
execReq := req
|
||||||
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
|
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
|
if errExec != nil {
|
||||||
|
result.Error = &Error{Message: errExec.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(errExec, &se) && se != nil {
|
||||||
|
result.Error.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
if ra := retryAfterFromError(errExec); ra != nil {
|
||||||
|
result.RetryAfter = ra
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
lastErr = errExec
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
routeModel := req.Model
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return cliproxyexecutor.Response{}, lastErr
|
||||||
|
}
|
||||||
|
return cliproxyexecutor.Response{}, errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
execReq := req
|
||||||
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
|
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
|
if errExec != nil {
|
||||||
|
result.Error = &Error{Message: errExec.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(errExec, &se) && se != nil {
|
||||||
|
result.Error.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
if ra := retryAfterFromError(errExec); ra != nil {
|
||||||
|
result.RetryAfter = ra
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
lastErr = errExec
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
routeModel := req.Model
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
return nil, errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
execReq := req
|
||||||
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
|
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||||
|
if errStream != nil {
|
||||||
|
rerr := &Error{Message: errStream.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(errStream, &se) && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||||
|
result.RetryAfter = retryAfterFromError(errStream)
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
lastErr = errStream
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||||
|
defer close(out)
|
||||||
|
var failed bool
|
||||||
|
for chunk := range streamChunks {
|
||||||
|
if chunk.Err != nil && !failed {
|
||||||
|
failed = true
|
||||||
|
rerr := &Error{Message: chunk.Err.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(chunk.Err, &se) && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||||
|
}
|
||||||
|
out <- chunk
|
||||||
|
}
|
||||||
|
if !failed {
|
||||||
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||||
|
}
|
||||||
|
}(execCtx, auth.Clone(), provider, chunks)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
if provider == "" {
|
if provider == "" {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||||
@@ -1191,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
|||||||
return authCopy, executor, nil
|
return authCopy, executor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||||
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
|
for _, provider := range providers {
|
||||||
|
p := strings.TrimSpace(strings.ToLower(provider))
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerSet[p] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(providerSet) == 0 {
|
||||||
|
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
candidates := make([]*Auth, 0, len(m.auths))
|
||||||
|
modelKey := strings.TrimSpace(model)
|
||||||
|
registryRef := registry.GetGlobalRegistry()
|
||||||
|
for _, candidate := range m.auths {
|
||||||
|
if candidate == nil || candidate.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||||
|
if providerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, used := tried[candidate.ID]; used {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := m.executors[providerKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, candidate)
|
||||||
|
}
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
|
||||||
|
if errPick != nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", errPick
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
|
||||||
|
executor, okExecutor := m.executors[providerKey]
|
||||||
|
if !okExecutor {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||||
|
}
|
||||||
|
authCopy := selected.Clone()
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !selected.indexAssigned {
|
||||||
|
m.mu.Lock()
|
||||||
|
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||||
|
current.EnsureIndex()
|
||||||
|
authCopy = current.Clone()
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
return authCopy, executor, providerKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||||
if m.store == nil || auth == nil {
|
if m.store == nil || auth == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header {
|
|||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) {
|
func authPriority(auth *Auth) int {
|
||||||
available = make([]*Auth, 0, len(auths))
|
if auth == nil || auth.Attributes == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(auth.Attributes["priority"])
|
||||||
|
if raw == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
parsed, err := strconv.Atoi(raw)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||||
|
available = make(map[int][]*Auth)
|
||||||
for i := 0; i < len(auths); i++ {
|
for i := 0; i < len(auths); i++ {
|
||||||
candidate := auths[i]
|
candidate := auths[i]
|
||||||
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
||||||
if !blocked {
|
if !blocked {
|
||||||
available = append(available, candidate)
|
priority := authPriority(candidate)
|
||||||
|
available[priority] = append(available[priority], candidate)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if reason == blockReasonCooldown {
|
if reason == blockReasonCooldown {
|
||||||
@@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(available) > 1 {
|
|
||||||
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
|
||||||
}
|
|
||||||
return available, cooldownCount, earliest
|
return available, cooldownCount, earliest
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
|||||||
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
||||||
}
|
}
|
||||||
|
|
||||||
available, cooldownCount, earliest := collectAvailable(auths, model, now)
|
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
|
||||||
if len(available) == 0 {
|
if len(availableByPriority) == 0 {
|
||||||
if cooldownCount == len(auths) && !earliest.IsZero() {
|
if cooldownCount == len(auths) && !earliest.IsZero() {
|
||||||
|
providerForError := provider
|
||||||
|
if providerForError == "mixed" {
|
||||||
|
providerForError = ""
|
||||||
|
}
|
||||||
resetIn := earliest.Sub(now)
|
resetIn := earliest.Sub(now)
|
||||||
if resetIn < 0 {
|
if resetIn < 0 {
|
||||||
resetIn = 0
|
resetIn = 0
|
||||||
}
|
}
|
||||||
return nil, newModelCooldownError(model, provider, resetIn)
|
return nil, newModelCooldownError(model, providerForError, resetIn)
|
||||||
}
|
}
|
||||||
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bestPriority := 0
|
||||||
|
found := false
|
||||||
|
for priority := range availableByPriority {
|
||||||
|
if !found || priority > bestPriority {
|
||||||
|
bestPriority = priority
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
available := availableByPriority[bestPriority]
|
||||||
|
if len(available) > 1 {
|
||||||
|
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
||||||
|
}
|
||||||
return available, nil
|
return available, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
)
|
)
|
||||||
@@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "c", Attributes: map[string]string{"priority": "0"}},
|
||||||
|
{ID: "a", Attributes: map[string]string{"priority": "10"}},
|
||||||
|
{ID: "b", Attributes: map[string]string{"priority": "10"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{"a", "b", "a", "b"}
|
||||||
|
for i, id := range want {
|
||||||
|
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
if got.ID != id {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
|
||||||
|
}
|
||||||
|
if got.ID == "c" {
|
||||||
|
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &FillFirstSelector{}
|
||||||
|
now := time.Now()
|
||||||
|
model := "test-model"
|
||||||
|
|
||||||
|
high := &Auth{
|
||||||
|
ID: "high",
|
||||||
|
Attributes: map[string]string{"priority": "10"},
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusActive,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: now.Add(30 * time.Minute),
|
||||||
|
Quota: QuotaState{
|
||||||
|
Exceeded: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
|
||||||
|
|
||||||
|
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() auth = nil")
|
||||||
|
}
|
||||||
|
if got.ID != "low" {
|
||||||
|
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
||||||
selector := &RoundRobinSelector{}
|
selector := &RoundRobinSelector{}
|
||||||
auths := []*Auth{
|
auths := []*Auth{
|
||||||
|
|||||||
Reference in New Issue
Block a user