feat(auth): add OAuth2 support for xAI with PKCE and token persistence

- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support.
- Added logic for token exchange, refresh, and persistent storage in JSON format.
- Created `xai` package with helpers for OAuth discovery, API token handling, and URL building.
- Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests.
- Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
Luis Pater
2026-05-17 01:02:35 +08:00
parent cd0cea393c
commit e4c957078c
24 changed files with 2050 additions and 4 deletions
@@ -27,6 +27,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
@@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestXAIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing xAI authentication...")
pkceCodes, errPKCE := xaiauth.GeneratePKCECodes()
if errPKCE != nil {
log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
nonce, errNonce := misc.GenerateRandomState()
if errNonce != nil {
log.Errorf("Failed to generate nonce parameter: %v", errNonce)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"})
return
}
authSvc := xaiauth.NewXAIAuth(h.cfg)
discovery, errDiscover := authSvc.Discover(ctx)
if errDiscover != nil {
log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"})
return
}
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath)
authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
RedirectURI: redirectURI,
CodeChallenge: pkceCodes.CodeChallenge,
State: state,
Nonce: nonce,
})
if errAuthURL != nil {
log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "xai")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/xai/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute xai callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start xai callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "xai") {
return
}
if time.Now().After(deadline) {
log.Error("xai oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("xAI authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed: "+errStr)
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("xAI authentication failed: state mismatch")
SetOAuthSessionError(state, "Authentication failed: state mismatch")
return
}
authCode = strings.TrimSpace(payload["code"])
if authCode == "" {
log.Error("xAI authentication failed: code not found")
SetOAuthSessionError(state, "Authentication failed: code not found")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint)
if errExchange != nil {
log.Errorf("Failed to exchange xAI token: %v", errExchange)
SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange))
return
}
tokenStorage := authSvc.CreateTokenStorage(bundle)
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
log.Error("xAI token exchange returned empty access token")
SetOAuthSessionError(state, "Failed to exchange token")
return
}
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
label := strings.TrimSpace(tokenStorage.Email)
if label == "" {
label = "xAI"
}
metadata := map[string]any{
"type": "xai",
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"id_token": tokenStorage.IDToken,
"token_type": tokenStorage.TokenType,
"expires_in": tokenStorage.ExpiresIn,
"expired": tokenStorage.Expire,
"last_refresh": tokenStorage.LastRefresh,
"base_url": tokenStorage.BaseURL,
"redirect_uri": tokenStorage.RedirectURI,
"token_endpoint": tokenStorage.TokenEndpoint,
"auth_kind": "oauth",
}
if tokenStorage.Email != "" {
metadata["email"] = tokenStorage.Email
}
if tokenStorage.Subject != "" {
metadata["sub"] = tokenStorage.Subject
}
record := &coreauth.Auth{
ID: fileName,
Provider: "xai",
FileName: fileName,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"auth_kind": "oauth",
"base_url": tokenStorage.BaseURL,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save xAI token to file: %v", errSave)
SetOAuthSessionError(state, "Failed to save token to file")
return
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("xai")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use xAI services through this CLI")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "gemini", nil
case "antigravity", "anti-gravity":
return "antigravity", nil
case "xai", "x-ai", "x.ai", "grok":
return "xai", nil
default:
return "", errUnsupportedOAuthFlow
}
+15
View File
@@ -484,6 +484,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/xai/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
@@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}