feat(auth): add OAuth2 support for xAI with PKCE and token persistence
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
@@ -13,6 +13,7 @@ func init() {
|
||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||
registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() })
|
||||
}
|
||||
|
||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||
|
||||
+282
@@ -0,0 +1,282 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// XAIAuthenticator implements the xAI Grok OAuth loopback flow.
|
||||
type XAIAuthenticator struct{}
|
||||
|
||||
// NewXAIAuthenticator constructs a new xAI authenticator.
|
||||
func NewXAIAuthenticator() Authenticator {
|
||||
return &XAIAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for xAI.
|
||||
func (XAIAuthenticator) Provider() string {
|
||||
return "xai"
|
||||
}
|
||||
|
||||
// RefreshLead instructs the manager to refresh before token expiry.
|
||||
func (XAIAuthenticator) RefreshLead() *time.Duration {
|
||||
lead := xaiauth.RefreshLead()
|
||||
return &lead
|
||||
}
|
||||
|
||||
// Login launches a local OAuth flow to obtain xAI tokens and persists them.
|
||||
func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
callbackPort := xaiauth.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
}
|
||||
|
||||
pkceCodes, err := xaiauth.GeneratePKCECodes()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai pkce generation failed: %w", err)
|
||||
}
|
||||
state, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai state generation failed: %w", err)
|
||||
}
|
||||
nonce, err := misc.GenerateRandomState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai nonce generation failed: %w", err)
|
||||
}
|
||||
|
||||
authSvc := xaiauth.NewXAIAuth(cfg)
|
||||
discovery, err := authSvc.Discover(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort)
|
||||
if errServer != nil {
|
||||
return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer)
|
||||
}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil {
|
||||
log.Warnf("xai callback server shutdown error: %v", errShutdown)
|
||||
}
|
||||
}()
|
||||
|
||||
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath)
|
||||
authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
|
||||
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
|
||||
RedirectURI: redirectURI,
|
||||
CodeChallenge: pkceCodes.CodeChallenge,
|
||||
State: state,
|
||||
Nonce: nonce,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !opts.NoBrowser {
|
||||
fmt.Println("Opening browser for xAI authentication")
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the URL manually")
|
||||
util.PrintSSHTunnelInstructions(port)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
util.PrintSSHTunnelInstructions(port)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
} else {
|
||||
util.PrintSSHTunnelInstructions(port)
|
||||
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for xAI authentication callback...")
|
||||
|
||||
var result callbackResult
|
||||
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
var manualInputCh <-chan string
|
||||
var manualInputErrCh <-chan error
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
default:
|
||||
}
|
||||
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ")
|
||||
continue
|
||||
case input := <-manualInputCh:
|
||||
manualInputCh = nil
|
||||
manualInputErrCh = nil
|
||||
manualResult, ok, errParse := parseXAIManualCallbackToken(input, state)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
result = manualResult
|
||||
break waitForCallback
|
||||
case errManual := <-manualInputErrCh:
|
||||
return nil, errManual
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("xai: authentication timed out")
|
||||
}
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("xai: authentication failed: %s", result.Error)
|
||||
}
|
||||
if result.State != state {
|
||||
return nil, fmt.Errorf("xai: invalid state")
|
||||
}
|
||||
if result.Code == "" {
|
||||
return nil, fmt.Errorf("xai: missing authorization code")
|
||||
}
|
||||
|
||||
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint)
|
||||
if errExchange != nil {
|
||||
return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange)
|
||||
}
|
||||
tokenStorage := authSvc.CreateTokenStorage(bundle)
|
||||
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("xai token storage missing access token")
|
||||
}
|
||||
|
||||
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
|
||||
label := strings.TrimSpace(tokenStorage.Email)
|
||||
if label == "" {
|
||||
label = "xAI"
|
||||
}
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "xai",
|
||||
"access_token": tokenStorage.AccessToken,
|
||||
"refresh_token": tokenStorage.RefreshToken,
|
||||
"id_token": tokenStorage.IDToken,
|
||||
"token_type": tokenStorage.TokenType,
|
||||
"expires_in": tokenStorage.ExpiresIn,
|
||||
"expired": tokenStorage.Expire,
|
||||
"last_refresh": tokenStorage.LastRefresh,
|
||||
"base_url": tokenStorage.BaseURL,
|
||||
"redirect_uri": tokenStorage.RedirectURI,
|
||||
"token_endpoint": tokenStorage.TokenEndpoint,
|
||||
"auth_kind": "oauth",
|
||||
}
|
||||
if tokenStorage.Email != "" {
|
||||
metadata["email"] = tokenStorage.Email
|
||||
}
|
||||
if tokenStorage.Subject != "" {
|
||||
metadata["sub"] = tokenStorage.Subject
|
||||
}
|
||||
|
||||
fmt.Println("xAI authentication successful")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"base_url": tokenStorage.BaseURL,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) {
|
||||
token := strings.TrimSpace(input)
|
||||
if token == "" {
|
||||
return callbackResult{}, false, nil
|
||||
}
|
||||
if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") {
|
||||
return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token")
|
||||
}
|
||||
return callbackResult{Code: token, State: state}, true, nil
|
||||
}
|
||||
|
||||
func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||
if port <= 0 {
|
||||
port = xaiauth.CallbackPort
|
||||
}
|
||||
addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, 0, nil, err
|
||||
}
|
||||
port = listener.Addr().(*net.TCPAddr).Port
|
||||
resultCh := make(chan callbackResult, 1)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
result := callbackResult{
|
||||
Code: strings.TrimSpace(q.Get("code")),
|
||||
Error: strings.TrimSpace(q.Get("error")),
|
||||
State: strings.TrimSpace(q.Get("state")),
|
||||
}
|
||||
resultCh <- result
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if result.Code != "" && result.Error == "" {
|
||||
_, _ = w.Write([]byte("<h1>Login successful</h1><p>You can close this window.</p>"))
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte("<h1>Login failed</h1><p>Please check the CLI output.</p>"))
|
||||
})
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
}
|
||||
go func() {
|
||||
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
|
||||
log.Warnf("xai callback server error: %v", errServe)
|
||||
}
|
||||
}()
|
||||
|
||||
return srv, port, resultCh, nil
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package auth
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestXAIAuthenticatorProviderAndRefreshLead(t *testing.T) {
|
||||
authenticator := NewXAIAuthenticator()
|
||||
if authenticator.Provider() != "xai" {
|
||||
t.Fatalf("Provider() = %q, want xai", authenticator.Provider())
|
||||
}
|
||||
lead := authenticator.RefreshLead()
|
||||
if lead == nil || *lead <= 0 {
|
||||
t.Fatalf("RefreshLead() = %v, want positive duration", lead)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseXAIManualCallbackTokenAcceptsRawCode(t *testing.T) {
|
||||
result, ok, err := parseXAIManualCallbackToken(" V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg ", "state-1")
|
||||
if err != nil {
|
||||
t.Fatalf("parseXAIManualCallbackToken() error = %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("parseXAIManualCallbackToken() ok = false, want true")
|
||||
}
|
||||
if result.Code != "V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg" {
|
||||
t.Fatalf("Code = %q", result.Code)
|
||||
}
|
||||
if result.State != "state-1" {
|
||||
t.Fatalf("State = %q, want state-1", result.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseXAIManualCallbackTokenRejectsCallbackURL(t *testing.T) {
|
||||
_, _, err := parseXAIManualCallbackToken("http://127.0.0.1:56121/callback?state=state-1&code=token-1", "state-1")
|
||||
if err == nil {
|
||||
t.Fatal("parseXAIManualCallbackToken() error = nil, want error")
|
||||
}
|
||||
}
|
||||
@@ -116,6 +116,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewGeminiAuthenticator(),
|
||||
sdkAuth.NewCodexAuthenticator(),
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewXAIAuthenticator(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -433,6 +434,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||
case "kimi":
|
||||
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
|
||||
case "xai":
|
||||
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
|
||||
default:
|
||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
if providerKey == "" {
|
||||
@@ -1156,6 +1159,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
case "kimi":
|
||||
models = registry.GetKimiModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "xai":
|
||||
models = registry.GetXAIModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
)
|
||||
|
||||
func TestEnsureExecutorsForAuth_XAIBindsIndependentExecutor(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "xai-auth-1",
|
||||
Provider: "xai",
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
resolved, ok := service.coreManager.Executor("xai")
|
||||
if !ok || resolved == nil {
|
||||
t.Fatal("expected xai executor after bind")
|
||||
}
|
||||
if _, isXAI := resolved.(*executor.XAIExecutor); !isXAI {
|
||||
t.Fatalf("executor type = %T, want *executor.XAIExecutor", resolved)
|
||||
}
|
||||
if _, isCodex := resolved.(*executor.CodexAutoExecutor); isCodex {
|
||||
t.Fatal("xai must not bind the codex auto executor")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user