Compare commits

...

9 Commits

Author SHA1 Message Date
Luis Pater
7248f65c36 feat(auth): prevent filestore writes on unchanged metadata
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
- Added `metadataEqualIgnoringTimestamps` to compare metadata while ignoring volatile fields.
- Prevented redundant writes caused by changes in timestamp-related fields.
- Improved efficiency in filestore operations by skipping unnecessary updates.
2026-01-15 14:05:23 +08:00
Luis Pater
086eb3df7a refactor(auth): simplify file handling logic and remove redundant comparison functions
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
feat(auth): fetch and update Antigravity project ID from metadata during filestore operations

- Added support to retrieve and update `project_id` using the access token if missing in metadata.
- Integrated HTTP client to fetch project ID dynamically.
- Enhanced metadata persistence logic.
2026-01-15 13:29:14 +08:00
Luis Pater
5a7e5bd870 feat(auth): add Antigravity onboarding with tier selection
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
- Updated `ideType` to `ANTIGRAVITY` in request payload.
- Introduced tier-selection logic to determine default tier for onboarding.
- Added `antigravityOnboardUser` function for project ID retrieval via polling.
- Enhanced error handling and response decoding for onboarding flow.
2026-01-15 11:43:02 +08:00
Luis Pater
6f8a8f8136 feat(selector): add priority support for auth selection
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
2026-01-15 07:08:24 +08:00
Luis Pater
b163f8ed9e Fixed: #1004
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
feat(translator): add function name to response output item serialization

- Included `item.name` in the serialized response output to enhance output item handling.
2026-01-15 03:27:00 +08:00
Luis Pater
a1da6ff5ac Fixed: #499 #985
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
feat(oauth): add support for customizable OAuth callback ports

- Introduced `oauth-callback-port` flag to override default callback ports.
- Updated SDK and login flows for `iflow`, `gemini`, `antigravity`, `codex`, `claude`, and `openai` to respect configurable callback ports.
- Refactored internal OAuth servers to dynamically assign ports based on the provided options.
- Revised tests and documentation to reflect the new flag and behavior.
2026-01-14 04:29:15 +08:00
Luis Pater
43652d044c refactor(config): replace nonstream-keepalive with nonstream-keepalive-interval
Some checks failed
docker-image / docker (push) Has been cancelled
goreleaser / goreleaser (push) Has been cancelled
- Updated `SDKConfig` to use `nonstream-keepalive-interval` (seconds) instead of the boolean `nonstream-keepalive`.
- Refactored handlers and logic to incorporate the new interval-based configuration.
- Updated config diff, tests, and example YAML to reflect the changes.
2026-01-13 03:14:38 +08:00
Luis Pater
b1b379ea18 feat(api): add non-streaming keep-alive support for idle timeout prevention
- Introduced `StartNonStreamingKeepAlive` to emit periodic blank lines during non-streaming responses.
- Added `nonstream-keepalive` configuration option in `SDKConfig`.
- Updated handlers to utilize `StartNonStreamingKeepAlive` and ensure proper cleanup.
- Extended config diff and tests to include `nonstream-keepalive` changes.
2026-01-13 02:36:07 +08:00
hkfires
21ac161b21 fix(test): implement missing HttpRequest method in stream bootstrap mock 2026-01-12 16:33:43 +08:00
33 changed files with 825 additions and 191 deletions

View File

