feat(auth): add OAuth2 support for xAI with PKCE and token persistence

- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support.
- Added logic for token exchange, refresh, and persistent storage in JSON format.
- Created `xai` package with helpers for OAuth discovery, API token handling, and URL building.
- Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests.
- Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
Luis Pater
2026-05-17 01:02:35 +08:00
parent cd0cea393c
commit e4c957078c
24 changed files with 2050 additions and 4 deletions
+4
View File
@@ -182,6 +182,7 @@ func main() {
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var xaiLogin bool
var projectID string
var vertexImport string
var vertexImportPrefix string
@@ -203,6 +204,7 @@ func main() {
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&xaiLogin, "xai-login", false, "Login to xAI using OAuth")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
@@ -656,6 +658,8 @@ func main() {
cmd.DoClaudeLogin(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else if xaiLogin {
cmd.DoXAILogin(cfg, options)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {
+6 -1
View File
@@ -345,7 +345,7 @@ nonstream-keepalive-interval: 0
# Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
@@ -375,6 +375,9 @@ nonstream-keepalive-interval: 0
# kimi:
# - name: "kimi-k2.5"
# alias: "k2.5"
# xai:
# - name: "grok-4.3"
# alias: "grok-latest"
# OAuth provider excluded models
# oauth-excluded-models:
@@ -395,6 +398,8 @@ nonstream-keepalive-interval: 0
# - "gpt-5-codex-mini"
# kimi:
# - "kimi-k2-thinking"
# xai:
# - "grok-3-mini"
# Optional payload configuration
# payload:
@@ -27,6 +27,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
@@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestXAIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing xAI authentication...")
pkceCodes, errPKCE := xaiauth.GeneratePKCECodes()
if errPKCE != nil {
log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
nonce, errNonce := misc.GenerateRandomState()
if errNonce != nil {
log.Errorf("Failed to generate nonce parameter: %v", errNonce)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"})
return
}
authSvc := xaiauth.NewXAIAuth(h.cfg)
discovery, errDiscover := authSvc.Discover(ctx)
if errDiscover != nil {
log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"})
return
}
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath)
authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
RedirectURI: redirectURI,
CodeChallenge: pkceCodes.CodeChallenge,
State: state,
Nonce: nonce,
})
if errAuthURL != nil {
log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "xai")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/xai/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute xai callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start xai callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "xai") {
return
}
if time.Now().After(deadline) {
log.Error("xai oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("xAI authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed: "+errStr)
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("xAI authentication failed: state mismatch")
SetOAuthSessionError(state, "Authentication failed: state mismatch")
return
}
authCode = strings.TrimSpace(payload["code"])
if authCode == "" {
log.Error("xAI authentication failed: code not found")
SetOAuthSessionError(state, "Authentication failed: code not found")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint)
if errExchange != nil {
log.Errorf("Failed to exchange xAI token: %v", errExchange)
SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange))
return
}
tokenStorage := authSvc.CreateTokenStorage(bundle)
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
log.Error("xAI token exchange returned empty access token")
SetOAuthSessionError(state, "Failed to exchange token")
return
}
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
label := strings.TrimSpace(tokenStorage.Email)
if label == "" {
label = "xAI"
}
metadata := map[string]any{
"type": "xai",
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"id_token": tokenStorage.IDToken,
"token_type": tokenStorage.TokenType,
"expires_in": tokenStorage.ExpiresIn,
"expired": tokenStorage.Expire,
"last_refresh": tokenStorage.LastRefresh,
"base_url": tokenStorage.BaseURL,
"redirect_uri": tokenStorage.RedirectURI,
"token_endpoint": tokenStorage.TokenEndpoint,
"auth_kind": "oauth",
}
if tokenStorage.Email != "" {
metadata["email"] = tokenStorage.Email
}
if tokenStorage.Subject != "" {
metadata["sub"] = tokenStorage.Subject
}
record := &coreauth.Auth{
ID: fileName,
Provider: "xai",
FileName: fileName,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"auth_kind": "oauth",
"base_url": tokenStorage.BaseURL,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save xAI token to file: %v", errSave)
SetOAuthSessionError(state, "Failed to save token to file")
return
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("xai")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use xAI services through this CLI")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
@@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "gemini", nil
case "antigravity", "anti-gravity":
return "antigravity", nil
case "xai", "x-ai", "x.ai", "grok":
return "xai", nil
default:
return "", errUnsupportedOAuthFlow
}
+15
View File
@@ -484,6 +484,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/xai/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
@@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}
+20
View File
@@ -0,0 +1,20 @@
package xai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
)
// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow.
func GeneratePKCECodes() (*PKCECodes, error) {
bytes := make([]byte, 96)
if _, err := rand.Read(bytes); err != nil {
return nil, fmt.Errorf("xai pkce: generate verifier: %w", err)
}
verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil
}
+104
View File
@@ -0,0 +1,104 @@
package xai
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
log "github.com/sirupsen/logrus"
)
// TokenStorage stores xAI OAuth credentials on disk.
type TokenStorage struct {
Type string `json:"type"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
TokenType string `json:"token_type,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Expire string `json:"expired,omitempty"`
LastRefresh string `json:"last_refresh,omitempty"`
Email string `json:"email,omitempty"`
Subject string `json:"sub,omitempty"`
BaseURL string `json:"base_url,omitempty"`
RedirectURI string `json:"redirect_uri,omitempty"`
TokenEndpoint string `json:"token_endpoint,omitempty"`
AuthKind string `json:"auth_kind,omitempty"`
Metadata map[string]any `json:"-"`
}
// SetMetadata allows the token store to merge status fields before saving.
func (ts *TokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile writes xAI credentials to a JSON auth file.
func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "xai"
ts.AuthKind = "oauth"
if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil {
return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll)
}
file, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("xai token storage: create token file: %w", err)
}
defer func() {
if errClose := file.Close(); errClose != nil {
log.Errorf("xai token storage: close token file error: %v", errClose)
}
}()
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("xai token storage: merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("xai token storage: write token file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used for xAI credentials.
func CredentialFileName(email, subject string) string {
email = sanitizeFileSegment(email)
if email != "" {
return fmt.Sprintf("xai-%s.json", email)
}
subject = sanitizeFileSegment(subject)
if subject != "" {
return fmt.Sprintf("xai-%s.json", subject)
}
return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli())
}
func sanitizeFileSegment(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
var b strings.Builder
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '@' || r == '.' || r == '_' || r == '-':
b.WriteRune(r)
default:
b.WriteRune('-')
}
}
return strings.Trim(b.String(), "-")
}
+72
View File
@@ -0,0 +1,72 @@
// Package xai provides OAuth2 authentication helpers for xAI Grok.
package xai
import "time"
const (
// DefaultAPIBaseURL is the default xAI Responses API base URL.
DefaultAPIBaseURL = "https://api.x.ai/v1"
// Issuer is xAI's OAuth issuer.
Issuer = "https://auth.x.ai"
// DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints.
DiscoveryURL = Issuer + "/.well-known/openid-configuration"
// ClientID is the public xAI Grok CLI OAuth client ID.
ClientID = "b1a00492-073a-47ea-816f-4c329264a828"
// Scope is the OAuth scope set required for xAI API access.
Scope = "openid profile email offline_access grok-cli:access api:access"
// RedirectHost is the loopback host used by xAI OAuth.
RedirectHost = "127.0.0.1"
// CallbackPort is the preferred loopback callback port.
CallbackPort = 56121
// RedirectPath is the loopback callback path registered by the xAI client.
RedirectPath = "/callback"
)
var refreshLead = 5 * time.Minute
// RefreshLead returns the refresh lead time for xAI OAuth credentials.
func RefreshLead() time.Duration {
return refreshLead
}
// PKCECodes holds the PKCE verifier/challenge pair.
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
// AuthorizeURLParams contains the values used to build the xAI OAuth URL.
type AuthorizeURLParams struct {
AuthorizationEndpoint string
RedirectURI string
CodeChallenge string
State string
Nonce string
}
// Discovery contains OAuth endpoints resolved from xAI OIDC discovery.
type Discovery struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
}
// TokenData holds xAI OAuth token data.
type TokenData struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
TokenType string `json:"token_type,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Expire string `json:"expired,omitempty"`
Email string `json:"email,omitempty"`
Subject string `json:"sub,omitempty"`
}
// AuthBundle aggregates token data and OAuth metadata for persistence.
type AuthBundle struct {
TokenData TokenData
LastRefresh string
BaseURL string
RedirectURI string
TokenEndpoint string
}
+304
View File
@@ -0,0 +1,304 @@
package xai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
log "github.com/sirupsen/logrus"
)
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
type XAIAuth struct {
httpClient *http.Client
}
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
func NewXAIAuth(cfg *config.Config) *XAIAuth {
return NewXAIAuthWithProxyURL(cfg, "")
}
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
effectiveProxyURL := strings.TrimSpace(proxyURL)
var sdkCfg config.SDKConfig
if cfg != nil {
sdkCfg = cfg.SDKConfig
if effectiveProxyURL == "" {
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
}
}
sdkCfg.ProxyURL = effectiveProxyURL
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
}
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return "", fmt.Errorf("xai discovery %s is empty", field)
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
}
if parsed.Scheme != "https" {
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
}
return rawURL, nil
}
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
if err != nil {
return "", err
}
if strings.TrimSpace(params.RedirectURI) == "" {
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
}
if strings.TrimSpace(params.CodeChallenge) == "" {
return "", fmt.Errorf("xai authorize URL: code challenge is required")
}
if strings.TrimSpace(params.State) == "" {
return "", fmt.Errorf("xai authorize URL: state is required")
}
if strings.TrimSpace(params.Nonce) == "" {
return "", fmt.Errorf("xai authorize URL: nonce is required")
}
values := url.Values{
"response_type": {"code"},
"client_id": {ClientID},
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
"scope": {Scope},
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
"code_challenge_method": {"S256"},
"state": {strings.TrimSpace(params.State)},
"nonce": {strings.TrimSpace(params.Nonce)},
"plan": {"generic"},
"referrer": {"cli-proxy-api"},
}
return endpoint + "?" + values.Encode(), nil
}
// Discover resolves xAI OAuth endpoints through OIDC discovery.
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
if err != nil {
return nil, fmt.Errorf("xai discovery: create request: %w", err)
}
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("xai discovery: close response body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("xai discovery: read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
}
if err = json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
}
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
if err != nil {
return nil, err
}
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
if err != nil {
return nil, err
}
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
}
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
}
if strings.TrimSpace(code) == "" {
return nil, fmt.Errorf("xai token exchange: authorization code is required")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
}
if strings.TrimSpace(tokenEndpoint) == "" {
discovery, errDiscover := a.Discover(ctx)
if errDiscover != nil {
return nil, errDiscover
}
tokenEndpoint = discovery.TokenEndpoint
}
form := url.Values{
"grant_type": {"authorization_code"},
"code": {strings.TrimSpace(code)},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"client_id": {ClientID},
"code_verifier": {pkceCodes.CodeVerifier},
}
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
if err != nil {
return nil, err
}
return &AuthBundle{
TokenData: *tokenData,
LastRefresh: time.Now().UTC().Format(time.RFC3339),
BaseURL: DefaultAPIBaseURL,
RedirectURI: strings.TrimSpace(redirectURI),
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
}, nil
}
// RefreshTokens refreshes an xAI access token.
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("xai token refresh: refresh token is required")
}
if strings.TrimSpace(tokenEndpoint) == "" {
discovery, errDiscover := a.Discover(ctx)
if errDiscover != nil {
return nil, errDiscover
}
tokenEndpoint = discovery.TokenEndpoint
}
form := url.Values{
"grant_type": {"refresh_token"},
"client_id": {ClientID},
"refresh_token": {strings.TrimSpace(refreshToken)},
}
return a.postTokenForm(ctx, tokenEndpoint, form)
}
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("xai token request: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xai token request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("xai token request: close response body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("xai token response: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err = json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("xai token response: parse body: %w", err)
}
if strings.TrimSpace(payload.AccessToken) == "" {
return nil, fmt.Errorf("xai token response missing access_token")
}
email, subject := parseJWTIdentity(payload.IDToken)
return &TokenData{
AccessToken: strings.TrimSpace(payload.AccessToken),
RefreshToken: strings.TrimSpace(payload.RefreshToken),
IDToken: strings.TrimSpace(payload.IDToken),
TokenType: strings.TrimSpace(payload.TokenType),
ExpiresIn: payload.ExpiresIn,
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
Email: email,
Subject: subject,
}, nil
}
// CreateTokenStorage converts an auth bundle into persistable storage.
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
if bundle == nil {
return nil
}
return &TokenStorage{
Type: "xai",
AccessToken: bundle.TokenData.AccessToken,
RefreshToken: bundle.TokenData.RefreshToken,
IDToken: bundle.TokenData.IDToken,
TokenType: bundle.TokenData.TokenType,
ExpiresIn: bundle.TokenData.ExpiresIn,
Expire: bundle.TokenData.Expire,
LastRefresh: bundle.LastRefresh,
Email: strings.TrimSpace(bundle.TokenData.Email),
Subject: bundle.TokenData.Subject,
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
RedirectURI: bundle.RedirectURI,
TokenEndpoint: bundle.TokenEndpoint,
AuthKind: "oauth",
}
}
func parseJWTIdentity(token string) (email string, subject string) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return "", ""
}
payload := parts[1]
payload += strings.Repeat("=", (4-len(payload)%4)%4)
raw, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return "", ""
}
var claims map[string]any
if err = json.Unmarshal(raw, &claims); err != nil {
return "", ""
}
if v, ok := claims["email"].(string); ok {
email = strings.TrimSpace(v)
}
if v, ok := claims["sub"].(string); ok {
subject = strings.TrimSpace(v)
}
return email, subject
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
+105
View File
@@ -0,0 +1,105 @@
package xai
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) {
authURL, err := BuildAuthorizeURL(AuthorizeURLParams{
AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize",
RedirectURI: "http://127.0.0.1:56121/callback",
CodeChallenge: "challenge",
State: "state-123",
Nonce: "nonce-123",
})
if err != nil {
t.Fatalf("BuildAuthorizeURL() error = %v", err)
}
parsed, errParse := url.Parse(authURL)
if errParse != nil {
t.Fatalf("parse authorize URL: %v", errParse)
}
if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" {
t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path)
}
query := parsed.Query()
want := map[string]string{
"response_type": "code",
"client_id": ClientID,
"redirect_uri": "http://127.0.0.1:56121/callback",
"scope": Scope,
"code_challenge": "challenge",
"code_challenge_method": "S256",
"state": "state-123",
"nonce": "nonce-123",
"plan": "generic",
"referrer": "cli-proxy-api",
}
for key, value := range want {
if got := query.Get(key); got != value {
t.Fatalf("%s = %q, want %q", key, got, value)
}
}
}
func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) {
if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil {
t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err)
}
if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil {
t.Fatal("expected non-HTTPS endpoint to be rejected")
}
if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil {
t.Fatal("expected non-xAI endpoint to be rejected")
}
}
func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) {
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") {
t.Fatalf("Content-Type = %q, want form", got)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm() error = %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-access",
"refresh_token": "new-refresh",
"token_type": "Bearer",
"expires_in": 3600,
})
}))
defer server.Close()
auth := NewXAIAuth(nil)
tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL)
if err != nil {
t.Fatalf("RefreshTokens() error = %v", err)
}
if tokenData.AccessToken != "new-access" {
t.Fatalf("access token = %q, want new-access", tokenData.AccessToken)
}
if gotForm.Get("grant_type") != "refresh_token" {
t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type"))
}
if gotForm.Get("client_id") != ClientID {
t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID)
}
if gotForm.Get("refresh_token") != "old-refresh" {
t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token"))
}
}
+2 -1
View File
@@ -6,7 +6,7 @@ import (
// newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for
// Gemini, Codex, Claude, Antigravity, and Kimi providers.
// Gemini, Codex, Claude, Antigravity, Kimi, and xAI providers.
//
// Returns:
// - *sdkAuth.Manager: A configured authentication manager instance
@@ -18,6 +18,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewAntigravityAuthenticator(),
sdkAuth.NewKimiAuthenticator(),
sdkAuth.NewXAIAuthenticator(),
)
return manager
}
+44
View File
@@ -0,0 +1,44 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoXAILogin triggers the OAuth flow for the xAI provider and saves tokens.
func DoXAILogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
record, savedPath, err := manager.Login(context.Background(), "xai", cfg, authOpts)
if err != nil {
log.Errorf("xAI authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("xAI authentication successful!")
}
+1 -1
View File
@@ -137,7 +137,7 @@ type Config struct {
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi.
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
+10
View File
@@ -21,6 +21,7 @@ type staticModelsJSON struct {
CodexPro []*ModelInfo `json:"codex-pro"`
Kimi []*ModelInfo `json:"kimi"`
Antigravity []*ModelInfo `json:"antigravity"`
XAI []*ModelInfo `json:"xai"`
}
// GetClaudeModels returns the standard Claude model definitions.
@@ -78,6 +79,11 @@ func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity)
}
// GetXAIModels returns the standard xAI Grok model definitions.
func GetXAIModels() []*ModelInfo {
return cloneModelInfos(getModels().XAI)
}
// WithCodexBuiltins injects hard-coded Codex-only model definitions that should
// not depend on remote models.json updates. Built-ins replace any matching IDs
// already present in the provided slice.
@@ -167,6 +173,7 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
// - codex
// - kimi
// - antigravity
// - xai
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel))
switch key {
@@ -186,6 +193,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetKimiModels()
case "antigravity":
return GetAntigravityModels()
case "xai", "x-ai", "grok":
return GetXAIModels()
default:
return nil
}
@@ -208,6 +217,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
data.CodexPro,
data.Kimi,
data.Antigravity,
data.XAI,
}
for _, models := range allModels {
for _, m := range models {
+2
View File
@@ -215,6 +215,7 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
{"codex", oldData.CodexPro, newData.CodexPro},
{"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity},
{"xai", oldData.XAI, newData.XAI},
}
seen := make(map[string]bool, len(sections))
@@ -335,6 +336,7 @@ func validateModelsCatalog(data *staticModelsJSON) error {
{name: "codex-pro", models: data.CodexPro},
{name: "kimi", models: data.Kimi},
{name: "antigravity", models: data.Antigravity},
{name: "xai", models: data.XAI},
}
for _, section := range requiredSections {
+106 -1
View File
@@ -46,7 +46,8 @@
"levels": [
"low",
"medium",
"high"
"high",
"xhigh"
]
}
},
@@ -2064,5 +2065,109 @@
]
}
}
],
"xai": [
{
"id": "grok-4.3",
"object": "model",
"created": 1775606400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.3",
"name": "grok-4.3",
"description": "xAI Grok 4.3 model for the Responses API.",
"context_length": 1000000,
"max_completion_tokens": 65536,
"thinking": {
"zero_allowed": true,
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "grok-4.20-0309-reasoning",
"object": "model",
"created": 1773014400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.20 0309 Reasoning",
"name": "grok-4.20-0309-reasoning",
"description": "xAI Grok 4.20 0309 reasoning model for the Responses API.",
"context_length": 2000000,
"max_completion_tokens": 65536
},
{
"id": "grok-4.20-0309-non-reasoning",
"object": "model",
"created": 1773014400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.20 0309 Non Reasoning",
"name": "grok-4.20-0309-non-reasoning",
"description": "xAI Grok 4.20 0309 non-reasoning model for the Responses API.",
"context_length": 2000000,
"max_completion_tokens": 65536
},
{
"id": "grok-4.20-multi-agent-0309",
"object": "model",
"created": 1773014400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.20 Multi Agent 0309",
"name": "grok-4.20-multi-agent-0309",
"description": "xAI Grok 4.20 multi-agent model for the Responses API.",
"context_length": 2000000,
"max_completion_tokens": 65536,
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "grok-3-mini",
"object": "model",
"created": 1740960000,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 3 Mini",
"name": "grok-3-mini",
"description": "xAI Grok 3 Mini model for the Responses API.",
"context_length": 131072,
"max_completion_tokens": 32768,
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "grok-3-mini-fast",
"object": "model",
"created": 1740960000,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 3 Mini Fast",
"name": "grok-3-mini-fast",
"description": "xAI Grok 3 Mini Fast model for the Responses API.",
"context_length": 131072,
"max_completion_tokens": 32768,
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
}
]
}
+570
View File
@@ -0,0 +1,570 @@
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tiktoken-go/tokenizer"
)
var xaiDataTag = []byte("data:")
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
type XAIExecutor struct {
cfg *config.Config
}
// NewXAIExecutor creates a new xAI executor.
func NewXAIExecutor(cfg *config.Config) *XAIExecutor {
return &XAIExecutor{cfg: cfg}
}
// Identifier returns the provider identifier.
func (e *XAIExecutor) Identifier() string {
return "xai"
}
// PrepareRequest injects xAI credentials into the outgoing HTTP request.
func (e *XAIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token, _ := xaiCreds(auth)
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects xAI credentials into the request and executes it.
func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("xai executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
return nil, errPrepare
}
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
token, baseURL := xaiCreds(auth)
if baseURL == "" {
baseURL = xaiauth.DefaultAPIBaseURL
}
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
if err != nil {
return resp, err
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
if err != nil {
return resp, err
}
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range bytes.Split(data, []byte("\n")) {
if !bytes.HasPrefix(line, xaiDataTag) {
continue
}
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
var param any
out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, &param)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
}
return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"}
}
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
token, baseURL := xaiCreds(auth)
if baseURL == "" {
baseURL = xaiauth.DefaultAPIBaseURL
}
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
if err != nil {
return nil, err
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
if err != nil {
return nil, err
}
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return nil, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800)
var param any
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for scanner.Scan() {
line := scanner.Bytes()
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
translatedLine := bytes.Clone(line)
if bytes.HasPrefix(line, xaiDataTag) {
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
eventData = xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
translatedLine = append([]byte("data: "), eventData...)
}
}
chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, &param)
for i := range chunks {
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx, errScan)
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for xAI Responses requests.
func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
prepared, err := e.prepareResponsesRequest(ctx, req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
}
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: tokenizer init failed: %w", err)
}
count, err := enc.Count(string(prepared.body))
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err)
}
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON))
return cliproxyexecutor.Response{Payload: translated}, nil
}
// Refresh refreshes xAI OAuth credentials using the stored refresh token.
func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("xai executor: refresh called")
if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled {
return refreshed, err
}
if auth == nil {
return nil, statusErr{code: http.StatusInternalServerError, msg: "xai executor: auth is nil"}
}
refreshToken := xaiMetadataString(auth.Metadata, "refresh_token")
if refreshToken == "" {
return auth, nil
}
tokenEndpoint := xaiMetadataString(auth.Metadata, "token_endpoint")
svc := xaiauth.NewXAIAuthWithProxyURL(e.cfg, auth.ProxyURL)
td, err := svc.RefreshTokens(ctx, refreshToken, tokenEndpoint)
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["type"] = "xai"
auth.Metadata["auth_kind"] = "oauth"
auth.Metadata["access_token"] = td.AccessToken
if td.RefreshToken != "" {
auth.Metadata["refresh_token"] = td.RefreshToken
}
if td.IDToken != "" {
auth.Metadata["id_token"] = td.IDToken
}
if td.TokenType != "" {
auth.Metadata["token_type"] = td.TokenType
}
if td.ExpiresIn > 0 {
auth.Metadata["expires_in"] = td.ExpiresIn
}
if td.Expire != "" {
auth.Metadata["expired"] = td.Expire
}
if td.Email != "" {
auth.Metadata["email"] = td.Email
}
if td.Subject != "" {
auth.Metadata["sub"] = td.Subject
}
if tokenEndpoint != "" {
auth.Metadata["token_endpoint"] = tokenEndpoint
}
if xaiMetadataString(auth.Metadata, "base_url") == "" {
auth.Metadata["base_url"] = xaiauth.DefaultAPIBaseURL
}
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["auth_kind"] = "oauth"
if strings.TrimSpace(auth.Attributes["base_url"]) == "" {
auth.Attributes["base_url"] = xaiauth.DefaultAPIBaseURL
}
return auth, nil
}
type xaiPreparedRequest struct {
baseModel string
from sdktranslator.Format
to sdktranslator.Format
originalPayload []byte
body []byte
sessionID string
}
func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := bytes.Clone(originalPayloadSource)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
var err error
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", stream)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeCodexInstructions(body)
body = sanitizeXAIResponsesBody(body, baseModel)
sessionID := xaiExecutionSessionID(req, opts)
if sessionID != "" {
body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID)
}
return &xaiPreparedRequest{
baseModel: baseModel,
from: from,
to: to,
originalPayload: originalPayload,
body: body,
sessionID: sessionID,
}, nil
}
func (e *XAIExecutor) recordXAIRequest(ctx context.Context, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: headers,
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
func xaiCreds(auth *cliproxyauth.Auth) (token, baseURL string) {
if auth == nil {
return "", ""
}
if auth.Attributes != nil {
token = strings.TrimSpace(auth.Attributes["api_key"])
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
}
if auth.Metadata != nil {
if token == "" {
token = xaiMetadataString(auth.Metadata, "access_token")
}
if baseURL == "" {
baseURL = xaiMetadataString(auth.Metadata, "base_url")
}
}
return token, baseURL
}
func applyXAIHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, sessionID string) {
r.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
r.Header.Set("Authorization", "Bearer "+token)
}
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
r.Header.Set("Accept", "application/json")
}
r.Header.Set("Connection", "Keep-Alive")
if sessionID != "" {
r.Header.Set("x-grok-conv-id", sessionID)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
}
func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) string {
if value := xaiMetadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
return value
}
if value := xaiMetadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
return value
}
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
return strings.TrimSpace(promptCacheKey.String())
}
return ""
}
func xaiMetadataString(meta map[string]any, key string) string {
if len(meta) == 0 || key == "" {
return ""
}
value, ok := meta[key]
if !ok || value == nil {
return ""
}
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
case fmt.Stringer:
return strings.TrimSpace(typed.String())
default:
return strings.TrimSpace(fmt.Sprint(typed))
}
}
func sanitizeXAIResponsesBody(body []byte, model string) []byte {
body = removeXAIEncryptedReasoningInclude(body)
if !xaiSupportsReasoningEffort(model) {
body, _ = sjson.DeleteBytes(body, "reasoning")
}
return body
}
func removeXAIEncryptedReasoningInclude(body []byte) []byte {
include := gjson.GetBytes(body, "include")
if !include.Exists() || !include.IsArray() {
return body
}
kept := make([]string, 0, len(include.Array()))
for _, item := range include.Array() {
value := strings.TrimSpace(item.String())
if value == "" || value == "reasoning.encrypted_content" {
continue
}
kept = append(kept, value)
}
body, _ = sjson.SetBytes(body, "include", kept)
return body
}
func xaiSupportsReasoningEffort(model string) bool {
name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName))
if idx := strings.LastIndex(name, "/"); idx >= 0 {
name = name[idx+1:]
}
switch {
case strings.HasPrefix(name, "grok-3-mini"):
return true
case strings.HasPrefix(name, "grok-4.20-multi-agent"):
return true
case strings.HasPrefix(name, "grok-4.3"):
return true
default:
return false
}
}
func xaiCollectOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) {
itemResult := gjson.GetBytes(eventData, "item")
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
return
}
outputIndexResult := gjson.GetBytes(eventData, "output_index")
if outputIndexResult.Exists() {
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
return
}
*outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw))
}
func xaiPatchCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte {
outputResult := gjson.GetBytes(eventData, "response.output")
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
if !shouldPatchOutput {
return eventData
}
indexes := make([]int64, 0, len(outputItemsByIndex))
for idx := range outputItemsByIndex {
indexes = append(indexes, idx)
}
sort.Slice(indexes, func(i, j int) bool {
return indexes[i] < indexes[j]
})
outputArray := []byte("[]")
var buf bytes.Buffer
buf.WriteByte('[')
wrote := false
for _, idx := range indexes {
if wrote {
buf.WriteByte(',')
}
buf.Write(outputItemsByIndex[idx])
wrote = true
}
for _, item := range outputItemsFallback {
if wrote {
buf.WriteByte(',')
}
buf.Write(item)
wrote = true
}
buf.WriteByte(']')
if wrote {
outputArray = buf.Bytes()
}
patched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray)
return patched
}
@@ -0,0 +1,138 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
"github.com/tidwall/gjson"
)
func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) {
var gotPath string
var gotAuth string
var gotGrokConvID string
var gotOriginator string
var gotAccountID string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
gotGrokConvID = r.Header.Get("x-grok-conv-id")
gotOriginator = r.Header.Get("Originator")
gotAccountID = r.Header.Get("Chatgpt-Account-Id")
var errRead error
gotBody, errRead = io.ReadAll(r.Body)
if errRead != nil {
t.Fatalf("read body: %v", errRead)
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}],\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}\n\n"))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "xai-auth",
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"auth_kind": "oauth",
},
Metadata: map[string]any{
"access_token": "xai-token",
"email": "user@example.com",
},
}
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-4.3",
Payload: []byte(`{"model":"grok-4.3","input":"hello","include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Stream: false,
Metadata: map[string]any{
cliproxyexecutor.ExecutionSessionMetadataKey: "conv-xai-1",
},
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want /responses", gotPath)
}
if gotAuth != "Bearer xai-token" {
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
}
if gotGrokConvID != "conv-xai-1" {
t.Fatalf("x-grok-conv-id = %q, want conv-xai-1", gotGrokConvID)
}
if gotOriginator != "" {
t.Fatalf("Originator = %q, want empty", gotOriginator)
}
if gotAccountID != "" {
t.Fatalf("Chatgpt-Account-Id = %q, want empty", gotAccountID)
}
if gjson.GetBytes(gotBody, "prompt_cache_key").String() != "conv-xai-1" {
t.Fatalf("prompt_cache_key missing from body: %s", string(gotBody))
}
if !gjson.GetBytes(gotBody, "stream").Bool() {
t.Fatalf("stream = false, want true; body=%s", string(gotBody))
}
if gjson.GetBytes(gotBody, "reasoning.effort").String() != "high" {
t.Fatalf("reasoning.effort = %q, want high; body=%s", gjson.GetBytes(gotBody, "reasoning.effort").String(), string(gotBody))
}
for _, include := range gjson.GetBytes(gotBody, "include").Array() {
if include.String() == "reasoning.encrypted_content" {
t.Fatalf("xai request must not ask for encrypted reasoning content: %s", string(gotBody))
}
}
}
func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var errRead error
gotBody, errRead = io.ReadAll(r.Body)
if errRead != nil {
t.Fatalf("read body: %v", errRead)
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"auth_kind": "oauth",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-4",
Payload: []byte(`{"model":"grok-4","input":"hello","reasoning":{"effort":"high"}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Stream: false,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gjson.GetBytes(gotBody, "reasoning").Exists() {
t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody))
}
}
+3
View File
@@ -24,6 +24,7 @@ var oauthProviders = []oauthProvider{
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
{"Antigravity", "antigravity-auth-url", "🟪"},
{"Kimi", "kimi-auth-url", "🟫"},
{"xAI", "xai-auth-url", "⬛"},
}
// oauthTabModel handles OAuth login flows.
@@ -280,6 +281,8 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
providerKey = "antigravity"
case "kimi-auth-url":
providerKey = "kimi"
case "xai-auth-url":
providerKey = "xai"
}
break
}
+1
View File
@@ -13,6 +13,7 @@ func init() {
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {
+282
View File
@@ -0,0 +1,282 @@
package auth
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// XAIAuthenticator implements the xAI Grok OAuth loopback flow.
type XAIAuthenticator struct{}
// NewXAIAuthenticator constructs a new xAI authenticator.
func NewXAIAuthenticator() Authenticator {
return &XAIAuthenticator{}
}
// Provider returns the provider key for xAI.
func (XAIAuthenticator) Provider() string {
return "xai"
}
// RefreshLead instructs the manager to refresh before token expiry.
func (XAIAuthenticator) RefreshLead() *time.Duration {
lead := xaiauth.RefreshLead()
return &lead
}
// Login launches a local OAuth flow to obtain xAI tokens and persists them.
func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
callbackPort := xaiauth.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := xaiauth.GeneratePKCECodes()
if err != nil {
return nil, fmt.Errorf("xai pkce generation failed: %w", err)
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("xai state generation failed: %w", err)
}
nonce, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("xai nonce generation failed: %w", err)
}
authSvc := xaiauth.NewXAIAuth(cfg)
discovery, err := authSvc.Discover(ctx)
if err != nil {
return nil, err
}
srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort)
if errServer != nil {
return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer)
}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil {
log.Warnf("xai callback server shutdown error: %v", errShutdown)
}
}()
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath)
authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
RedirectURI: redirectURI,
CodeChallenge: pkceCodes.CodeChallenge,
State: state,
Nonce: nonce,
})
if err != nil {
return nil, err
}
if !opts.NoBrowser {
fmt.Println("Opening browser for xAI authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for xAI authentication callback...")
var result callbackResult
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
var manualInputCh <-chan string
var manualInputErrCh <-chan error
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
default:
}
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ")
continue
case input := <-manualInputCh:
manualInputCh = nil
manualInputErrCh = nil
manualResult, ok, errParse := parseXAIManualCallbackToken(input, state)
if errParse != nil {
return nil, errParse
}
if !ok {
continue
}
result = manualResult
break waitForCallback
case errManual := <-manualInputErrCh:
return nil, errManual
case <-timeoutTimer.C:
return nil, fmt.Errorf("xai: authentication timed out")
}
}
if result.Error != "" {
return nil, fmt.Errorf("xai: authentication failed: %s", result.Error)
}
if result.State != state {
return nil, fmt.Errorf("xai: invalid state")
}
if result.Code == "" {
return nil, fmt.Errorf("xai: missing authorization code")
}
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint)
if errExchange != nil {
return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange)
}
tokenStorage := authSvc.CreateTokenStorage(bundle)
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
return nil, fmt.Errorf("xai token storage missing access token")
}
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
label := strings.TrimSpace(tokenStorage.Email)
if label == "" {
label = "xAI"
}
metadata := map[string]any{
"type": "xai",
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"id_token": tokenStorage.IDToken,
"token_type": tokenStorage.TokenType,
"expires_in": tokenStorage.ExpiresIn,
"expired": tokenStorage.Expire,
"last_refresh": tokenStorage.LastRefresh,
"base_url": tokenStorage.BaseURL,
"redirect_uri": tokenStorage.RedirectURI,
"token_endpoint": tokenStorage.TokenEndpoint,
"auth_kind": "oauth",
}
if tokenStorage.Email != "" {
metadata["email"] = tokenStorage.Email
}
if tokenStorage.Subject != "" {
metadata["sub"] = tokenStorage.Subject
}
fmt.Println("xAI authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"auth_kind": "oauth",
"base_url": tokenStorage.BaseURL,
},
}, nil
}
func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) {
token := strings.TrimSpace(input)
if token == "" {
return callbackResult{}, false, nil
}
if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") {
return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token")
}
return callbackResult{Code: token, State: state}, true, nil
}
func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
if port <= 0 {
port = xaiauth.CallbackPort
}
addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, 0, nil, err
}
port = listener.Addr().(*net.TCPAddr).Port
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
result := callbackResult{
Code: strings.TrimSpace(q.Get("code")),
Error: strings.TrimSpace(q.Get("error")),
State: strings.TrimSpace(q.Get("state")),
}
resultCh <- result
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if result.Code != "" && result.Error == "" {
_, _ = w.Write([]byte("<h1>Login successful</h1><p>You can close this window.</p>"))
return
}
_, _ = w.Write([]byte("<h1>Login failed</h1><p>Please check the CLI output.</p>"))
})
srv := &http.Server{
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
go func() {
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
log.Warnf("xai callback server error: %v", errServe)
}
}()
return srv, port, resultCh, nil
}
+37
View File
@@ -0,0 +1,37 @@
package auth
import "testing"
func TestXAIAuthenticatorProviderAndRefreshLead(t *testing.T) {
authenticator := NewXAIAuthenticator()
if authenticator.Provider() != "xai" {
t.Fatalf("Provider() = %q, want xai", authenticator.Provider())
}
lead := authenticator.RefreshLead()
if lead == nil || *lead <= 0 {
t.Fatalf("RefreshLead() = %v, want positive duration", lead)
}
}
func TestParseXAIManualCallbackTokenAcceptsRawCode(t *testing.T) {
result, ok, err := parseXAIManualCallbackToken(" V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg ", "state-1")
if err != nil {
t.Fatalf("parseXAIManualCallbackToken() error = %v", err)
}
if !ok {
t.Fatal("parseXAIManualCallbackToken() ok = false, want true")
}
if result.Code != "V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg" {
t.Fatalf("Code = %q", result.Code)
}
if result.State != "state-1" {
t.Fatalf("State = %q, want state-1", result.State)
}
}
func TestParseXAIManualCallbackTokenRejectsCallbackURL(t *testing.T) {
_, _, err := parseXAIManualCallbackToken("http://127.0.0.1:56121/callback?state=state-1&code=token-1", "state-1")
if err == nil {
t.Fatal("parseXAIManualCallbackToken() error = nil, want error")
}
}
+6
View File
@@ -116,6 +116,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
sdkAuth.NewGeminiAuthenticator(),
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewXAIAuthenticator(),
)
}
@@ -433,6 +434,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "kimi":
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
case "xai":
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
@@ -1156,6 +1159,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
case "xai":
models = registry.GetXAIModels()
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
@@ -0,0 +1,36 @@
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
func TestEnsureExecutorsForAuth_XAIBindsIndependentExecutor(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "xai-auth-1",
Provider: "xai",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
service.ensureExecutorsForAuth(auth)
resolved, ok := service.coreManager.Executor("xai")
if !ok || resolved == nil {
t.Fatal("expected xai executor after bind")
}
if _, isXAI := resolved.(*executor.XAIExecutor); !isXAI {
t.Fatalf("executor type = %T, want *executor.XAIExecutor", resolved)
}
if _, isCodex := resolved.(*executor.CodexAutoExecutor); isCodex {
t.Fatal("xai must not bind the codex auto executor")
}
}