feat(auth): add OAuth2 support for xAI with PKCE and token persistence
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow.
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
bytes := make([]byte, 96)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return nil, fmt.Errorf("xai pkce: generate verifier: %w", err)
|
||||
}
|
||||
verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes)
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
||||
return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TokenStorage stores xAI OAuth credentials on disk.
|
||||
type TokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Expire string `json:"expired,omitempty"`
|
||||
LastRefresh string `json:"last_refresh,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||
AuthKind string `json:"auth_kind,omitempty"`
|
||||
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows the token store to merge status fields before saving.
|
||||
func (ts *TokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile writes xAI credentials to a JSON auth file.
|
||||
func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "xai"
|
||||
ts.AuthKind = "oauth"
|
||||
if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil {
|
||||
return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll)
|
||||
}
|
||||
file, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("xai token storage: create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := file.Close(); errClose != nil {
|
||||
log.Errorf("xai token storage: close token file error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("xai token storage: merge metadata: %w", errMerge)
|
||||
}
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("xai token storage: write token file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used for xAI credentials.
|
||||
func CredentialFileName(email, subject string) string {
|
||||
email = sanitizeFileSegment(email)
|
||||
if email != "" {
|
||||
return fmt.Sprintf("xai-%s.json", email)
|
||||
}
|
||||
subject = sanitizeFileSegment(subject)
|
||||
if subject != "" {
|
||||
return fmt.Sprintf("xai-%s.json", subject)
|
||||
}
|
||||
return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
func sanitizeFileSegment(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
case r >= 'A' && r <= 'Z':
|
||||
b.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == '@' || r == '.' || r == '_' || r == '-':
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
b.WriteRune('-')
|
||||
}
|
||||
}
|
||||
return strings.Trim(b.String(), "-")
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Package xai provides OAuth2 authentication helpers for xAI Grok.
|
||||
package xai
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// DefaultAPIBaseURL is the default xAI Responses API base URL.
|
||||
DefaultAPIBaseURL = "https://api.x.ai/v1"
|
||||
// Issuer is xAI's OAuth issuer.
|
||||
Issuer = "https://auth.x.ai"
|
||||
// DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints.
|
||||
DiscoveryURL = Issuer + "/.well-known/openid-configuration"
|
||||
// ClientID is the public xAI Grok CLI OAuth client ID.
|
||||
ClientID = "b1a00492-073a-47ea-816f-4c329264a828"
|
||||
// Scope is the OAuth scope set required for xAI API access.
|
||||
Scope = "openid profile email offline_access grok-cli:access api:access"
|
||||
// RedirectHost is the loopback host used by xAI OAuth.
|
||||
RedirectHost = "127.0.0.1"
|
||||
// CallbackPort is the preferred loopback callback port.
|
||||
CallbackPort = 56121
|
||||
// RedirectPath is the loopback callback path registered by the xAI client.
|
||||
RedirectPath = "/callback"
|
||||
)
|
||||
|
||||
var refreshLead = 5 * time.Minute
|
||||
|
||||
// RefreshLead returns the refresh lead time for xAI OAuth credentials.
|
||||
func RefreshLead() time.Duration {
|
||||
return refreshLead
|
||||
}
|
||||
|
||||
// PKCECodes holds the PKCE verifier/challenge pair.
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
// AuthorizeURLParams contains the values used to build the xAI OAuth URL.
|
||||
type AuthorizeURLParams struct {
|
||||
AuthorizationEndpoint string
|
||||
RedirectURI string
|
||||
CodeChallenge string
|
||||
State string
|
||||
Nonce string
|
||||
}
|
||||
|
||||
// Discovery contains OAuth endpoints resolved from xAI OIDC discovery.
|
||||
type Discovery struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
}
|
||||
|
||||
// TokenData holds xAI OAuth token data.
|
||||
type TokenData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Expire string `json:"expired,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
}
|
||||
|
||||
// AuthBundle aggregates token data and OAuth metadata for persistence.
|
||||
type AuthBundle struct {
|
||||
TokenData TokenData
|
||||
LastRefresh string
|
||||
BaseURL string
|
||||
RedirectURI string
|
||||
TokenEndpoint string
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
|
||||
type XAIAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
|
||||
func NewXAIAuth(cfg *config.Config) *XAIAuth {
|
||||
return NewXAIAuthWithProxyURL(cfg, "")
|
||||
}
|
||||
|
||||
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
|
||||
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
|
||||
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||
var sdkCfg config.SDKConfig
|
||||
if cfg != nil {
|
||||
sdkCfg = cfg.SDKConfig
|
||||
if effectiveProxyURL == "" {
|
||||
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
}
|
||||
sdkCfg.ProxyURL = effectiveProxyURL
|
||||
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
|
||||
}
|
||||
|
||||
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
|
||||
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
|
||||
rawURL = strings.TrimSpace(rawURL)
|
||||
if rawURL == "" {
|
||||
return "", fmt.Errorf("xai discovery %s is empty", field)
|
||||
}
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
|
||||
}
|
||||
if parsed.Scheme != "https" {
|
||||
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
|
||||
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
|
||||
}
|
||||
return rawURL, nil
|
||||
}
|
||||
|
||||
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
|
||||
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
|
||||
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(params.RedirectURI) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
|
||||
}
|
||||
if strings.TrimSpace(params.CodeChallenge) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: code challenge is required")
|
||||
}
|
||||
if strings.TrimSpace(params.State) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: state is required")
|
||||
}
|
||||
if strings.TrimSpace(params.Nonce) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: nonce is required")
|
||||
}
|
||||
values := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {ClientID},
|
||||
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
|
||||
"scope": {Scope},
|
||||
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {strings.TrimSpace(params.State)},
|
||||
"nonce": {strings.TrimSpace(params.Nonce)},
|
||||
"plan": {"generic"},
|
||||
"referrer": {"cli-proxy-api"},
|
||||
}
|
||||
return endpoint + "?" + values.Encode(), nil
|
||||
}
|
||||
|
||||
// Discover resolves xAI OAuth endpoints through OIDC discovery.
|
||||
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai discovery: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var payload struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
|
||||
}
|
||||
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
|
||||
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
|
||||
}
|
||||
if strings.TrimSpace(code) == "" {
|
||||
return nil, fmt.Errorf("xai token exchange: authorization code is required")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
|
||||
}
|
||||
if strings.TrimSpace(tokenEndpoint) == "" {
|
||||
discovery, errDiscover := a.Discover(ctx)
|
||||
if errDiscover != nil {
|
||||
return nil, errDiscover
|
||||
}
|
||||
tokenEndpoint = discovery.TokenEndpoint
|
||||
}
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {strings.TrimSpace(code)},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"client_id": {ClientID},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthBundle{
|
||||
TokenData: *tokenData,
|
||||
LastRefresh: time.Now().UTC().Format(time.RFC3339),
|
||||
BaseURL: DefaultAPIBaseURL,
|
||||
RedirectURI: strings.TrimSpace(redirectURI),
|
||||
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshTokens refreshes an xAI access token.
|
||||
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return nil, fmt.Errorf("xai token refresh: refresh token is required")
|
||||
}
|
||||
if strings.TrimSpace(tokenEndpoint) == "" {
|
||||
discovery, errDiscover := a.Discover(ctx)
|
||||
if errDiscover != nil {
|
||||
return nil, errDiscover
|
||||
}
|
||||
tokenEndpoint = discovery.TokenEndpoint
|
||||
}
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"client_id": {ClientID},
|
||||
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||
}
|
||||
return a.postTokenForm(ctx, tokenEndpoint, form)
|
||||
}
|
||||
|
||||
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai token request: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai token request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai token request: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai token response: read body: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var payload struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("xai token response: parse body: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(payload.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("xai token response missing access_token")
|
||||
}
|
||||
email, subject := parseJWTIdentity(payload.IDToken)
|
||||
return &TokenData{
|
||||
AccessToken: strings.TrimSpace(payload.AccessToken),
|
||||
RefreshToken: strings.TrimSpace(payload.RefreshToken),
|
||||
IDToken: strings.TrimSpace(payload.IDToken),
|
||||
TokenType: strings.TrimSpace(payload.TokenType),
|
||||
ExpiresIn: payload.ExpiresIn,
|
||||
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
|
||||
Email: email,
|
||||
Subject: subject,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage converts an auth bundle into persistable storage.
|
||||
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
|
||||
if bundle == nil {
|
||||
return nil
|
||||
}
|
||||
return &TokenStorage{
|
||||
Type: "xai",
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
IDToken: bundle.TokenData.IDToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
ExpiresIn: bundle.TokenData.ExpiresIn,
|
||||
Expire: bundle.TokenData.Expire,
|
||||
LastRefresh: bundle.LastRefresh,
|
||||
Email: strings.TrimSpace(bundle.TokenData.Email),
|
||||
Subject: bundle.TokenData.Subject,
|
||||
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
|
||||
RedirectURI: bundle.RedirectURI,
|
||||
TokenEndpoint: bundle.TokenEndpoint,
|
||||
AuthKind: "oauth",
|
||||
}
|
||||
}
|
||||
|
||||
func parseJWTIdentity(token string) (email string, subject string) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return "", ""
|
||||
}
|
||||
payload := parts[1]
|
||||
payload += strings.Repeat("=", (4-len(payload)%4)%4)
|
||||
raw, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
var claims map[string]any
|
||||
if err = json.Unmarshal(raw, &claims); err != nil {
|
||||
return "", ""
|
||||
}
|
||||
if v, ok := claims["email"].(string); ok {
|
||||
email = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := claims["sub"].(string); ok {
|
||||
subject = strings.TrimSpace(v)
|
||||
}
|
||||
return email, subject
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) {
|
||||
authURL, err := BuildAuthorizeURL(AuthorizeURLParams{
|
||||
AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize",
|
||||
RedirectURI: "http://127.0.0.1:56121/callback",
|
||||
CodeChallenge: "challenge",
|
||||
State: "state-123",
|
||||
Nonce: "nonce-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizeURL() error = %v", err)
|
||||
}
|
||||
|
||||
parsed, errParse := url.Parse(authURL)
|
||||
if errParse != nil {
|
||||
t.Fatalf("parse authorize URL: %v", errParse)
|
||||
}
|
||||
if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" {
|
||||
t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path)
|
||||
}
|
||||
|
||||
query := parsed.Query()
|
||||
want := map[string]string{
|
||||
"response_type": "code",
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": "http://127.0.0.1:56121/callback",
|
||||
"scope": Scope,
|
||||
"code_challenge": "challenge",
|
||||
"code_challenge_method": "S256",
|
||||
"state": "state-123",
|
||||
"nonce": "nonce-123",
|
||||
"plan": "generic",
|
||||
"referrer": "cli-proxy-api",
|
||||
}
|
||||
for key, value := range want {
|
||||
if got := query.Get(key); got != value {
|
||||
t.Fatalf("%s = %q, want %q", key, got, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) {
|
||||
if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil {
|
||||
t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err)
|
||||
}
|
||||
if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil {
|
||||
t.Fatal("expected non-HTTPS endpoint to be rejected")
|
||||
}
|
||||
if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil {
|
||||
t.Fatal("expected non-xAI endpoint to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) {
|
||||
var gotForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("method = %s, want POST", r.Method)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("Content-Type = %q, want form", got)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm() error = %v", err)
|
||||
}
|
||||
gotForm = r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
auth := NewXAIAuth(nil)
|
||||
tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshTokens() error = %v", err)
|
||||
}
|
||||
if tokenData.AccessToken != "new-access" {
|
||||
t.Fatalf("access token = %q, want new-access", tokenData.AccessToken)
|
||||
}
|
||||
if gotForm.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type"))
|
||||
}
|
||||
if gotForm.Get("client_id") != ClientID {
|
||||
t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID)
|
||||
}
|
||||
if gotForm.Get("refresh_token") != "old-refresh" {
|
||||
t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token"))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user