@@ -61,6 +61,7 @@ func main() {
var iflowLogin bool var iflowLogin bool
var iflowCookie bool var iflowCookie bool
var noBrowser bool var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool var antigravityLogin bool
var projectID string var projectID string
var vertexImport string var vertexImport string
@@ -75,6 +76,7 @@ func main() {
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth") flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie") flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
@@ -425,7 +427,8 @@ func main() {
// Create login options to be used in authentication flows. // Create login options to be used in authentication flows.
options := &cmd.LoginOptions{ options := &cmd.LoginOptions{
NoBrowser: noBrowser, NoBrowser: noBrowser,
CallbackPort: oauthCallbackPort,
} }
// Register the shared token store once so all components use the same persistence backend. // Register the shared token store once so all components use the same persistence backend.

View File

@@ -77,6 +77,9 @@ routing:
# When true, enable authentication for the WebSocket API (/v1/ws). # When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false ws-auth: false
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
nonstream-keepalive-interval: 0
# Streaming behavior (SSE keep-alives + safe bootstrap retries). # Streaming behavior (SSE keep-alives + safe bootstrap retries).
# streaming: # streaming:
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives. # keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.

View File

@@ -29,8 +29,9 @@ import (
) )
const ( const (
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOauthClientSecret = "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" geminiOauthClientSecret = "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"
geminiDefaultCallbackPort = 8085
) )
var ( var (
@@ -49,8 +50,9 @@ type GeminiAuth struct {
// WebLoginOptions customizes the interactive OAuth flow. // WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct { type WebLoginOptions struct {
NoBrowser bool NoBrowser bool
Prompt func(string) (string, error) CallbackPort int
Prompt func(string) (string, error)
} }
// NewGeminiAuth creates a new instance of GeminiAuth. // NewGeminiAuth creates a new instance of GeminiAuth.
@@ -72,6 +74,12 @@ func NewGeminiAuth() *GeminiAuth {
// - *http.Client: An HTTP client configured with authentication // - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise // - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
callbackPort := geminiDefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Configure proxy settings for the HTTP client if a proxy URL is provided. // Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL) proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil { if err == nil {
@@ -106,7 +114,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: geminiOauthClientID, ClientID: geminiOauthClientID,
ClientSecret: geminiOauthClientSecret, ClientSecret: geminiOauthClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server. RedirectURL: callbackURL, // This will be used by the local server.
Scopes: geminiOauthScopes, Scopes: geminiOauthScopes,
Endpoint: google.Endpoint, Endpoint: google.Endpoint,
} }
@@ -218,14 +226,20 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
callbackPort := geminiDefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string, 1) codeChan := make(chan string, 1)
errChan := make(chan error, 1) errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer. // Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux() mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux} server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
config.RedirectURL = "http://localhost:8085/oauth2callback" config.RedirectURL = callbackURL
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" { if err := r.URL.Query().Get("error"); err != "" {
@@ -277,13 +291,13 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Check if browser is available // Check if browser is available
if !browser.IsAvailable() { if !browser.IsAvailable() {
log.Warn("No browser available on this system") log.Warn("No browser available on this system")
util.PrintSSHTunnelInstructions(8085) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
} else { } else {
if err := browser.OpenURL(authURL); err != nil { if err := browser.OpenURL(authURL); err != nil {
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err) authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
log.Warn(codex.GetUserFriendlyMessage(authErr)) log.Warn(codex.GetUserFriendlyMessage(authErr))
util.PrintSSHTunnelInstructions(8085) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL) fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
// Log platform info for debugging // Log platform info for debugging
@@ -294,7 +308,7 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
} }
} }
} else { } else {
util.PrintSSHTunnelInstructions(8085) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL) fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
} }

View File

@@ -32,9 +32,10 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, CallbackPort: options.CallbackPort,
Prompt: promptFn, Metadata: map[string]string{},
Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)

View File

@@ -22,9 +22,10 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, CallbackPort: options.CallbackPort,
Prompt: promptFn, Metadata: map[string]string{},
Prompt: promptFn,
} }
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)

View File

@@ -24,9 +24,10 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
} }
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, CallbackPort: options.CallbackPort,
Prompt: promptFn, Metadata: map[string]string{},
Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)

View File

@@ -67,10 +67,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
} }
loginOpts := &sdkAuth.LoginOptions{ loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
ProjectID: trimmedProjectID, ProjectID: trimmedProjectID,
Metadata: map[string]string{}, CallbackPort: options.CallbackPort,
Prompt: callbackPrompt, Metadata: map[string]string{},
Prompt: callbackPrompt,
} }
authenticator := sdkAuth.NewGeminiAuthenticator() authenticator := sdkAuth.NewGeminiAuthenticator()
@@ -88,8 +89,9 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Prompt: callbackPrompt, CallbackPort: options.CallbackPort,
Prompt: callbackPrompt,
}) })
if errClient != nil { if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient) log.Errorf("Gemini authentication failed: %v", errClient)

View File

@@ -19,6 +19,9 @@ type LoginOptions struct {
// NoBrowser indicates whether to skip opening the browser automatically. // NoBrowser indicates whether to skip opening the browser automatically.
NoBrowser bool NoBrowser bool
// CallbackPort overrides the local OAuth callback port when set (>0).
CallbackPort int
// Prompt allows the caller to provide interactive input when needed. // Prompt allows the caller to provide interactive input when needed.
Prompt func(prompt string) (string, error) Prompt func(prompt string) (string, error)
} }
@@ -43,9 +46,10 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, CallbackPort: options.CallbackPort,
Prompt: promptFn, Metadata: map[string]string{},
Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)

View File

@@ -36,9 +36,10 @@ func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
} }
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, CallbackPort: options.CallbackPort,
Prompt: promptFn, Metadata: map[string]string{},
Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)

View File

