e4c957078c
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
283 lines
8.3 KiB
Go
283 lines
8.3 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/browser"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
|
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// XAIAuthenticator implements the xAI Grok OAuth loopback flow.
|
|
type XAIAuthenticator struct{}
|
|
|
|
// NewXAIAuthenticator constructs a new xAI authenticator.
|
|
func NewXAIAuthenticator() Authenticator {
|
|
return &XAIAuthenticator{}
|
|
}
|
|
|
|
// Provider returns the provider key for xAI.
|
|
func (XAIAuthenticator) Provider() string {
|
|
return "xai"
|
|
}
|
|
|
|
// RefreshLead instructs the manager to refresh before token expiry.
|
|
func (XAIAuthenticator) RefreshLead() *time.Duration {
|
|
lead := xaiauth.RefreshLead()
|
|
return &lead
|
|
}
|
|
|
|
// Login launches a local OAuth flow to obtain xAI tokens and persists them.
|
|
func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
|
if cfg == nil {
|
|
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
|
}
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
if opts == nil {
|
|
opts = &LoginOptions{}
|
|
}
|
|
|
|
callbackPort := xaiauth.CallbackPort
|
|
if opts.CallbackPort > 0 {
|
|
callbackPort = opts.CallbackPort
|
|
}
|
|
|
|
pkceCodes, err := xaiauth.GeneratePKCECodes()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai pkce generation failed: %w", err)
|
|
}
|
|
state, err := misc.GenerateRandomState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai state generation failed: %w", err)
|
|
}
|
|
nonce, err := misc.GenerateRandomState()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai nonce generation failed: %w", err)
|
|
}
|
|
|
|
authSvc := xaiauth.NewXAIAuth(cfg)
|
|
discovery, err := authSvc.Discover(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort)
|
|
if errServer != nil {
|
|
return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer)
|
|
}
|
|
defer func() {
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil {
|
|
log.Warnf("xai callback server shutdown error: %v", errShutdown)
|
|
}
|
|
}()
|
|
|
|
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath)
|
|
authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
|
|
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
|
|
RedirectURI: redirectURI,
|
|
CodeChallenge: pkceCodes.CodeChallenge,
|
|
State: state,
|
|
Nonce: nonce,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !opts.NoBrowser {
|
|
fmt.Println("Opening browser for xAI authentication")
|
|
if !browser.IsAvailable() {
|
|
log.Warn("No browser available; please open the URL manually")
|
|
util.PrintSSHTunnelInstructions(port)
|
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
|
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
|
|
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
|
util.PrintSSHTunnelInstructions(port)
|
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
|
}
|
|
} else {
|
|
util.PrintSSHTunnelInstructions(port)
|
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
|
}
|
|
|
|
fmt.Println("Waiting for xAI authentication callback...")
|
|
|
|
var result callbackResult
|
|
timeoutTimer := time.NewTimer(5 * time.Minute)
|
|
defer timeoutTimer.Stop()
|
|
|
|
var manualPromptTimer *time.Timer
|
|
var manualPromptC <-chan time.Time
|
|
if opts.Prompt != nil {
|
|
manualPromptTimer = time.NewTimer(15 * time.Second)
|
|
manualPromptC = manualPromptTimer.C
|
|
defer manualPromptTimer.Stop()
|
|
}
|
|
|
|
var manualInputCh <-chan string
|
|
var manualInputErrCh <-chan error
|
|
|
|
waitForCallback:
|
|
for {
|
|
select {
|
|
case result = <-callbackCh:
|
|
break waitForCallback
|
|
case <-manualPromptC:
|
|
manualPromptC = nil
|
|
if manualPromptTimer != nil {
|
|
manualPromptTimer.Stop()
|
|
}
|
|
select {
|
|
case result = <-callbackCh:
|
|
break waitForCallback
|
|
default:
|
|
}
|
|
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ")
|
|
continue
|
|
case input := <-manualInputCh:
|
|
manualInputCh = nil
|
|
manualInputErrCh = nil
|
|
manualResult, ok, errParse := parseXAIManualCallbackToken(input, state)
|
|
if errParse != nil {
|
|
return nil, errParse
|
|
}
|
|
if !ok {
|
|
continue
|
|
}
|
|
result = manualResult
|
|
break waitForCallback
|
|
case errManual := <-manualInputErrCh:
|
|
return nil, errManual
|
|
case <-timeoutTimer.C:
|
|
return nil, fmt.Errorf("xai: authentication timed out")
|
|
}
|
|
}
|
|
|
|
if result.Error != "" {
|
|
return nil, fmt.Errorf("xai: authentication failed: %s", result.Error)
|
|
}
|
|
if result.State != state {
|
|
return nil, fmt.Errorf("xai: invalid state")
|
|
}
|
|
if result.Code == "" {
|
|
return nil, fmt.Errorf("xai: missing authorization code")
|
|
}
|
|
|
|
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint)
|
|
if errExchange != nil {
|
|
return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange)
|
|
}
|
|
tokenStorage := authSvc.CreateTokenStorage(bundle)
|
|
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
|
|
return nil, fmt.Errorf("xai token storage missing access token")
|
|
}
|
|
|
|
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
|
|
label := strings.TrimSpace(tokenStorage.Email)
|
|
if label == "" {
|
|
label = "xAI"
|
|
}
|
|
|
|
metadata := map[string]any{
|
|
"type": "xai",
|
|
"access_token": tokenStorage.AccessToken,
|
|
"refresh_token": tokenStorage.RefreshToken,
|
|
"id_token": tokenStorage.IDToken,
|
|
"token_type": tokenStorage.TokenType,
|
|
"expires_in": tokenStorage.ExpiresIn,
|
|
"expired": tokenStorage.Expire,
|
|
"last_refresh": tokenStorage.LastRefresh,
|
|
"base_url": tokenStorage.BaseURL,
|
|
"redirect_uri": tokenStorage.RedirectURI,
|
|
"token_endpoint": tokenStorage.TokenEndpoint,
|
|
"auth_kind": "oauth",
|
|
}
|
|
if tokenStorage.Email != "" {
|
|
metadata["email"] = tokenStorage.Email
|
|
}
|
|
if tokenStorage.Subject != "" {
|
|
metadata["sub"] = tokenStorage.Subject
|
|
}
|
|
|
|
fmt.Println("xAI authentication successful")
|
|
|
|
return &coreauth.Auth{
|
|
ID: fileName,
|
|
Provider: a.Provider(),
|
|
FileName: fileName,
|
|
Label: label,
|
|
Storage: tokenStorage,
|
|
Metadata: metadata,
|
|
Attributes: map[string]string{
|
|
"auth_kind": "oauth",
|
|
"base_url": tokenStorage.BaseURL,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) {
|
|
token := strings.TrimSpace(input)
|
|
if token == "" {
|
|
return callbackResult{}, false, nil
|
|
}
|
|
if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") {
|
|
return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token")
|
|
}
|
|
return callbackResult{Code: token, State: state}, true, nil
|
|
}
|
|
|
|
func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
|
if port <= 0 {
|
|
port = xaiauth.CallbackPort
|
|
}
|
|
addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port)
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return nil, 0, nil, err
|
|
}
|
|
port = listener.Addr().(*net.TCPAddr).Port
|
|
resultCh := make(chan callbackResult, 1)
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
|
|
q := r.URL.Query()
|
|
result := callbackResult{
|
|
Code: strings.TrimSpace(q.Get("code")),
|
|
Error: strings.TrimSpace(q.Get("error")),
|
|
State: strings.TrimSpace(q.Get("state")),
|
|
}
|
|
resultCh <- result
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
if result.Code != "" && result.Error == "" {
|
|
_, _ = w.Write([]byte("<h1>Login successful</h1><p>You can close this window.</p>"))
|
|
return
|
|
}
|
|
_, _ = w.Write([]byte("<h1>Login failed</h1><p>Please check the CLI output.</p>"))
|
|
})
|
|
|
|
srv := &http.Server{
|
|
Handler: mux,
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
WriteTimeout: 5 * time.Second,
|
|
}
|
|
go func() {
|
|
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
|
|
log.Warnf("xai callback server error: %v", errServe)
|
|
}
|
|
}()
|
|
|
|
return srv, port, resultCh, nil
|
|
}
|