feat(auth): add OAuth2 support for xAI with PKCE and token persistence

- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support.
- Added logic for token exchange, refresh, and persistent storage in JSON format.
- Created `xai` package with helpers for OAuth discovery, API token handling, and URL building.
- Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests.
- Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
Luis Pater
2026-05-17 01:02:35 +08:00
parent cd0cea393c
commit e4c957078c
24 changed files with 2050 additions and 4 deletions
+1
View File
@@ -13,6 +13,7 @@ func init() {
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {
+282
View File
@@ -0,0 +1,282 @@
package auth
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// XAIAuthenticator implements the xAI Grok OAuth loopback flow.
type XAIAuthenticator struct{}
// NewXAIAuthenticator constructs a new xAI authenticator.
func NewXAIAuthenticator() Authenticator {
return &XAIAuthenticator{}
}
// Provider returns the provider key for xAI.
func (XAIAuthenticator) Provider() string {
return "xai"
}
// RefreshLead instructs the manager to refresh before token expiry.
func (XAIAuthenticator) RefreshLead() *time.Duration {
lead := xaiauth.RefreshLead()
return &lead
}
// Login launches a local OAuth flow to obtain xAI tokens and persists them.
func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
callbackPort := xaiauth.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := xaiauth.GeneratePKCECodes()
if err != nil {
return nil, fmt.Errorf("xai pkce generation failed: %w", err)
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("xai state generation failed: %w", err)
}
nonce, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("xai nonce generation failed: %w", err)
}
authSvc := xaiauth.NewXAIAuth(cfg)
discovery, err := authSvc.Discover(ctx)
if err != nil {
return nil, err
}
srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort)
if errServer != nil {
return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer)
}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil {
log.Warnf("xai callback server shutdown error: %v", errShutdown)
}
}()
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath)
authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
RedirectURI: redirectURI,
CodeChallenge: pkceCodes.CodeChallenge,
State: state,
Nonce: nonce,
})
if err != nil {
return nil, err
}
if !opts.NoBrowser {
fmt.Println("Opening browser for xAI authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for xAI authentication callback...")
var result callbackResult
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
var manualInputCh <-chan string
var manualInputErrCh <-chan error
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
default:
}
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ")
continue
case input := <-manualInputCh:
manualInputCh = nil
manualInputErrCh = nil
manualResult, ok, errParse := parseXAIManualCallbackToken(input, state)
if errParse != nil {
return nil, errParse
}
if !ok {
continue
}
result = manualResult
break waitForCallback
case errManual := <-manualInputErrCh:
return nil, errManual
case <-timeoutTimer.C:
return nil, fmt.Errorf("xai: authentication timed out")
}
}
if result.Error != "" {
return nil, fmt.Errorf("xai: authentication failed: %s", result.Error)
}
if result.State != state {
return nil, fmt.Errorf("xai: invalid state")
}
if result.Code == "" {
return nil, fmt.Errorf("xai: missing authorization code")
}
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint)
if errExchange != nil {
return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange)
}
tokenStorage := authSvc.CreateTokenStorage(bundle)
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
return nil, fmt.Errorf("xai token storage missing access token")
}
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
label := strings.TrimSpace(tokenStorage.Email)
if label == "" {
label = "xAI"
}
metadata := map[string]any{
"type": "xai",
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"id_token": tokenStorage.IDToken,
"token_type": tokenStorage.TokenType,
"expires_in": tokenStorage.ExpiresIn,
"expired": tokenStorage.Expire,
"last_refresh": tokenStorage.LastRefresh,
"base_url": tokenStorage.BaseURL,
"redirect_uri": tokenStorage.RedirectURI,
"token_endpoint": tokenStorage.TokenEndpoint,
"auth_kind": "oauth",
}
if tokenStorage.Email != "" {
metadata["email"] = tokenStorage.Email
}
if tokenStorage.Subject != "" {
metadata["sub"] = tokenStorage.Subject
}
fmt.Println("xAI authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"auth_kind": "oauth",
"base_url": tokenStorage.BaseURL,
},
}, nil
}
func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) {
token := strings.TrimSpace(input)
if token == "" {
return callbackResult{}, false, nil
}
if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") {
return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token")
}
return callbackResult{Code: token, State: state}, true, nil
}
func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
if port <= 0 {
port = xaiauth.CallbackPort
}
addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, 0, nil, err
}
port = listener.Addr().(*net.TCPAddr).Port
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
result := callbackResult{
Code: strings.TrimSpace(q.Get("code")),
Error: strings.TrimSpace(q.Get("error")),
State: strings.TrimSpace(q.Get("state")),
}
resultCh <- result
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if result.Code != "" && result.Error == "" {
_, _ = w.Write([]byte("<h1>Login successful</h1><p>You can close this window.</p>"))
return
}
_, _ = w.Write([]byte("<h1>Login failed</h1><p>Please check the CLI output.</p>"))
})
srv := &http.Server{
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
go func() {
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
log.Warnf("xai callback server error: %v", errServe)
}
}()
return srv, port, resultCh, nil
}
+37
View File
@@ -0,0 +1,37 @@
package auth
import "testing"
func TestXAIAuthenticatorProviderAndRefreshLead(t *testing.T) {
authenticator := NewXAIAuthenticator()
if authenticator.Provider() != "xai" {
t.Fatalf("Provider() = %q, want xai", authenticator.Provider())
}
lead := authenticator.RefreshLead()
if lead == nil || *lead <= 0 {
t.Fatalf("RefreshLead() = %v, want positive duration", lead)
}
}
func TestParseXAIManualCallbackTokenAcceptsRawCode(t *testing.T) {
result, ok, err := parseXAIManualCallbackToken(" V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg ", "state-1")
if err != nil {
t.Fatalf("parseXAIManualCallbackToken() error = %v", err)
}
if !ok {
t.Fatal("parseXAIManualCallbackToken() ok = false, want true")
}
if result.Code != "V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg" {
t.Fatalf("Code = %q", result.Code)
}
if result.State != "state-1" {
t.Fatalf("State = %q, want state-1", result.State)
}
}
func TestParseXAIManualCallbackTokenRejectsCallbackURL(t *testing.T) {
_, _, err := parseXAIManualCallbackToken("http://127.0.0.1:56121/callback?state=state-1&code=token-1", "state-1")
if err == nil {
t.Fatal("parseXAIManualCallbackToken() error = nil, want error")
}
}
+6
View File
@@ -116,6 +116,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
sdkAuth.NewGeminiAuthenticator(),
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewXAIAuthenticator(),
)
}
@@ -433,6 +434,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "kimi":
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
case "xai":
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
@@ -1156,6 +1159,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
case "xai":
models = registry.GetXAIModels()
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
@@ -0,0 +1,36 @@
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
func TestEnsureExecutorsForAuth_XAIBindsIndependentExecutor(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "xai-auth-1",
Provider: "xai",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
service.ensureExecutorsForAuth(auth)
resolved, ok := service.coreManager.Executor("xai")
if !ok || resolved == nil {
t.Fatal("expected xai executor after bind")
}
if _, isXAI := resolved.(*executor.XAIExecutor); !isXAI {
t.Fatalf("executor type = %T, want *executor.XAIExecutor", resolved)
}
if _, isCodex := resolved.(*executor.CodexAutoExecutor); isCodex {
t.Fatal("xai must not bind the codex auto executor")
}
}