@@ -242,6 +242,10 @@ type ClaudeKey struct {
// APIKey is the authentication key for accessing Claude API services. // APIKey is the authentication key for accessing Claude API services.
APIKey string `yaml:"api-key" json:"api-key"` APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4"). // Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -280,6 +284,10 @@ type CodexKey struct {
// APIKey is the authentication key for accessing Codex API services. // APIKey is the authentication key for accessing Codex API services.
APIKey string `yaml:"api-key" json:"api-key"` APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex"). // Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -318,6 +326,10 @@ type GeminiKey struct {
// APIKey is the authentication key for accessing Gemini API services. // APIKey is the authentication key for accessing Gemini API services.
APIKey string `yaml:"api-key" json:"api-key"` APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview"). // Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
@@ -355,6 +367,10 @@ type OpenAICompatibility struct {
// Name is the identifier for this OpenAI compatibility configuration. // Name is the identifier for this OpenAI compatibility configuration.
Name string `yaml:"name" json:"name"` Name string `yaml:"name" json:"name"`
// Priority controls selection preference when multiple providers or credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2"). // Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`

View File

@@ -25,6 +25,10 @@ type SDKConfig struct {
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries). // Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"` Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
// <= 0 disables keep-alives. Value is in seconds.
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
} }
// StreamingConfig holds server streaming behavior configuration. // StreamingConfig holds server streaming behavior configuration.

View File

@@ -13,6 +13,10 @@ type VertexCompatKey struct {
// Maps to the x-goog-api-key header. // Maps to the x-goog-api-key header.
APIKey string `yaml:"api-key" json:"api-key"` APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro"). // Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"` Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`

View File

@@ -25,6 +25,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
@@ -1104,12 +1105,49 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
auth.Metadata["refresh_token"] = tokenResp.RefreshToken auth.Metadata["refresh_token"] = tokenResp.RefreshToken
} }
auth.Metadata["expires_in"] = tokenResp.ExpiresIn auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = time.Now().UnixMilli() now := time.Now()
auth.Metadata["expired"] = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339) auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
auth.Metadata["type"] = antigravityAuthType auth.Metadata["type"] = antigravityAuthType
if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil {
log.Warnf("antigravity executor: ensure project id failed: %v", errProject)
}
return auth, nil return auth, nil
} }
func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) error {
if auth == nil {
return nil
}
if auth.Metadata["project_id"] != nil {
return nil
}
token := strings.TrimSpace(accessToken)
if token == "" {
token = metaStringValue(auth.Metadata, "access_token")
}
if token == "" {
return nil
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
if errFetch != nil {
return errFetch
}
if strings.TrimSpace(projectID) == "" {
return nil
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
return nil
}
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) { func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
if token == "" { if token == "" {
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"} return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}

View File

@@ -251,6 +251,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID)) itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args) itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID) itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone)) out = append(out, emitEvent("response.output_item.done", itemDone))
st.InFuncBlock = false st.InFuncBlock = false
} else if st.ReasoningActive { } else if st.ReasoningActive {

View File

@@ -54,6 +54,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix { if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix)) changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
} }
if oldCfg.NonStreamKeepAliveInterval != newCfg.NonStreamKeepAliveInterval {
changes = append(changes, fmt.Sprintf("nonstream-keepalive-interval: %d -> %d", oldCfg.NonStreamKeepAliveInterval, newCfg.NonStreamKeepAliveInterval))
}
// Quota-exceeded behavior // Quota-exceeded behavior
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject { if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {

View File

@@ -231,10 +231,11 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false}, AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"}, RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
SDKConfig: sdkconfig.SDKConfig{ SDKConfig: sdkconfig.SDKConfig{
RequestLog: false, RequestLog: false,
ProxyURL: "http://old-proxy", ProxyURL: "http://old-proxy",
APIKeys: []string{"key-1"}, APIKeys: []string{"key-1"},
ForceModelPrefix: false, ForceModelPrefix: false,
NonStreamKeepAliveInterval: 0,
}, },
} }
newCfg := &config.Config{ newCfg := &config.Config{
@@ -267,10 +268,11 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
SecretKey: "", SecretKey: "",
}, },
SDKConfig: sdkconfig.SDKConfig{ SDKConfig: sdkconfig.SDKConfig{
RequestLog: true, RequestLog: true,
ProxyURL: "http://new-proxy", ProxyURL: "http://new-proxy",
APIKeys: []string{" key-1 ", "key-2"}, APIKeys: []string{" key-1 ", "key-2"},
ForceModelPrefix: true, ForceModelPrefix: true,
NonStreamKeepAliveInterval: 5,
}, },
} }
@@ -285,6 +287,7 @@ func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy") expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
expectContains(t, details, "ws-auth: false -> true") expectContains(t, details, "ws-auth: false -> true")
expectContains(t, details, "force-model-prefix: false -> true") expectContains(t, details, "force-model-prefix: false -> true")
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
expectContains(t, details, "quota-exceeded.switch-project: false -> true") expectContains(t, details, "quota-exceeded.switch-project: false -> true")
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true") expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, details, "api-keys count: 1 -> 2") expectContains(t, details, "api-keys count: 1 -> 2")

View File

