feat(auth): add OAuth2 support for xAI with PKCE and token persistence
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
|
||||
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
@@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestXAIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing xAI authentication...")
|
||||
|
||||
pkceCodes, errPKCE := xaiauth.GeneratePKCECodes()
|
||||
if errPKCE != nil {
|
||||
log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||
return
|
||||
}
|
||||
|
||||
state, errState := misc.GenerateRandomState()
|
||||
if errState != nil {
|
||||
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
nonce, errNonce := misc.GenerateRandomState()
|
||||
if errNonce != nil {
|
||||
log.Errorf("Failed to generate nonce parameter: %v", errNonce)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
authSvc := xaiauth.NewXAIAuth(h.cfg)
|
||||
discovery, errDiscover := authSvc.Discover(ctx)
|
||||
if errDiscover != nil {
|
||||
log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"})
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath)
|
||||
authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
|
||||
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
|
||||
RedirectURI: redirectURI,
|
||||
CodeChallenge: pkceCodes.CodeChallenge,
|
||||
State: state,
|
||||
Nonce: nonce,
|
||||
})
|
||||
if errAuthURL != nil {
|
||||
log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "xai")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/xai/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute xai callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start xai callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var authCode string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "xai") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("xai oauth flow timed out")
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
var payload map[string]string
|
||||
_ = json.Unmarshal(data, &payload)
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("xAI authentication failed: %s", errStr)
|
||||
SetOAuthSessionError(state, "Authentication failed: "+errStr)
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("xAI authentication failed: state mismatch")
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("xAI authentication failed: code not found")
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint)
|
||||
if errExchange != nil {
|
||||
log.Errorf("Failed to exchange xAI token: %v", errExchange)
|
||||
SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange))
|
||||
return
|
||||
}
|
||||
|
||||
tokenStorage := authSvc.CreateTokenStorage(bundle)
|
||||
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
|
||||
log.Error("xAI token exchange returned empty access token")
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
|
||||
label := strings.TrimSpace(tokenStorage.Email)
|
||||
if label == "" {
|
||||
label = "xAI"
|
||||
}
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "xai",
|
||||
"access_token": tokenStorage.AccessToken,
|
||||
"refresh_token": tokenStorage.RefreshToken,
|
||||
"id_token": tokenStorage.IDToken,
|
||||
"token_type": tokenStorage.TokenType,
|
||||
"expires_in": tokenStorage.ExpiresIn,
|
||||
"expired": tokenStorage.Expire,
|
||||
"last_refresh": tokenStorage.LastRefresh,
|
||||
"base_url": tokenStorage.BaseURL,
|
||||
"redirect_uri": tokenStorage.RedirectURI,
|
||||
"token_endpoint": tokenStorage.TokenEndpoint,
|
||||
"auth_kind": "oauth",
|
||||
}
|
||||
if tokenStorage.Email != "" {
|
||||
metadata["email"] = tokenStorage.Email
|
||||
}
|
||||
if tokenStorage.Subject != "" {
|
||||
metadata["sub"] = tokenStorage.Subject
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "xai",
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"base_url": tokenStorage.BaseURL,
|
||||
},
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save xAI token to file: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("xai")
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use xAI services through this CLI")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
@@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
return "gemini", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "xai", "x-ai", "x.ai", "grok":
|
||||
return "xai", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
|
||||
@@ -484,6 +484,20 @@ func (s *Server) setupRoutes() {
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/xai/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||
}
|
||||
|
||||
@@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user