e4c957078c
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
305 lines
10 KiB
Go
305 lines
10 KiB
Go
package xai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
|
|
type XAIAuth struct {
|
|
httpClient *http.Client
|
|
}
|
|
|
|
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
|
|
func NewXAIAuth(cfg *config.Config) *XAIAuth {
|
|
return NewXAIAuthWithProxyURL(cfg, "")
|
|
}
|
|
|
|
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
|
|
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
|
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
|
var sdkCfg config.SDKConfig
|
|
if cfg != nil {
|
|
sdkCfg = cfg.SDKConfig
|
|
if effectiveProxyURL == "" {
|
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
|
}
|
|
}
|
|
sdkCfg.ProxyURL = effectiveProxyURL
|
|
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
|
|
}
|
|
|
|
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
|
|
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
|
|
rawURL = strings.TrimSpace(rawURL)
|
|
if rawURL == "" {
|
|
return "", fmt.Errorf("xai discovery %s is empty", field)
|
|
}
|
|
parsed, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
|
|
}
|
|
if parsed.Scheme != "https" {
|
|
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
|
|
}
|
|
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
|
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
|
|
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
|
|
}
|
|
return rawURL, nil
|
|
}
|
|
|
|
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
|
|
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
|
|
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if strings.TrimSpace(params.RedirectURI) == "" {
|
|
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
|
|
}
|
|
if strings.TrimSpace(params.CodeChallenge) == "" {
|
|
return "", fmt.Errorf("xai authorize URL: code challenge is required")
|
|
}
|
|
if strings.TrimSpace(params.State) == "" {
|
|
return "", fmt.Errorf("xai authorize URL: state is required")
|
|
}
|
|
if strings.TrimSpace(params.Nonce) == "" {
|
|
return "", fmt.Errorf("xai authorize URL: nonce is required")
|
|
}
|
|
values := url.Values{
|
|
"response_type": {"code"},
|
|
"client_id": {ClientID},
|
|
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
|
|
"scope": {Scope},
|
|
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
|
|
"code_challenge_method": {"S256"},
|
|
"state": {strings.TrimSpace(params.State)},
|
|
"nonce": {strings.TrimSpace(params.Nonce)},
|
|
"plan": {"generic"},
|
|
"referrer": {"cli-proxy-api"},
|
|
}
|
|
return endpoint + "?" + values.Encode(), nil
|
|
}
|
|
|
|
// Discover resolves xAI OAuth endpoints through OIDC discovery.
|
|
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai discovery: create request: %w", err)
|
|
}
|
|
req.Header.Set("Accept", "application/json")
|
|
resp, err := a.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
|
|
}
|
|
defer func() {
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("xai discovery: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai discovery: read response: %w", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
var payload struct {
|
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
|
TokenEndpoint string `json:"token_endpoint"`
|
|
}
|
|
if err = json.Unmarshal(body, &payload); err != nil {
|
|
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
|
|
}
|
|
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
|
|
}
|
|
|
|
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
|
|
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
|
|
if pkceCodes == nil {
|
|
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
|
|
}
|
|
if strings.TrimSpace(code) == "" {
|
|
return nil, fmt.Errorf("xai token exchange: authorization code is required")
|
|
}
|
|
if strings.TrimSpace(redirectURI) == "" {
|
|
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
|
|
}
|
|
if strings.TrimSpace(tokenEndpoint) == "" {
|
|
discovery, errDiscover := a.Discover(ctx)
|
|
if errDiscover != nil {
|
|
return nil, errDiscover
|
|
}
|
|
tokenEndpoint = discovery.TokenEndpoint
|
|
}
|
|
form := url.Values{
|
|
"grant_type": {"authorization_code"},
|
|
"code": {strings.TrimSpace(code)},
|
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
|
"client_id": {ClientID},
|
|
"code_verifier": {pkceCodes.CodeVerifier},
|
|
}
|
|
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &AuthBundle{
|
|
TokenData: *tokenData,
|
|
LastRefresh: time.Now().UTC().Format(time.RFC3339),
|
|
BaseURL: DefaultAPIBaseURL,
|
|
RedirectURI: strings.TrimSpace(redirectURI),
|
|
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
|
|
}, nil
|
|
}
|
|
|
|
// RefreshTokens refreshes an xAI access token.
|
|
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
|
|
if strings.TrimSpace(refreshToken) == "" {
|
|
return nil, fmt.Errorf("xai token refresh: refresh token is required")
|
|
}
|
|
if strings.TrimSpace(tokenEndpoint) == "" {
|
|
discovery, errDiscover := a.Discover(ctx)
|
|
if errDiscover != nil {
|
|
return nil, errDiscover
|
|
}
|
|
tokenEndpoint = discovery.TokenEndpoint
|
|
}
|
|
form := url.Values{
|
|
"grant_type": {"refresh_token"},
|
|
"client_id": {ClientID},
|
|
"refresh_token": {strings.TrimSpace(refreshToken)},
|
|
}
|
|
return a.postTokenForm(ctx, tokenEndpoint, form)
|
|
}
|
|
|
|
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai token request: create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.Header.Set("Accept", "application/json")
|
|
resp, err := a.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai token request failed: %w", err)
|
|
}
|
|
defer func() {
|
|
if errClose := resp.Body.Close(); errClose != nil {
|
|
log.Errorf("xai token request: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("xai token response: read body: %w", err)
|
|
}
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
|
}
|
|
var payload struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
IDToken string `json:"id_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
if err = json.Unmarshal(body, &payload); err != nil {
|
|
return nil, fmt.Errorf("xai token response: parse body: %w", err)
|
|
}
|
|
if strings.TrimSpace(payload.AccessToken) == "" {
|
|
return nil, fmt.Errorf("xai token response missing access_token")
|
|
}
|
|
email, subject := parseJWTIdentity(payload.IDToken)
|
|
return &TokenData{
|
|
AccessToken: strings.TrimSpace(payload.AccessToken),
|
|
RefreshToken: strings.TrimSpace(payload.RefreshToken),
|
|
IDToken: strings.TrimSpace(payload.IDToken),
|
|
TokenType: strings.TrimSpace(payload.TokenType),
|
|
ExpiresIn: payload.ExpiresIn,
|
|
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
|
|
Email: email,
|
|
Subject: subject,
|
|
}, nil
|
|
}
|
|
|
|
// CreateTokenStorage converts an auth bundle into persistable storage.
|
|
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
|
|
if bundle == nil {
|
|
return nil
|
|
}
|
|
return &TokenStorage{
|
|
Type: "xai",
|
|
AccessToken: bundle.TokenData.AccessToken,
|
|
RefreshToken: bundle.TokenData.RefreshToken,
|
|
IDToken: bundle.TokenData.IDToken,
|
|
TokenType: bundle.TokenData.TokenType,
|
|
ExpiresIn: bundle.TokenData.ExpiresIn,
|
|
Expire: bundle.TokenData.Expire,
|
|
LastRefresh: bundle.LastRefresh,
|
|
Email: strings.TrimSpace(bundle.TokenData.Email),
|
|
Subject: bundle.TokenData.Subject,
|
|
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
|
|
RedirectURI: bundle.RedirectURI,
|
|
TokenEndpoint: bundle.TokenEndpoint,
|
|
AuthKind: "oauth",
|
|
}
|
|
}
|
|
|
|
func parseJWTIdentity(token string) (email string, subject string) {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) < 2 {
|
|
return "", ""
|
|
}
|
|
payload := parts[1]
|
|
payload += strings.Repeat("=", (4-len(payload)%4)%4)
|
|
raw, err := base64.URLEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
return "", ""
|
|
}
|
|
var claims map[string]any
|
|
if err = json.Unmarshal(raw, &claims); err != nil {
|
|
return "", ""
|
|
}
|
|
if v, ok := claims["email"].(string); ok {
|
|
email = strings.TrimSpace(v)
|
|
}
|
|
if v, ok := claims["sub"].(string); ok {
|
|
subject = strings.TrimSpace(v)
|
|
}
|
|
return email, subject
|
|
}
|
|
|
|
func firstNonEmpty(values ...string) string {
|
|
for _, value := range values {
|
|
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
|
return trimmed
|
|
}
|
|
}
|
|
return ""
|
|
}
|