@@ -2,6 +2,7 @@ package synthesizer
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
@@ -59,6 +60,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
"source": fmt.Sprintf("config:gemini[%s]", token), "source": fmt.Sprintf("config:gemini[%s]", token),
"api_key": key, "api_key": key,
} }
if entry.Priority != 0 {
attrs["priority"] = strconv.Itoa(entry.Priority)
}
if base != "" { if base != "" {
attrs["base_url"] = base attrs["base_url"] = base
} }
@@ -103,6 +107,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea
"source": fmt.Sprintf("config:claude[%s]", token), "source": fmt.Sprintf("config:claude[%s]", token),
"api_key": key, "api_key": key,
} }
if ck.Priority != 0 {
attrs["priority"] = strconv.Itoa(ck.Priority)
}
if base != "" { if base != "" {
attrs["base_url"] = base attrs["base_url"] = base
} }
@@ -147,6 +154,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
"source": fmt.Sprintf("config:codex[%s]", token), "source": fmt.Sprintf("config:codex[%s]", token),
"api_key": key, "api_key": key,
} }
if ck.Priority != 0 {
attrs["priority"] = strconv.Itoa(ck.Priority)
}
if ck.BaseURL != "" { if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL attrs["base_url"] = ck.BaseURL
} }
@@ -202,6 +212,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
"compat_name": compat.Name, "compat_name": compat.Name,
"provider_key": providerName, "provider_key": providerName,
} }
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if key != "" { if key != "" {
attrs["api_key"] = key attrs["api_key"] = key
} }
@@ -233,6 +246,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
"compat_name": compat.Name, "compat_name": compat.Name,
"provider_key": providerName, "provider_key": providerName,
} }
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" { if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
attrs["models_hash"] = hash attrs["models_hash"] = hash
} }
@@ -275,6 +291,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
"base_url": base, "base_url": base,
"provider_key": providerName, "provider_key": providerName,
} }
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if key != "" { if key != "" {
attrs["api_key"] = key attrs["api_key"] = key
} }

View File

