feat(auth): add OAuth2 support for xAI with PKCE and token persistence

- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support.
- Added logic for token exchange, refresh, and persistent storage in JSON format.
- Created `xai` package with helpers for OAuth discovery, API token handling, and URL building.
- Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests.
- Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
Luis Pater
2026-05-17 01:02:35 +08:00
parent cd0cea393c
commit e4c957078c
24 changed files with 2050 additions and 4 deletions
+20
View File
@@ -0,0 +1,20 @@
package xai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
)
// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow.
func GeneratePKCECodes() (*PKCECodes, error) {
bytes := make([]byte, 96)
if _, err := rand.Read(bytes); err != nil {
return nil, fmt.Errorf("xai pkce: generate verifier: %w", err)
}
verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil
}
+104
View File
@@ -0,0 +1,104 @@
package xai
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
log "github.com/sirupsen/logrus"
)
// TokenStorage stores xAI OAuth credentials on disk.
type TokenStorage struct {
Type string `json:"type"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
TokenType string `json:"token_type,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Expire string `json:"expired,omitempty"`
LastRefresh string `json:"last_refresh,omitempty"`
Email string `json:"email,omitempty"`
Subject string `json:"sub,omitempty"`
BaseURL string `json:"base_url,omitempty"`
RedirectURI string `json:"redirect_uri,omitempty"`
TokenEndpoint string `json:"token_endpoint,omitempty"`
AuthKind string `json:"auth_kind,omitempty"`
Metadata map[string]any `json:"-"`
}
// SetMetadata allows the token store to merge status fields before saving.
func (ts *TokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile writes xAI credentials to a JSON auth file.
func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "xai"
ts.AuthKind = "oauth"
if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil {
return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll)
}
file, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("xai token storage: create token file: %w", err)
}
defer func() {
if errClose := file.Close(); errClose != nil {
log.Errorf("xai token storage: close token file error: %v", errClose)
}
}()
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("xai token storage: merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("xai token storage: write token file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used for xAI credentials.
func CredentialFileName(email, subject string) string {
email = sanitizeFileSegment(email)
if email != "" {
return fmt.Sprintf("xai-%s.json", email)
}
subject = sanitizeFileSegment(subject)
if subject != "" {
return fmt.Sprintf("xai-%s.json", subject)
}
return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli())
}
func sanitizeFileSegment(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
var b strings.Builder
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '@' || r == '.' || r == '_' || r == '-':
b.WriteRune(r)
default:
b.WriteRune('-')
}
}
return strings.Trim(b.String(), "-")
}
+72
View File
@@ -0,0 +1,72 @@
// Package xai provides OAuth2 authentication helpers for xAI Grok.
package xai
import "time"
const (
// DefaultAPIBaseURL is the default xAI Responses API base URL.
DefaultAPIBaseURL = "https://api.x.ai/v1"
// Issuer is xAI's OAuth issuer.
Issuer = "https://auth.x.ai"
// DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints.
DiscoveryURL = Issuer + "/.well-known/openid-configuration"
// ClientID is the public xAI Grok CLI OAuth client ID.
ClientID = "b1a00492-073a-47ea-816f-4c329264a828"
// Scope is the OAuth scope set required for xAI API access.
Scope = "openid profile email offline_access grok-cli:access api:access"
// RedirectHost is the loopback host used by xAI OAuth.
RedirectHost = "127.0.0.1"
// CallbackPort is the preferred loopback callback port.
CallbackPort = 56121
// RedirectPath is the loopback callback path registered by the xAI client.
RedirectPath = "/callback"
)
var refreshLead = 5 * time.Minute
// RefreshLead returns the refresh lead time for xAI OAuth credentials.
func RefreshLead() time.Duration {
return refreshLead
}
// PKCECodes holds the PKCE verifier/challenge pair.
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
// AuthorizeURLParams contains the values used to build the xAI OAuth URL.
type AuthorizeURLParams struct {
AuthorizationEndpoint string
RedirectURI string
CodeChallenge string
State string
Nonce string
}
// Discovery contains OAuth endpoints resolved from xAI OIDC discovery.
type Discovery struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
}
// TokenData holds xAI OAuth token data.
type TokenData struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
TokenType string `json:"token_type,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Expire string `json:"expired,omitempty"`
Email string `json:"email,omitempty"`
Subject string `json:"sub,omitempty"`
}
// AuthBundle aggregates token data and OAuth metadata for persistence.
type AuthBundle struct {
TokenData TokenData
LastRefresh string
BaseURL string
RedirectURI string
TokenEndpoint string
}
+304
View File
@@ -0,0 +1,304 @@
package xai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
log "github.com/sirupsen/logrus"
)
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
type XAIAuth struct {
httpClient *http.Client
}
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
func NewXAIAuth(cfg *config.Config) *XAIAuth {
return NewXAIAuthWithProxyURL(cfg, "")
}
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
effectiveProxyURL := strings.TrimSpace(proxyURL)
var sdkCfg config.SDKConfig
if cfg != nil {
sdkCfg = cfg.SDKConfig
if effectiveProxyURL == "" {
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
}
}
sdkCfg.ProxyURL = effectiveProxyURL
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
}
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return "", fmt.Errorf("xai discovery %s is empty", field)
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
}
if parsed.Scheme != "https" {
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
}
return rawURL, nil
}
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
if err != nil {
return "", err
}
if strings.TrimSpace(params.RedirectURI) == "" {
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
}
if strings.TrimSpace(params.CodeChallenge) == "" {
return "", fmt.Errorf("xai authorize URL: code challenge is required")
}
if strings.TrimSpace(params.State) == "" {
return "", fmt.Errorf("xai authorize URL: state is required")
}
if strings.TrimSpace(params.Nonce) == "" {
return "", fmt.Errorf("xai authorize URL: nonce is required")
}
values := url.Values{
"response_type": {"code"},
"client_id": {ClientID},
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
"scope": {Scope},
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
"code_challenge_method": {"S256"},
"state": {strings.TrimSpace(params.State)},
"nonce": {strings.TrimSpace(params.Nonce)},
"plan": {"generic"},
"referrer": {"cli-proxy-api"},
}
return endpoint + "?" + values.Encode(), nil
}
// Discover resolves xAI OAuth endpoints through OIDC discovery.
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
if err != nil {
return nil, fmt.Errorf("xai discovery: create request: %w", err)
}
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("xai discovery: close response body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("xai discovery: read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
}
if err = json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
}
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
if err != nil {
return nil, err
}
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
if err != nil {
return nil, err
}
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
}
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
}
if strings.TrimSpace(code) == "" {
return nil, fmt.Errorf("xai token exchange: authorization code is required")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
}
if strings.TrimSpace(tokenEndpoint) == "" {
discovery, errDiscover := a.Discover(ctx)
if errDiscover != nil {
return nil, errDiscover
}
tokenEndpoint = discovery.TokenEndpoint
}
form := url.Values{
"grant_type": {"authorization_code"},
"code": {strings.TrimSpace(code)},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"client_id": {ClientID},
"code_verifier": {pkceCodes.CodeVerifier},
}
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
if err != nil {
return nil, err
}
return &AuthBundle{
TokenData: *tokenData,
LastRefresh: time.Now().UTC().Format(time.RFC3339),
BaseURL: DefaultAPIBaseURL,
RedirectURI: strings.TrimSpace(redirectURI),
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
}, nil
}
// RefreshTokens refreshes an xAI access token.
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("xai token refresh: refresh token is required")
}
if strings.TrimSpace(tokenEndpoint) == "" {
discovery, errDiscover := a.Discover(ctx)
if errDiscover != nil {
return nil, errDiscover
}
tokenEndpoint = discovery.TokenEndpoint
}
form := url.Values{
"grant_type": {"refresh_token"},
"client_id": {ClientID},
"refresh_token": {strings.TrimSpace(refreshToken)},
}
return a.postTokenForm(ctx, tokenEndpoint, form)
}
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("xai token request: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xai token request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("xai token request: close response body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("xai token response: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err = json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("xai token response: parse body: %w", err)
}
if strings.TrimSpace(payload.AccessToken) == "" {
return nil, fmt.Errorf("xai token response missing access_token")
}
email, subject := parseJWTIdentity(payload.IDToken)
return &TokenData{
AccessToken: strings.TrimSpace(payload.AccessToken),
RefreshToken: strings.TrimSpace(payload.RefreshToken),
IDToken: strings.TrimSpace(payload.IDToken),
TokenType: strings.TrimSpace(payload.TokenType),
ExpiresIn: payload.ExpiresIn,
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
Email: email,
Subject: subject,
}, nil
}
// CreateTokenStorage converts an auth bundle into persistable storage.
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
if bundle == nil {
return nil
}
return &TokenStorage{
Type: "xai",
AccessToken: bundle.TokenData.AccessToken,
RefreshToken: bundle.TokenData.RefreshToken,
IDToken: bundle.TokenData.IDToken,
TokenType: bundle.TokenData.TokenType,
ExpiresIn: bundle.TokenData.ExpiresIn,
Expire: bundle.TokenData.Expire,
LastRefresh: bundle.LastRefresh,
Email: strings.TrimSpace(bundle.TokenData.Email),
Subject: bundle.TokenData.Subject,
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
RedirectURI: bundle.RedirectURI,
TokenEndpoint: bundle.TokenEndpoint,
AuthKind: "oauth",
}
}
func parseJWTIdentity(token string) (email string, subject string) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return "", ""
}
payload := parts[1]
payload += strings.Repeat("=", (4-len(payload)%4)%4)
raw, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return "", ""
}
var claims map[string]any
if err = json.Unmarshal(raw, &claims); err != nil {
return "", ""
}
if v, ok := claims["email"].(string); ok {
email = strings.TrimSpace(v)
}
if v, ok := claims["sub"].(string); ok {
subject = strings.TrimSpace(v)
}
return email, subject
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
+105
View File
@@ -0,0 +1,105 @@
package xai
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) {
authURL, err := BuildAuthorizeURL(AuthorizeURLParams{
AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize",
RedirectURI: "http://127.0.0.1:56121/callback",
CodeChallenge: "challenge",
State: "state-123",
Nonce: "nonce-123",
})
if err != nil {
t.Fatalf("BuildAuthorizeURL() error = %v", err)
}
parsed, errParse := url.Parse(authURL)
if errParse != nil {
t.Fatalf("parse authorize URL: %v", errParse)
}
if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" {
t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path)
}
query := parsed.Query()
want := map[string]string{
"response_type": "code",
"client_id": ClientID,
"redirect_uri": "http://127.0.0.1:56121/callback",
"scope": Scope,
"code_challenge": "challenge",
"code_challenge_method": "S256",
"state": "state-123",
"nonce": "nonce-123",
"plan": "generic",
"referrer": "cli-proxy-api",
}
for key, value := range want {
if got := query.Get(key); got != value {
t.Fatalf("%s = %q, want %q", key, got, value)
}
}
}
func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) {
if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil {
t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err)
}
if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil {
t.Fatal("expected non-HTTPS endpoint to be rejected")
}
if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil {
t.Fatal("expected non-xAI endpoint to be rejected")
}
}
func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) {
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") {
t.Fatalf("Content-Type = %q, want form", got)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm() error = %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-access",
"refresh_token": "new-refresh",
"token_type": "Bearer",
"expires_in": 3600,
})
}))
defer server.Close()
auth := NewXAIAuth(nil)
tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL)
if err != nil {
t.Fatalf("RefreshTokens() error = %v", err)
}
if tokenData.AccessToken != "new-access" {
t.Fatalf("access token = %q, want new-access", tokenData.AccessToken)
}
if gotForm.Get("grant_type") != "refresh_token" {
t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type"))
}
if gotForm.Get("client_id") != ClientID {
t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID)
}
if gotForm.Get("refresh_token") != "old-refresh" {
t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token"))
}
}