@@ -146,10 +146,12 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
alt := h.GetAlt(c) alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
modelName := gjson.GetBytes(rawJSON, "model").String() modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil { if errMsg != nil {
h.WriteErrorResponse(c, errMsg) h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error) cliCancel(errMsg.Error)
@@ -159,13 +161,18 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
// Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header // Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header
// This fixes title generation and other non-streaming responses that arrive compressed // This fixes title generation and other non-streaming responses that arrive compressed
if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b { if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b {
gzReader, err := gzip.NewReader(bytes.NewReader(resp)) gzReader, errGzip := gzip.NewReader(bytes.NewReader(resp))
if err != nil { if errGzip != nil {
log.Warnf("failed to decompress gzipped Claude response: %v", err) log.Warnf("failed to decompress gzipped Claude response: %v", errGzip)
} else { } else {
defer gzReader.Close() defer func() {
if decompressed, err := io.ReadAll(gzReader); err != nil { if errClose := gzReader.Close(); errClose != nil {
log.Warnf("failed to read decompressed Claude response: %v", err) log.Warnf("failed to close Claude gzip reader: %v", errClose)
}
}()
decompressed, errRead := io.ReadAll(gzReader)
if errRead != nil {
log.Warnf("failed to read decompressed Claude response: %v", errRead)
} else { } else {
resp = decompressed resp = decompressed
} }

View File

@@ -336,7 +336,9 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
alt := h.GetAlt(c) alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil { if errMsg != nil {
h.WriteErrorResponse(c, errMsg) h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error) cliCancel(errMsg.Error)

View File

@@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -113,6 +114,19 @@ func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
return time.Duration(seconds) * time.Second return time.Duration(seconds) * time.Second
} }
// NonStreamingKeepAliveInterval returns the keep-alive interval for non-streaming responses.
// Returning 0 disables keep-alives (default when unset).
func NonStreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := 0
if cfg != nil {
seconds = cfg.NonStreamKeepAliveInterval
}
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent. // StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int { func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries retries := defaultStreamingBootstrapRetries
@@ -293,6 +307,53 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
} }
} }
// StartNonStreamingKeepAlive emits blank lines every 5 seconds while waiting for a non-streaming response.
// It returns a stop function that must be called before writing the final response.
func (h *BaseAPIHandler) StartNonStreamingKeepAlive(c *gin.Context, ctx context.Context) func() {
if h == nil || c == nil {
return func() {}
}
interval := NonStreamingKeepAliveInterval(h.Cfg)
if interval <= 0 {
return func() {}
}
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return func() {}
}
if ctx == nil {
ctx = context.Background()
}
stopChan := make(chan struct{})
var stopOnce sync.Once
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-stopChan:
return
case <-ctx.Done():
return
case <-ticker.C:
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}()
return func() {
stopOnce.Do(func() {
close(stopChan)
})
wg.Wait()
}
}
// appendAPIResponse preserves any previously captured API response and appends new data. // appendAPIResponse preserves any previously captured API response and appends new data.
func appendAPIResponse(c *gin.Context, data []byte) { func appendAPIResponse(c *gin.Context, data []byte) {
if c == nil || len(data) == 0 { if c == nil || len(data) == 0 {

View File

@@ -56,6 +56,14 @@ func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, co
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
} }
func (e *failOnceStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *failOnceStreamExecutor) Calls() int { func (e *failOnceStreamExecutor) Calls() int {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()

View File

@@ -524,7 +524,9 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String() modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "") resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
stopKeepAlive()
if errMsg != nil { if errMsg != nil {
h.WriteErrorResponse(c, errMsg) h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error) cliCancel(errMsg.Error)

View File

@@ -103,20 +103,17 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
modelName := gjson.GetBytes(rawJSON, "model").String() modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
defer func() { stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
cliCancel()
}()
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "") resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
stopKeepAlive()
if errMsg != nil { if errMsg != nil {
h.WriteErrorResponse(c, errMsg) h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return return
} }
_, _ = c.Writer.Write(resp) _, _ = c.Writer.Write(resp)
return cliCancel()
// no legacy fallback
} }
// handleStreamingResponse handles streaming responses for Gemini models. // handleStreamingResponse handles streaming responses for Gemini models.

View File

@@ -60,6 +60,11 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
opts = &LoginOptions{} opts = &LoginOptions{}
} }
callbackPort := antigravityCallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{}) httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
state, err := misc.GenerateRandomState() state, err := misc.GenerateRandomState()
@@ -67,7 +72,7 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err) return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
} }
srv, port, cbChan, errServer := startAntigravityCallbackServer() srv, port, cbChan, errServer := startAntigravityCallbackServer(callbackPort)
if errServer != nil { if errServer != nil {
return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer) return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer)
} }
@@ -224,13 +229,16 @@ type callbackResult struct {
State string State string
} }
func startAntigravityCallbackServer() (*http.Server, int, <-chan callbackResult, error) { func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
addr := fmt.Sprintf(":%d", antigravityCallbackPort) if port <= 0 {
port = antigravityCallbackPort
}
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return nil, 0, nil, err return nil, 0, nil, err
} }
port := listener.Addr().(*net.TCPAddr).Port port = listener.Addr().(*net.TCPAddr).Port
resultCh := make(chan callbackResult, 1) resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -374,7 +382,7 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
// Call loadCodeAssist to get the project // Call loadCodeAssist to get the project
loadReqBody := map[string]any{ loadReqBody := map[string]any{
"metadata": map[string]string{ "metadata": map[string]string{
"ideType": "IDE_UNSPECIFIED", "ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED", "platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI", "pluginType": "GEMINI",
}, },
@@ -434,8 +442,134 @@ func fetchAntigravityProjectID(ctx context.Context, accessToken string, httpClie
} }
if projectID == "" { if projectID == "" {
return "", fmt.Errorf("no cloudaicompanionProject in response") tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = antigravityOnboardUser(ctx, accessToken, tierID, httpClient)
if err != nil {
return "", err
}
return projectID, nil
} }
return projectID, nil return projectID, nil
} }
// antigravityOnboardUser attempts to fetch the project ID via onboardUser by polling for completion.
// It returns an empty string when the operation times out or completes without a project ID.
func antigravityOnboardUser(ctx context.Context, accessToken, tierID string, httpClient *http.Client) (string, error) {
if httpClient == nil {
httpClient = http.DefaultClient
}
fmt.Println("Antigravity: onboarding user...", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", antigravityAPIEndpoint, antigravityAPIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", antigravityAPIUserAgent)
req.Header.Set("X-Goog-Api-Client", antigravityAPIClient)
req.Header.Set("Client-Metadata", antigravityClientMetadata)
resp, errDo := httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
}

View File

@@ -47,6 +47,11 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
opts = &LoginOptions{} opts = &LoginOptions{}
} }
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := claude.GeneratePKCECodes() pkceCodes, err := claude.GeneratePKCECodes()
if err != nil { if err != nil {
return nil, fmt.Errorf("claude pkce generation failed: %w", err) return nil, fmt.Errorf("claude pkce generation failed: %w", err)
@@ -57,7 +62,7 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
return nil, fmt.Errorf("claude state generation failed: %w", err) return nil, fmt.Errorf("claude state generation failed: %w", err)
} }
oauthServer := claude.NewOAuthServer(a.CallbackPort) oauthServer := claude.NewOAuthServer(callbackPort)
if err = oauthServer.Start(); err != nil { if err = oauthServer.Start(); err != nil {
if strings.Contains(err.Error(), "already in use") { if strings.Contains(err.Error(), "already in use") {
return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err) return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err)
@@ -84,15 +89,15 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
fmt.Println("Opening browser for Claude authentication") fmt.Println("Opening browser for Claude authentication")
if !browser.IsAvailable() { if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually") log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(a.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil { } else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err) log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(a.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} }
} else { } else {
util.PrintSSHTunnelInstructions(a.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} }

View File

@@ -47,6 +47,11 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
opts = &LoginOptions{} opts = &LoginOptions{}
} }
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := codex.GeneratePKCECodes() pkceCodes, err := codex.GeneratePKCECodes()
if err != nil { if err != nil {
return nil, fmt.Errorf("codex pkce generation failed: %w", err) return nil, fmt.Errorf("codex pkce generation failed: %w", err)
@@ -57,7 +62,7 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
return nil, fmt.Errorf("codex state generation failed: %w", err) return nil, fmt.Errorf("codex state generation failed: %w", err)
} }
oauthServer := codex.NewOAuthServer(a.CallbackPort) oauthServer := codex.NewOAuthServer(callbackPort)
if err = oauthServer.Start(); err != nil { if err = oauthServer.Start(); err != nil {
if strings.Contains(err.Error(), "already in use") { if strings.Contains(err.Error(), "already in use") {
return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err) return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err)
@@ -83,15 +88,15 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Opening browser for Codex authentication") fmt.Println("Opening browser for Codex authentication")
if !browser.IsAvailable() { if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually") log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(a.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil { } else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err) log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(a.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} }
} else { } else {
util.PrintSSHTunnelInstructions(a.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} }

View File

@@ -5,8 +5,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/fs" "io/fs"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -77,15 +79,23 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
if metadataEqualIgnoringTimestamps(existing, raw) { if metadataEqualIgnoringTimestamps(existing, raw) {
return path, nil return path, nil
} }
} else if errRead != nil && !os.IsNotExist(errRead) { file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
if errOpen != nil {
return "", fmt.Errorf("auth filestore: open existing failed: %w", errOpen)
}
if _, errWrite := file.Write(raw); errWrite != nil {
_ = file.Close()
return "", fmt.Errorf("auth filestore: write existing failed: %w", errWrite)
}
if errClose := file.Close(); errClose != nil {
return "", fmt.Errorf("auth filestore: close existing failed: %w", errClose)
}
return path, nil
} else if !os.IsNotExist(errRead) {
return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead) return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead)
} }
tmp := path + ".tmp" if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
if errWrite := os.WriteFile(tmp, raw, 0o600); errWrite != nil { return "", fmt.Errorf("auth filestore: write file failed: %w", errWrite)
return "", fmt.Errorf("auth filestore: write temp failed: %w", errWrite)
}
if errRename := os.Rename(tmp, path); errRename != nil {
return "", fmt.Errorf("auth filestore: rename failed: %w", errRename)
} }
default: default:
return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID) return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID)
@@ -178,6 +188,30 @@ func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth,
if provider == "" { if provider == "" {
provider = "unknown" provider = "unknown"
} }
if provider == "antigravity" {
projectID := ""
if pid, ok := metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(pid)
}
if projectID == "" {
accessToken := ""
if token, ok := metadata["access_token"].(string); ok {
accessToken = strings.TrimSpace(token)
}
if accessToken != "" {
fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient)
if errFetch == nil && strings.TrimSpace(fetchedProjectID) != "" {
metadata["project_id"] = strings.TrimSpace(fetchedProjectID)
if raw, errMarshal := json.Marshal(metadata); errMarshal == nil {
if file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600); errOpen == nil {
_, _ = file.Write(raw)
_ = file.Close()
}
}
}
}
}
}
info, err := os.Stat(path) info, err := os.Stat(path)
if err != nil { if err != nil {
return nil, fmt.Errorf("stat file: %w", err) return nil, fmt.Errorf("stat file: %w", err)
@@ -266,92 +300,28 @@ func (s *FileTokenStore) baseDirSnapshot() string {
return s.baseDir return s.baseDir
} }
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata. // metadataEqualIgnoringTimestamps compares two metadata JSON blobs, ignoring volatile fields that
// This function is kept for backward compatibility but can cause refresh loops. // change on every refresh but don't affect authentication logic.
func jsonEqual(a, b []byte) bool {
var objA any
var objB any
if err := json.Unmarshal(a, &objA); err != nil {
return false
}
if err := json.Unmarshal(b, &objB); err != nil {
return false
}
return deepEqualJSON(objA, objB)
}
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
// ignoring fields that change on every refresh but don't affect functionality.
// This prevents unnecessary file writes that would trigger watcher events and
// create refresh loops.
func metadataEqualIgnoringTimestamps(a, b []byte) bool { func metadataEqualIgnoringTimestamps(a, b []byte) bool {
var objA, objB map[string]any var objA map[string]any
if err := json.Unmarshal(a, &objA); err != nil { var objB map[string]any
if errUnmarshalA := json.Unmarshal(a, &objA); errUnmarshalA != nil {
return false return false
} }
if err := json.Unmarshal(b, &objB); err != nil { if errUnmarshalB := json.Unmarshal(b, &objB); errUnmarshalB != nil {
return false return false
} }
stripVolatileMetadataFields(objA)
// Fields to ignore: these change on every refresh but don't affect authentication logic. stripVolatileMetadataFields(objB)
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh return reflect.DeepEqual(objA, objB)
// - access_token: Google OAuth returns a new access_token on each refresh, this is expected
// and shouldn't trigger file writes (the new token will be fetched again when needed)
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"}
for _, field := range ignoredFields {
delete(objA, field)
delete(objB, field)
}
return deepEqualJSON(objA, objB)
} }
func deepEqualJSON(a, b any) bool { func stripVolatileMetadataFields(metadata map[string]any) {
switch valA := a.(type) { if metadata == nil {
case map[string]any: return
valB, ok := b.(map[string]any) }
if !ok || len(valA) != len(valB) { // These fields change on refresh and would otherwise trigger watcher reload loops.
return false for _, field := range []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"} {
} delete(metadata, field)
for key, subA := range valA {
subB, ok1 := valB[key]
if !ok1 || !deepEqualJSON(subA, subB) {
return false
}
}
return true
case []any:
sliceB, ok := b.([]any)
if !ok || len(valA) != len(sliceB) {
return false
}
for i := range valA {
if !deepEqualJSON(valA[i], sliceB[i]) {
return false
}
}
return true
case float64:
valB, ok := b.(float64)
if !ok {
return false
}
return valA == valB
case string:
valB, ok := b.(string)
if !ok {
return false
}
return valA == valB
case bool:
valB, ok := b.(bool)
if !ok {
return false
}
return valA == valB
case nil:
return b == nil
default:
return false
} }
} }

View File

@@ -45,8 +45,9 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{ _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
NoBrowser: opts.NoBrowser, NoBrowser: opts.NoBrowser,
Prompt: opts.Prompt, CallbackPort: opts.CallbackPort,
Prompt: opts.Prompt,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("gemini authentication failed: %w", err) return nil, fmt.Errorf("gemini authentication failed: %w", err)

View File

@@ -42,9 +42,14 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
opts = &LoginOptions{} opts = &LoginOptions{}
} }
callbackPort := iflow.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
authSvc := iflow.NewIFlowAuth(cfg) authSvc := iflow.NewIFlowAuth(cfg)
oauthServer := iflow.NewOAuthServer(iflow.CallbackPort) oauthServer := iflow.NewOAuthServer(callbackPort)
if err := oauthServer.Start(); err != nil { if err := oauthServer.Start(); err != nil {
if strings.Contains(err.Error(), "already in use") { if strings.Contains(err.Error(), "already in use") {
return nil, fmt.Errorf("iflow authentication server port in use: %w", err) return nil, fmt.Errorf("iflow authentication server port in use: %w", err)
@@ -64,21 +69,21 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err) return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err)
} }
authURL, redirectURI := authSvc.AuthorizationURL(state, iflow.CallbackPort) authURL, redirectURI := authSvc.AuthorizationURL(state, callbackPort)
if !opts.NoBrowser { if !opts.NoBrowser {
fmt.Println("Opening browser for iFlow authentication") fmt.Println("Opening browser for iFlow authentication")
if !browser.IsAvailable() { if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually") log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(iflow.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil { } else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err) log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(iflow.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} }
} else { } else {
util.PrintSSHTunnelInstructions(iflow.CallbackPort) util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL) fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} }

View File

@@ -14,10 +14,11 @@ var ErrRefreshNotSupported = errors.New("cliproxy auth: refresh not supported")
// LoginOptions captures generic knobs shared across authenticators. // LoginOptions captures generic knobs shared across authenticators.
// Provider-specific logic can inspect Metadata for extra parameters. // Provider-specific logic can inspect Metadata for extra parameters.
type LoginOptions struct { type LoginOptions struct {
NoBrowser bool NoBrowser bool
ProjectID string ProjectID string
Metadata map[string]string CallbackPort int
Prompt func(prompt string) (string, error) Metadata map[string]string
Prompt func(prompt string) (string, error)
} }
// Authenticator manages login and optional refresh flows for a provider. // Authenticator manages login and optional refresh flows for a provider.

View File

@@ -271,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
if len(normalized) == 0 { if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
rotated := m.rotateProviders(req.Model, normalized)
retryTimes, maxWait := m.retrySettings() retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1 attempts := retryTimes + 1
@@ -281,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
var lastErr error var lastErr error
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; attempt < attempts; attempt++ {
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) { resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
return m.executeWithProvider(execCtx, provider, req, opts)
})
if errExec == nil { if errExec == nil {
return resp, nil return resp, nil
} }
lastErr = errExec lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait) wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
if !shouldRetry { if !shouldRetry {
break break
} }
@@ -309,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
if len(normalized) == 0 { if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
rotated := m.rotateProviders(req.Model, normalized)
retryTimes, maxWait := m.retrySettings() retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1 attempts := retryTimes + 1
@@ -319,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
var lastErr error var lastErr error
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; attempt < attempts; attempt++ {
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) { resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
return m.executeCountWithProvider(execCtx, provider, req, opts)
})
if errExec == nil { if errExec == nil {
return resp, nil return resp, nil
} }
lastErr = errExec lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait) wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
if !shouldRetry { if !shouldRetry {
break break
} }
@@ -347,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
if len(normalized) == 0 { if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"} return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
} }
rotated := m.rotateProviders(req.Model, normalized)
retryTimes, maxWait := m.retrySettings() retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1 attempts := retryTimes + 1
@@ -357,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
var lastErr error var lastErr error
for attempt := 0; attempt < attempts; attempt++ { for attempt := 0; attempt < attempts; attempt++ {
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) { chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
return m.executeStreamWithProvider(execCtx, provider, req, opts)
})
if errStream == nil { if errStream == nil {
return chunks, nil return chunks, nil
} }
lastErr = errStream lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait) wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
if !shouldRetry { if !shouldRetry {
break break
} }
@@ -378,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "auth_not_found", Message: "no auth available"} return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
} }
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
out <- chunk
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}
}
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if provider == "" { if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"} return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
@@ -1191,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
return authCopy, executor, nil return authCopy, executor, nil
} }
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
providerSet[p] = struct{}{}
}
if len(providerSet) == 0 {
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
m.mu.RLock()
candidates := make([]*Auth, 0, len(m.auths))
modelKey := strings.TrimSpace(model)
registryRef := registry.GetGlobalRegistry()
for _, candidate := range m.auths {
if candidate == nil || candidate.Disabled {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue
}
if _, ok := providerSet[providerKey]; !ok {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if _, ok := m.executors[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, "", errPick
}
if selected == nil {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
executor, okExecutor := m.executors[providerKey]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error { func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil { if m.store == nil || auth == nil {
return nil return nil

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"sort" "sort"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header {
return headers return headers
} }
func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) { func authPriority(auth *Auth) int {
available = make([]*Auth, 0, len(auths)) if auth == nil || auth.Attributes == nil {
return 0
}
raw := strings.TrimSpace(auth.Attributes["priority"])
if raw == "" {
return 0
}
parsed, err := strconv.Atoi(raw)
if err != nil {
return 0
}
return parsed
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ { for i := 0; i < len(auths); i++ {
candidate := auths[i] candidate := auths[i]
blocked, reason, next := isAuthBlockedForModel(candidate, model, now) blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
if !blocked { if !blocked {
available = append(available, candidate) priority := authPriority(candidate)
available[priority] = append(available[priority], candidate)
continue continue
} }
if reason == blockReasonCooldown { if reason == blockReasonCooldown {
@@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []*
} }
} }
} }
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, cooldownCount, earliest return available, cooldownCount, earliest
} }
@@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"} return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
} }
available, cooldownCount, earliest := collectAvailable(auths, model, now) availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
if len(available) == 0 { if len(availableByPriority) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() { if cooldownCount == len(auths) && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now) resetIn := earliest.Sub(now)
if resetIn < 0 { if resetIn < 0 {
resetIn = 0 resetIn = 0
} }
return nil, newModelCooldownError(model, provider, resetIn) return nil, newModelCooldownError(model, providerForError, resetIn)
} }
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"} return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
} }
bestPriority := 0
found := false
for priority := range availableByPriority {
if !found || priority > bestPriority {
bestPriority = priority
found = true
}
}
available := availableByPriority[bestPriority]
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, nil return available, nil
} }

View File

@@ -5,6 +5,7 @@ import (
"errors" "errors"
"sync" "sync"
"testing" "testing"
"time"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
) )
@@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
} }
} }
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "c", Attributes: map[string]string{"priority": "0"}},
{ID: "a", Attributes: map[string]string{"priority": "10"}},
{ID: "b", Attributes: map[string]string{"priority": "10"}},
}
want := []string{"a", "b", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
if got.ID == "c" {
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
}
}
}
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
now := time.Now()
model := "test-model"
high := &Auth{
ID: "high",
Attributes: map[string]string{"priority": "10"},
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: now.Add(30 * time.Minute),
Quota: QuotaState{
Exceeded: true,
},
},
},
}
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "low" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
}
}
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) { func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
selector := &RoundRobinSelector{} selector := &RoundRobinSelector{}
auths := []*Auth{ auths := []*Auth{