feat(auth): add OAuth2 support for xAI with PKCE and token persistence

- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support.
- Added logic for token exchange, refresh, and persistent storage in JSON format.
- Created `xai` package with helpers for OAuth discovery, API token handling, and URL building.
- Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests.
- Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
Luis Pater
2026-05-17 01:02:35 +08:00
parent cd0cea393c
commit e4c957078c
24 changed files with 2050 additions and 4 deletions
+4
View File
@@ -182,6 +182,7 @@ func main() {
var oauthCallbackPort int var oauthCallbackPort int
var antigravityLogin bool var antigravityLogin bool
var kimiLogin bool var kimiLogin bool
var xaiLogin bool
var projectID string var projectID string
var vertexImport string var vertexImport string
var vertexImportPrefix string var vertexImportPrefix string
@@ -203,6 +204,7 @@ func main() {
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)") flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth") flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth") flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.BoolVar(&xaiLogin, "xai-login", false, "Login to xAI using OAuth")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
@@ -656,6 +658,8 @@ func main() {
cmd.DoClaudeLogin(cfg, options) cmd.DoClaudeLogin(cfg, options)
} else if kimiLogin { } else if kimiLogin {
cmd.DoKimiLogin(cfg, options) cmd.DoKimiLogin(cfg, options)
} else if xaiLogin {
cmd.DoXAILogin(cfg, options)
} else { } else {
// In cloud deploy mode without config file, just wait for shutdown signals // In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists { if isCloudDeploy && !configFileExists {
+6 -1
View File
@@ -345,7 +345,7 @@ nonstream-keepalive-interval: 0
# Global OAuth model name aliases (per channel) # Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing. # These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi. # Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode. # NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping # NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps # client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
@@ -375,6 +375,9 @@ nonstream-keepalive-interval: 0
# kimi: # kimi:
# - name: "kimi-k2.5" # - name: "kimi-k2.5"
# alias: "k2.5" # alias: "k2.5"
# xai:
# - name: "grok-4.3"
# alias: "grok-latest"
# OAuth provider excluded models # OAuth provider excluded models
# oauth-excluded-models: # oauth-excluded-models:
@@ -395,6 +398,8 @@ nonstream-keepalive-interval: 0
# - "gpt-5-codex-mini" # - "gpt-5-codex-mini"
# kimi: # kimi:
# - "kimi-k2-thinking" # - "kimi-k2-thinking"
# xai:
# - "grok-3-mini"
# Optional payload configuration # Optional payload configuration
# payload: # payload:
@@ -27,6 +27,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini" geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi" "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc" "github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry" "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
@@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
func (h *Handler) RequestXAIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing xAI authentication...")
pkceCodes, errPKCE := xaiauth.GeneratePKCECodes()
if errPKCE != nil {
log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
nonce, errNonce := misc.GenerateRandomState()
if errNonce != nil {
log.Errorf("Failed to generate nonce parameter: %v", errNonce)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"})
return
}
authSvc := xaiauth.NewXAIAuth(h.cfg)
discovery, errDiscover := authSvc.Discover(ctx)
if errDiscover != nil {
log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"})
return
}
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath)
authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
RedirectURI: redirectURI,
CodeChallenge: pkceCodes.CodeChallenge,
State: state,
Nonce: nonce,
})
if errAuthURL != nil {
log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "xai")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/xai/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute xai callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start xai callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "xai") {
return
}
if time.Now().After(deadline) {
log.Error("xai oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("xAI authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed: "+errStr)
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("xAI authentication failed: state mismatch")
SetOAuthSessionError(state, "Authentication failed: state mismatch")
return
}
authCode = strings.TrimSpace(payload["code"])
if authCode == "" {
log.Error("xAI authentication failed: code not found")
SetOAuthSessionError(state, "Authentication failed: code not found")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint)
if errExchange != nil {
log.Errorf("Failed to exchange xAI token: %v", errExchange)
SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange))
return
}
tokenStorage := authSvc.CreateTokenStorage(bundle)
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
log.Error("xAI token exchange returned empty access token")
SetOAuthSessionError(state, "Failed to exchange token")
return
}
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
label := strings.TrimSpace(tokenStorage.Email)
if label == "" {
label = "xAI"
}
metadata := map[string]any{
"type": "xai",
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"id_token": tokenStorage.IDToken,
"token_type": tokenStorage.TokenType,
"expires_in": tokenStorage.ExpiresIn,
"expired": tokenStorage.Expire,
"last_refresh": tokenStorage.LastRefresh,
"base_url": tokenStorage.BaseURL,
"redirect_uri": tokenStorage.RedirectURI,
"token_endpoint": tokenStorage.TokenEndpoint,
"auth_kind": "oauth",
}
if tokenStorage.Email != "" {
metadata["email"] = tokenStorage.Email
}
if tokenStorage.Subject != "" {
metadata["sub"] = tokenStorage.Subject
}
record := &coreauth.Auth{
ID: fileName,
Provider: "xai",
FileName: fileName,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"auth_kind": "oauth",
"base_url": tokenStorage.BaseURL,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save xAI token to file: %v", errSave)
SetOAuthSessionError(state, "Failed to save token to file")
return
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("xai")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use xAI services through this CLI")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestKimiToken(c *gin.Context) { func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background() ctx := context.Background()
ctx = PopulateAuthContext(ctx, c) ctx = PopulateAuthContext(ctx, c)
@@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
return "gemini", nil return "gemini", nil
case "antigravity", "anti-gravity": case "antigravity", "anti-gravity":
return "antigravity", nil return "antigravity", nil
case "xai", "x-ai", "x.ai", "grok":
return "xai", nil
default: default:
return "", errUnsupportedOAuthFlow return "", errUnsupportedOAuthFlow
} }
+15
View File
@@ -484,6 +484,20 @@ func (s *Server) setupRoutes() {
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
}) })
s.engine.GET("/xai/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
// Management routes are registered lazily by registerManagementRoutes when a secret is configured. // Management routes are registered lazily by registerManagementRoutes when a secret is configured.
} }
@@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken) mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken) mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken) mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
} }
+20
View File
@@ -0,0 +1,20 @@
package xai
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
)
// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow.
func GeneratePKCECodes() (*PKCECodes, error) {
bytes := make([]byte, 96)
if _, err := rand.Read(bytes); err != nil {
return nil, fmt.Errorf("xai pkce: generate verifier: %w", err)
}
verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil
}
+104
View File
@@ -0,0 +1,104 @@
package xai
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
log "github.com/sirupsen/logrus"
)
// TokenStorage stores xAI OAuth credentials on disk.
type TokenStorage struct {
Type string `json:"type"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
TokenType string `json:"token_type,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Expire string `json:"expired,omitempty"`
LastRefresh string `json:"last_refresh,omitempty"`
Email string `json:"email,omitempty"`
Subject string `json:"sub,omitempty"`
BaseURL string `json:"base_url,omitempty"`
RedirectURI string `json:"redirect_uri,omitempty"`
TokenEndpoint string `json:"token_endpoint,omitempty"`
AuthKind string `json:"auth_kind,omitempty"`
Metadata map[string]any `json:"-"`
}
// SetMetadata allows the token store to merge status fields before saving.
func (ts *TokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile writes xAI credentials to a JSON auth file.
func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "xai"
ts.AuthKind = "oauth"
if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil {
return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll)
}
file, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("xai token storage: create token file: %w", err)
}
defer func() {
if errClose := file.Close(); errClose != nil {
log.Errorf("xai token storage: close token file error: %v", errClose)
}
}()
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("xai token storage: merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("xai token storage: write token file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used for xAI credentials.
func CredentialFileName(email, subject string) string {
email = sanitizeFileSegment(email)
if email != "" {
return fmt.Sprintf("xai-%s.json", email)
}
subject = sanitizeFileSegment(subject)
if subject != "" {
return fmt.Sprintf("xai-%s.json", subject)
}
return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli())
}
func sanitizeFileSegment(value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
var b strings.Builder
for _, r := range value {
switch {
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '@' || r == '.' || r == '_' || r == '-':
b.WriteRune(r)
default:
b.WriteRune('-')
}
}
return strings.Trim(b.String(), "-")
}
+72
View File
@@ -0,0 +1,72 @@
// Package xai provides OAuth2 authentication helpers for xAI Grok.
package xai
import "time"
const (
// DefaultAPIBaseURL is the default xAI Responses API base URL.
DefaultAPIBaseURL = "https://api.x.ai/v1"
// Issuer is xAI's OAuth issuer.
Issuer = "https://auth.x.ai"
// DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints.
DiscoveryURL = Issuer + "/.well-known/openid-configuration"
// ClientID is the public xAI Grok CLI OAuth client ID.
ClientID = "b1a00492-073a-47ea-816f-4c329264a828"
// Scope is the OAuth scope set required for xAI API access.
Scope = "openid profile email offline_access grok-cli:access api:access"
// RedirectHost is the loopback host used by xAI OAuth.
RedirectHost = "127.0.0.1"
// CallbackPort is the preferred loopback callback port.
CallbackPort = 56121
// RedirectPath is the loopback callback path registered by the xAI client.
RedirectPath = "/callback"
)
var refreshLead = 5 * time.Minute
// RefreshLead returns the refresh lead time for xAI OAuth credentials.
func RefreshLead() time.Duration {
return refreshLead
}
// PKCECodes holds the PKCE verifier/challenge pair.
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
// AuthorizeURLParams contains the values used to build the xAI OAuth URL.
type AuthorizeURLParams struct {
AuthorizationEndpoint string
RedirectURI string
CodeChallenge string
State string
Nonce string
}
// Discovery contains OAuth endpoints resolved from xAI OIDC discovery.
type Discovery struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
}
// TokenData holds xAI OAuth token data.
type TokenData struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
TokenType string `json:"token_type,omitempty"`
ExpiresIn int `json:"expires_in,omitempty"`
Expire string `json:"expired,omitempty"`
Email string `json:"email,omitempty"`
Subject string `json:"sub,omitempty"`
}
// AuthBundle aggregates token data and OAuth metadata for persistence.
type AuthBundle struct {
TokenData TokenData
LastRefresh string
BaseURL string
RedirectURI string
TokenEndpoint string
}
+304
View File
@@ -0,0 +1,304 @@
package xai
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
log "github.com/sirupsen/logrus"
)
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
type XAIAuth struct {
httpClient *http.Client
}
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
func NewXAIAuth(cfg *config.Config) *XAIAuth {
return NewXAIAuthWithProxyURL(cfg, "")
}
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
effectiveProxyURL := strings.TrimSpace(proxyURL)
var sdkCfg config.SDKConfig
if cfg != nil {
sdkCfg = cfg.SDKConfig
if effectiveProxyURL == "" {
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
}
}
sdkCfg.ProxyURL = effectiveProxyURL
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
}
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
rawURL = strings.TrimSpace(rawURL)
if rawURL == "" {
return "", fmt.Errorf("xai discovery %s is empty", field)
}
parsed, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
}
if parsed.Scheme != "https" {
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
}
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
}
return rawURL, nil
}
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
if err != nil {
return "", err
}
if strings.TrimSpace(params.RedirectURI) == "" {
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
}
if strings.TrimSpace(params.CodeChallenge) == "" {
return "", fmt.Errorf("xai authorize URL: code challenge is required")
}
if strings.TrimSpace(params.State) == "" {
return "", fmt.Errorf("xai authorize URL: state is required")
}
if strings.TrimSpace(params.Nonce) == "" {
return "", fmt.Errorf("xai authorize URL: nonce is required")
}
values := url.Values{
"response_type": {"code"},
"client_id": {ClientID},
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
"scope": {Scope},
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
"code_challenge_method": {"S256"},
"state": {strings.TrimSpace(params.State)},
"nonce": {strings.TrimSpace(params.Nonce)},
"plan": {"generic"},
"referrer": {"cli-proxy-api"},
}
return endpoint + "?" + values.Encode(), nil
}
// Discover resolves xAI OAuth endpoints through OIDC discovery.
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
if err != nil {
return nil, fmt.Errorf("xai discovery: create request: %w", err)
}
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("xai discovery: close response body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("xai discovery: read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload struct {
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
}
if err = json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
}
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
if err != nil {
return nil, err
}
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
if err != nil {
return nil, err
}
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
}
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
}
if strings.TrimSpace(code) == "" {
return nil, fmt.Errorf("xai token exchange: authorization code is required")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
}
if strings.TrimSpace(tokenEndpoint) == "" {
discovery, errDiscover := a.Discover(ctx)
if errDiscover != nil {
return nil, errDiscover
}
tokenEndpoint = discovery.TokenEndpoint
}
form := url.Values{
"grant_type": {"authorization_code"},
"code": {strings.TrimSpace(code)},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"client_id": {ClientID},
"code_verifier": {pkceCodes.CodeVerifier},
}
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
if err != nil {
return nil, err
}
return &AuthBundle{
TokenData: *tokenData,
LastRefresh: time.Now().UTC().Format(time.RFC3339),
BaseURL: DefaultAPIBaseURL,
RedirectURI: strings.TrimSpace(redirectURI),
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
}, nil
}
// RefreshTokens refreshes an xAI access token.
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("xai token refresh: refresh token is required")
}
if strings.TrimSpace(tokenEndpoint) == "" {
discovery, errDiscover := a.Discover(ctx)
if errDiscover != nil {
return nil, errDiscover
}
tokenEndpoint = discovery.TokenEndpoint
}
form := url.Values{
"grant_type": {"refresh_token"},
"client_id": {ClientID},
"refresh_token": {strings.TrimSpace(refreshToken)},
}
return a.postTokenForm(ctx, tokenEndpoint, form)
}
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
if ctx == nil {
ctx = context.Background()
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("xai token request: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("xai token request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("xai token request: close response body error: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("xai token response: read body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var payload struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err = json.Unmarshal(body, &payload); err != nil {
return nil, fmt.Errorf("xai token response: parse body: %w", err)
}
if strings.TrimSpace(payload.AccessToken) == "" {
return nil, fmt.Errorf("xai token response missing access_token")
}
email, subject := parseJWTIdentity(payload.IDToken)
return &TokenData{
AccessToken: strings.TrimSpace(payload.AccessToken),
RefreshToken: strings.TrimSpace(payload.RefreshToken),
IDToken: strings.TrimSpace(payload.IDToken),
TokenType: strings.TrimSpace(payload.TokenType),
ExpiresIn: payload.ExpiresIn,
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
Email: email,
Subject: subject,
}, nil
}
// CreateTokenStorage converts an auth bundle into persistable storage.
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
if bundle == nil {
return nil
}
return &TokenStorage{
Type: "xai",
AccessToken: bundle.TokenData.AccessToken,
RefreshToken: bundle.TokenData.RefreshToken,
IDToken: bundle.TokenData.IDToken,
TokenType: bundle.TokenData.TokenType,
ExpiresIn: bundle.TokenData.ExpiresIn,
Expire: bundle.TokenData.Expire,
LastRefresh: bundle.LastRefresh,
Email: strings.TrimSpace(bundle.TokenData.Email),
Subject: bundle.TokenData.Subject,
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
RedirectURI: bundle.RedirectURI,
TokenEndpoint: bundle.TokenEndpoint,
AuthKind: "oauth",
}
}
func parseJWTIdentity(token string) (email string, subject string) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return "", ""
}
payload := parts[1]
payload += strings.Repeat("=", (4-len(payload)%4)%4)
raw, err := base64.URLEncoding.DecodeString(payload)
if err != nil {
return "", ""
}
var claims map[string]any
if err = json.Unmarshal(raw, &claims); err != nil {
return "", ""
}
if v, ok := claims["email"].(string); ok {
email = strings.TrimSpace(v)
}
if v, ok := claims["sub"].(string); ok {
subject = strings.TrimSpace(v)
}
return email, subject
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
+105
View File
@@ -0,0 +1,105 @@
package xai
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) {
authURL, err := BuildAuthorizeURL(AuthorizeURLParams{
AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize",
RedirectURI: "http://127.0.0.1:56121/callback",
CodeChallenge: "challenge",
State: "state-123",
Nonce: "nonce-123",
})
if err != nil {
t.Fatalf("BuildAuthorizeURL() error = %v", err)
}
parsed, errParse := url.Parse(authURL)
if errParse != nil {
t.Fatalf("parse authorize URL: %v", errParse)
}
if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" {
t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path)
}
query := parsed.Query()
want := map[string]string{
"response_type": "code",
"client_id": ClientID,
"redirect_uri": "http://127.0.0.1:56121/callback",
"scope": Scope,
"code_challenge": "challenge",
"code_challenge_method": "S256",
"state": "state-123",
"nonce": "nonce-123",
"plan": "generic",
"referrer": "cli-proxy-api",
}
for key, value := range want {
if got := query.Get(key); got != value {
t.Fatalf("%s = %q, want %q", key, got, value)
}
}
}
func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) {
if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil {
t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err)
}
if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil {
t.Fatal("expected non-HTTPS endpoint to be rejected")
}
if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil {
t.Fatal("expected non-xAI endpoint to be rejected")
}
}
func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) {
var gotForm url.Values
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Fatalf("method = %s, want POST", r.Method)
}
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") {
t.Fatalf("Content-Type = %q, want form", got)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("ParseForm() error = %v", err)
}
gotForm = r.PostForm
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "new-access",
"refresh_token": "new-refresh",
"token_type": "Bearer",
"expires_in": 3600,
})
}))
defer server.Close()
auth := NewXAIAuth(nil)
tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL)
if err != nil {
t.Fatalf("RefreshTokens() error = %v", err)
}
if tokenData.AccessToken != "new-access" {
t.Fatalf("access token = %q, want new-access", tokenData.AccessToken)
}
if gotForm.Get("grant_type") != "refresh_token" {
t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type"))
}
if gotForm.Get("client_id") != ClientID {
t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID)
}
if gotForm.Get("refresh_token") != "old-refresh" {
t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token"))
}
}
+2 -1
View File
@@ -6,7 +6,7 @@ import (
// newAuthManager creates a new authentication manager instance with all supported // newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for // authenticators and a file-based token store. It initializes authenticators for
// Gemini, Codex, Claude, Antigravity, and Kimi providers. // Gemini, Codex, Claude, Antigravity, Kimi, and xAI providers.
// //
// Returns: // Returns:
// - *sdkAuth.Manager: A configured authentication manager instance // - *sdkAuth.Manager: A configured authentication manager instance
@@ -18,6 +18,7 @@ func newAuthManager() *sdkAuth.Manager {
sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewAntigravityAuthenticator(), sdkAuth.NewAntigravityAuthenticator(),
sdkAuth.NewKimiAuthenticator(), sdkAuth.NewKimiAuthenticator(),
sdkAuth.NewXAIAuthenticator(),
) )
return manager return manager
} }
+44
View File
@@ -0,0 +1,44 @@
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoXAILogin triggers the OAuth flow for the xAI provider and saves tokens.
func DoXAILogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
record, savedPath, err := manager.Login(context.Background(), "xai", cfg, authOpts)
if err != nil {
log.Errorf("xAI authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("xAI authentication successful!")
}
+1 -1
View File
@@ -137,7 +137,7 @@ type Config struct {
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels. // OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels: // These aliases affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi. // gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
// //
// NOTE: This does not apply to existing per-credential model alias features under: // NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode. // gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
+10
View File
@@ -21,6 +21,7 @@ type staticModelsJSON struct {
CodexPro []*ModelInfo `json:"codex-pro"` CodexPro []*ModelInfo `json:"codex-pro"`
Kimi []*ModelInfo `json:"kimi"` Kimi []*ModelInfo `json:"kimi"`
Antigravity []*ModelInfo `json:"antigravity"` Antigravity []*ModelInfo `json:"antigravity"`
XAI []*ModelInfo `json:"xai"`
} }
// GetClaudeModels returns the standard Claude model definitions. // GetClaudeModels returns the standard Claude model definitions.
@@ -78,6 +79,11 @@ func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity) return cloneModelInfos(getModels().Antigravity)
} }
// GetXAIModels returns the standard xAI Grok model definitions.
func GetXAIModels() []*ModelInfo {
return cloneModelInfos(getModels().XAI)
}
// WithCodexBuiltins injects hard-coded Codex-only model definitions that should // WithCodexBuiltins injects hard-coded Codex-only model definitions that should
// not depend on remote models.json updates. Built-ins replace any matching IDs // not depend on remote models.json updates. Built-ins replace any matching IDs
// already present in the provided slice. // already present in the provided slice.
@@ -167,6 +173,7 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
// - codex // - codex
// - kimi // - kimi
// - antigravity // - antigravity
// - xai
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel)) key := strings.ToLower(strings.TrimSpace(channel))
switch key { switch key {
@@ -186,6 +193,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetKimiModels() return GetKimiModels()
case "antigravity": case "antigravity":
return GetAntigravityModels() return GetAntigravityModels()
case "xai", "x-ai", "grok":
return GetXAIModels()
default: default:
return nil return nil
} }
@@ -208,6 +217,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
data.CodexPro, data.CodexPro,
data.Kimi, data.Kimi,
data.Antigravity, data.Antigravity,
data.XAI,
} }
for _, models := range allModels { for _, models := range allModels {
for _, m := range models { for _, m := range models {
+2
View File
@@ -215,6 +215,7 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
{"codex", oldData.CodexPro, newData.CodexPro}, {"codex", oldData.CodexPro, newData.CodexPro},
{"kimi", oldData.Kimi, newData.Kimi}, {"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity}, {"antigravity", oldData.Antigravity, newData.Antigravity},
{"xai", oldData.XAI, newData.XAI},
} }
seen := make(map[string]bool, len(sections)) seen := make(map[string]bool, len(sections))
@@ -335,6 +336,7 @@ func validateModelsCatalog(data *staticModelsJSON) error {
{name: "codex-pro", models: data.CodexPro}, {name: "codex-pro", models: data.CodexPro},
{name: "kimi", models: data.Kimi}, {name: "kimi", models: data.Kimi},
{name: "antigravity", models: data.Antigravity}, {name: "antigravity", models: data.Antigravity},
{name: "xai", models: data.XAI},
} }
for _, section := range requiredSections { for _, section := range requiredSections {
+106 -1
View File
@@ -46,7 +46,8 @@
"levels": [ "levels": [
"low", "low",
"medium", "medium",
"high" "high",
"xhigh"
] ]
} }
}, },
@@ -2064,5 +2065,109 @@
] ]
} }
} }
],
"xai": [
{
"id": "grok-4.3",
"object": "model",
"created": 1775606400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.3",
"name": "grok-4.3",
"description": "xAI Grok 4.3 model for the Responses API.",
"context_length": 1000000,
"max_completion_tokens": 65536,
"thinking": {
"zero_allowed": true,
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "grok-4.20-0309-reasoning",
"object": "model",
"created": 1773014400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.20 0309 Reasoning",
"name": "grok-4.20-0309-reasoning",
"description": "xAI Grok 4.20 0309 reasoning model for the Responses API.",
"context_length": 2000000,
"max_completion_tokens": 65536
},
{
"id": "grok-4.20-0309-non-reasoning",
"object": "model",
"created": 1773014400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.20 0309 Non Reasoning",
"name": "grok-4.20-0309-non-reasoning",
"description": "xAI Grok 4.20 0309 non-reasoning model for the Responses API.",
"context_length": 2000000,
"max_completion_tokens": 65536
},
{
"id": "grok-4.20-multi-agent-0309",
"object": "model",
"created": 1773014400,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 4.20 Multi Agent 0309",
"name": "grok-4.20-multi-agent-0309",
"description": "xAI Grok 4.20 multi-agent model for the Responses API.",
"context_length": 2000000,
"max_completion_tokens": 65536,
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "grok-3-mini",
"object": "model",
"created": 1740960000,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 3 Mini",
"name": "grok-3-mini",
"description": "xAI Grok 3 Mini model for the Responses API.",
"context_length": 131072,
"max_completion_tokens": 32768,
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "grok-3-mini-fast",
"object": "model",
"created": 1740960000,
"owned_by": "xai",
"type": "xai",
"display_name": "Grok 3 Mini Fast",
"name": "grok-3-mini-fast",
"description": "xAI Grok 3 Mini Fast model for the Responses API.",
"context_length": 131072,
"max_completion_tokens": 32768,
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
}
] ]
} }
+570
View File
@@ -0,0 +1,570 @@
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"sort"
"strings"
"time"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tiktoken-go/tokenizer"
)
var xaiDataTag = []byte("data:")
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
type XAIExecutor struct {
cfg *config.Config
}
// NewXAIExecutor creates a new xAI executor.
func NewXAIExecutor(cfg *config.Config) *XAIExecutor {
return &XAIExecutor{cfg: cfg}
}
// Identifier returns the provider identifier.
func (e *XAIExecutor) Identifier() string {
return "xai"
}
// PrepareRequest injects xAI credentials into the outgoing HTTP request.
func (e *XAIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token, _ := xaiCreds(auth)
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects xAI credentials into the request and executes it.
func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("xai executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
return nil, errPrepare
}
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
token, baseURL := xaiCreds(auth)
if baseURL == "" {
baseURL = xaiauth.DefaultAPIBaseURL
}
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
if err != nil {
return resp, err
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
if err != nil {
return resp, err
}
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for _, line := range bytes.Split(data, []byte("\n")) {
if !bytes.HasPrefix(line, xaiDataTag) {
continue
}
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
var param any
out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, &param)
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
}
}
return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"}
}
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
token, baseURL := xaiCreds(auth)
if baseURL == "" {
baseURL = xaiauth.DefaultAPIBaseURL
}
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
if err != nil {
return nil, err
}
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
defer reporter.TrackFailure(ctx, &err)
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
if err != nil {
return nil, err
}
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
if errRead != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
return nil, errRead
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return nil, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800)
var param any
outputItemsByIndex := make(map[int64][]byte)
var outputItemsFallback [][]byte
for scanner.Scan() {
line := scanner.Bytes()
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
translatedLine := bytes.Clone(line)
if bytes.HasPrefix(line, xaiDataTag) {
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
switch gjson.GetBytes(eventData, "type").String() {
case "response.output_item.done":
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
case "response.completed":
if detail, ok := helps.ParseCodexUsage(eventData); ok {
reporter.Publish(ctx, detail)
}
eventData = xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
translatedLine = append([]byte("data: "), eventData...)
}
}
chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, &param)
for i := range chunks {
select {
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
case <-ctx.Done():
return
}
}
}
if errScan := scanner.Err(); errScan != nil {
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
reporter.PublishFailure(ctx, errScan)
select {
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
case <-ctx.Done():
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for xAI Responses requests.
func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
prepared, err := e.prepareResponsesRequest(ctx, req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
}
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: tokenizer init failed: %w", err)
}
count, err := enc.Count(string(prepared.body))
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err)
}
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON))
return cliproxyexecutor.Response{Payload: translated}, nil
}
// Refresh refreshes xAI OAuth credentials using the stored refresh token.
func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("xai executor: refresh called")
if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled {
return refreshed, err
}
if auth == nil {
return nil, statusErr{code: http.StatusInternalServerError, msg: "xai executor: auth is nil"}
}
refreshToken := xaiMetadataString(auth.Metadata, "refresh_token")
if refreshToken == "" {
return auth, nil
}
tokenEndpoint := xaiMetadataString(auth.Metadata, "token_endpoint")
svc := xaiauth.NewXAIAuthWithProxyURL(e.cfg, auth.ProxyURL)
td, err := svc.RefreshTokens(ctx, refreshToken, tokenEndpoint)
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["type"] = "xai"
auth.Metadata["auth_kind"] = "oauth"
auth.Metadata["access_token"] = td.AccessToken
if td.RefreshToken != "" {
auth.Metadata["refresh_token"] = td.RefreshToken
}
if td.IDToken != "" {
auth.Metadata["id_token"] = td.IDToken
}
if td.TokenType != "" {
auth.Metadata["token_type"] = td.TokenType
}
if td.ExpiresIn > 0 {
auth.Metadata["expires_in"] = td.ExpiresIn
}
if td.Expire != "" {
auth.Metadata["expired"] = td.Expire
}
if td.Email != "" {
auth.Metadata["email"] = td.Email
}
if td.Subject != "" {
auth.Metadata["sub"] = td.Subject
}
if tokenEndpoint != "" {
auth.Metadata["token_endpoint"] = tokenEndpoint
}
if xaiMetadataString(auth.Metadata, "base_url") == "" {
auth.Metadata["base_url"] = xaiauth.DefaultAPIBaseURL
}
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["auth_kind"] = "oauth"
if strings.TrimSpace(auth.Attributes["base_url"]) == "" {
auth.Attributes["base_url"] = xaiauth.DefaultAPIBaseURL
}
return auth, nil
}
type xaiPreparedRequest struct {
baseModel string
from sdktranslator.Format
to sdktranslator.Format
originalPayload []byte
body []byte
sessionID string
}
func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := bytes.Clone(originalPayloadSource)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
var err error
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
requestPath := helps.PayloadRequestPath(opts)
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", stream)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.DeleteBytes(body, "stream_options")
body = normalizeCodexInstructions(body)
body = sanitizeXAIResponsesBody(body, baseModel)
sessionID := xaiExecutionSessionID(req, opts)
if sessionID != "" {
body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID)
}
return &xaiPreparedRequest{
baseModel: baseModel,
from: from,
to: to,
originalPayload: originalPayload,
body: body,
sessionID: sessionID,
}, nil
}
func (e *XAIExecutor) recordXAIRequest(ctx context.Context, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: headers,
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
}
func xaiCreds(auth *cliproxyauth.Auth) (token, baseURL string) {
if auth == nil {
return "", ""
}
if auth.Attributes != nil {
token = strings.TrimSpace(auth.Attributes["api_key"])
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
}
if auth.Metadata != nil {
if token == "" {
token = xaiMetadataString(auth.Metadata, "access_token")
}
if baseURL == "" {
baseURL = xaiMetadataString(auth.Metadata, "base_url")
}
}
return token, baseURL
}
func applyXAIHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, sessionID string) {
r.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(token) != "" {
r.Header.Set("Authorization", "Bearer "+token)
}
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
r.Header.Set("Accept", "application/json")
}
r.Header.Set("Connection", "Keep-Alive")
if sessionID != "" {
r.Header.Set("x-grok-conv-id", sessionID)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
}
func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) string {
if value := xaiMetadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
return value
}
if value := xaiMetadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
return value
}
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
return strings.TrimSpace(promptCacheKey.String())
}
return ""
}
func xaiMetadataString(meta map[string]any, key string) string {
if len(meta) == 0 || key == "" {
return ""
}
value, ok := meta[key]
if !ok || value == nil {
return ""
}
switch typed := value.(type) {
case string:
return strings.TrimSpace(typed)
case fmt.Stringer:
return strings.TrimSpace(typed.String())
default:
return strings.TrimSpace(fmt.Sprint(typed))
}
}
func sanitizeXAIResponsesBody(body []byte, model string) []byte {
body = removeXAIEncryptedReasoningInclude(body)
if !xaiSupportsReasoningEffort(model) {
body, _ = sjson.DeleteBytes(body, "reasoning")
}
return body
}
func removeXAIEncryptedReasoningInclude(body []byte) []byte {
include := gjson.GetBytes(body, "include")
if !include.Exists() || !include.IsArray() {
return body
}
kept := make([]string, 0, len(include.Array()))
for _, item := range include.Array() {
value := strings.TrimSpace(item.String())
if value == "" || value == "reasoning.encrypted_content" {
continue
}
kept = append(kept, value)
}
body, _ = sjson.SetBytes(body, "include", kept)
return body
}
func xaiSupportsReasoningEffort(model string) bool {
name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName))
if idx := strings.LastIndex(name, "/"); idx >= 0 {
name = name[idx+1:]
}
switch {
case strings.HasPrefix(name, "grok-3-mini"):
return true
case strings.HasPrefix(name, "grok-4.20-multi-agent"):
return true
case strings.HasPrefix(name, "grok-4.3"):
return true
default:
return false
}
}
func xaiCollectOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) {
itemResult := gjson.GetBytes(eventData, "item")
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
return
}
outputIndexResult := gjson.GetBytes(eventData, "output_index")
if outputIndexResult.Exists() {
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
return
}
*outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw))
}
func xaiPatchCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte {
outputResult := gjson.GetBytes(eventData, "response.output")
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
if !shouldPatchOutput {
return eventData
}
indexes := make([]int64, 0, len(outputItemsByIndex))
for idx := range outputItemsByIndex {
indexes = append(indexes, idx)
}
sort.Slice(indexes, func(i, j int) bool {
return indexes[i] < indexes[j]
})
outputArray := []byte("[]")
var buf bytes.Buffer
buf.WriteByte('[')
wrote := false
for _, idx := range indexes {
if wrote {
buf.WriteByte(',')
}
buf.Write(outputItemsByIndex[idx])
wrote = true
}
for _, item := range outputItemsFallback {
if wrote {
buf.WriteByte(',')
}
buf.Write(item)
wrote = true
}
buf.WriteByte(']')
if wrote {
outputArray = buf.Bytes()
}
patched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray)
return patched
}
@@ -0,0 +1,138 @@
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
"github.com/tidwall/gjson"
)
func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) {
var gotPath string
var gotAuth string
var gotGrokConvID string
var gotOriginator string
var gotAccountID string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
gotGrokConvID = r.Header.Get("x-grok-conv-id")
gotOriginator = r.Header.Get("Originator")
gotAccountID = r.Header.Get("Chatgpt-Account-Id")
var errRead error
gotBody, errRead = io.ReadAll(r.Body)
if errRead != nil {
t.Fatalf("read body: %v", errRead)
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}],\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}\n\n"))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
ID: "xai-auth",
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"auth_kind": "oauth",
},
Metadata: map[string]any{
"access_token": "xai-token",
"email": "user@example.com",
},
}
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-4.3",
Payload: []byte(`{"model":"grok-4.3","input":"hello","include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Stream: false,
Metadata: map[string]any{
cliproxyexecutor.ExecutionSessionMetadataKey: "conv-xai-1",
},
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotPath != "/responses" {
t.Fatalf("path = %q, want /responses", gotPath)
}
if gotAuth != "Bearer xai-token" {
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
}
if gotGrokConvID != "conv-xai-1" {
t.Fatalf("x-grok-conv-id = %q, want conv-xai-1", gotGrokConvID)
}
if gotOriginator != "" {
t.Fatalf("Originator = %q, want empty", gotOriginator)
}
if gotAccountID != "" {
t.Fatalf("Chatgpt-Account-Id = %q, want empty", gotAccountID)
}
if gjson.GetBytes(gotBody, "prompt_cache_key").String() != "conv-xai-1" {
t.Fatalf("prompt_cache_key missing from body: %s", string(gotBody))
}
if !gjson.GetBytes(gotBody, "stream").Bool() {
t.Fatalf("stream = false, want true; body=%s", string(gotBody))
}
if gjson.GetBytes(gotBody, "reasoning.effort").String() != "high" {
t.Fatalf("reasoning.effort = %q, want high; body=%s", gjson.GetBytes(gotBody, "reasoning.effort").String(), string(gotBody))
}
for _, include := range gjson.GetBytes(gotBody, "include").Array() {
if include.String() == "reasoning.encrypted_content" {
t.Fatalf("xai request must not ask for encrypted reasoning content: %s", string(gotBody))
}
}
}
func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var errRead error
gotBody, errRead = io.ReadAll(r.Body)
if errRead != nil {
t.Fatalf("read body: %v", errRead)
}
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "xai",
Attributes: map[string]string{
"base_url": server.URL,
"auth_kind": "oauth",
},
Metadata: map[string]any{"access_token": "xai-token"},
}
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-4",
Payload: []byte(`{"model":"grok-4","input":"hello","reasoning":{"effort":"high"}}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FormatOpenAIResponse,
Stream: false,
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gjson.GetBytes(gotBody, "reasoning").Exists() {
t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody))
}
}
+3
View File
@@ -24,6 +24,7 @@ var oauthProviders = []oauthProvider{
{"Codex (OpenAI)", "codex-auth-url", "🟩"}, {"Codex (OpenAI)", "codex-auth-url", "🟩"},
{"Antigravity", "antigravity-auth-url", "🟪"}, {"Antigravity", "antigravity-auth-url", "🟪"},
{"Kimi", "kimi-auth-url", "🟫"}, {"Kimi", "kimi-auth-url", "🟫"},
{"xAI", "xai-auth-url", "⬛"},
} }
// oauthTabModel handles OAuth login flows. // oauthTabModel handles OAuth login flows.
@@ -280,6 +281,8 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
providerKey = "antigravity" providerKey = "antigravity"
case "kimi-auth-url": case "kimi-auth-url":
providerKey = "kimi" providerKey = "kimi"
case "xai-auth-url":
providerKey = "xai"
} }
break break
} }
+1
View File
@@ -13,6 +13,7 @@ func init() {
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() }) registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() }) registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() }) registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() })
} }
func registerRefreshLead(provider string, factory func() Authenticator) { func registerRefreshLead(provider string, factory func() Authenticator) {
+282
View File
@@ -0,0 +1,282 @@
package auth
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
"github.com/router-for-me/CLIProxyAPI/v7/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// XAIAuthenticator implements the xAI Grok OAuth loopback flow.
type XAIAuthenticator struct{}
// NewXAIAuthenticator constructs a new xAI authenticator.
func NewXAIAuthenticator() Authenticator {
return &XAIAuthenticator{}
}
// Provider returns the provider key for xAI.
func (XAIAuthenticator) Provider() string {
return "xai"
}
// RefreshLead instructs the manager to refresh before token expiry.
func (XAIAuthenticator) RefreshLead() *time.Duration {
lead := xaiauth.RefreshLead()
return &lead
}
// Login launches a local OAuth flow to obtain xAI tokens and persists them.
func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
callbackPort := xaiauth.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := xaiauth.GeneratePKCECodes()
if err != nil {
return nil, fmt.Errorf("xai pkce generation failed: %w", err)
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("xai state generation failed: %w", err)
}
nonce, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("xai nonce generation failed: %w", err)
}
authSvc := xaiauth.NewXAIAuth(cfg)
discovery, err := authSvc.Discover(ctx)
if err != nil {
return nil, err
}
srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort)
if errServer != nil {
return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer)
}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil {
log.Warnf("xai callback server shutdown error: %v", errShutdown)
}
}()
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath)
authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
RedirectURI: redirectURI,
CodeChallenge: pkceCodes.CodeChallenge,
State: state,
Nonce: nonce,
})
if err != nil {
return nil, err
}
if !opts.NoBrowser {
fmt.Println("Opening browser for xAI authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for xAI authentication callback...")
var result callbackResult
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
var manualInputCh <-chan string
var manualInputErrCh <-chan error
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
default:
}
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ")
continue
case input := <-manualInputCh:
manualInputCh = nil
manualInputErrCh = nil
manualResult, ok, errParse := parseXAIManualCallbackToken(input, state)
if errParse != nil {
return nil, errParse
}
if !ok {
continue
}
result = manualResult
break waitForCallback
case errManual := <-manualInputErrCh:
return nil, errManual
case <-timeoutTimer.C:
return nil, fmt.Errorf("xai: authentication timed out")
}
}
if result.Error != "" {
return nil, fmt.Errorf("xai: authentication failed: %s", result.Error)
}
if result.State != state {
return nil, fmt.Errorf("xai: invalid state")
}
if result.Code == "" {
return nil, fmt.Errorf("xai: missing authorization code")
}
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint)
if errExchange != nil {
return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange)
}
tokenStorage := authSvc.CreateTokenStorage(bundle)
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
return nil, fmt.Errorf("xai token storage missing access token")
}
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
label := strings.TrimSpace(tokenStorage.Email)
if label == "" {
label = "xAI"
}
metadata := map[string]any{
"type": "xai",
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"id_token": tokenStorage.IDToken,
"token_type": tokenStorage.TokenType,
"expires_in": tokenStorage.ExpiresIn,
"expired": tokenStorage.Expire,
"last_refresh": tokenStorage.LastRefresh,
"base_url": tokenStorage.BaseURL,
"redirect_uri": tokenStorage.RedirectURI,
"token_endpoint": tokenStorage.TokenEndpoint,
"auth_kind": "oauth",
}
if tokenStorage.Email != "" {
metadata["email"] = tokenStorage.Email
}
if tokenStorage.Subject != "" {
metadata["sub"] = tokenStorage.Subject
}
fmt.Println("xAI authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: label,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"auth_kind": "oauth",
"base_url": tokenStorage.BaseURL,
},
}, nil
}
func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) {
token := strings.TrimSpace(input)
if token == "" {
return callbackResult{}, false, nil
}
if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") {
return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token")
}
return callbackResult{Code: token, State: state}, true, nil
}
func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
if port <= 0 {
port = xaiauth.CallbackPort
}
addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, 0, nil, err
}
port = listener.Addr().(*net.TCPAddr).Port
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
result := callbackResult{
Code: strings.TrimSpace(q.Get("code")),
Error: strings.TrimSpace(q.Get("error")),
State: strings.TrimSpace(q.Get("state")),
}
resultCh <- result
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if result.Code != "" && result.Error == "" {
_, _ = w.Write([]byte("<h1>Login successful</h1><p>You can close this window.</p>"))
return
}
_, _ = w.Write([]byte("<h1>Login failed</h1><p>Please check the CLI output.</p>"))
})
srv := &http.Server{
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
go func() {
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
log.Warnf("xai callback server error: %v", errServe)
}
}()
return srv, port, resultCh, nil
}
+37
View File
@@ -0,0 +1,37 @@
package auth
import "testing"
func TestXAIAuthenticatorProviderAndRefreshLead(t *testing.T) {
authenticator := NewXAIAuthenticator()
if authenticator.Provider() != "xai" {
t.Fatalf("Provider() = %q, want xai", authenticator.Provider())
}
lead := authenticator.RefreshLead()
if lead == nil || *lead <= 0 {
t.Fatalf("RefreshLead() = %v, want positive duration", lead)
}
}
func TestParseXAIManualCallbackTokenAcceptsRawCode(t *testing.T) {
result, ok, err := parseXAIManualCallbackToken(" V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg ", "state-1")
if err != nil {
t.Fatalf("parseXAIManualCallbackToken() error = %v", err)
}
if !ok {
t.Fatal("parseXAIManualCallbackToken() ok = false, want true")
}
if result.Code != "V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg" {
t.Fatalf("Code = %q", result.Code)
}
if result.State != "state-1" {
t.Fatalf("State = %q, want state-1", result.State)
}
}
func TestParseXAIManualCallbackTokenRejectsCallbackURL(t *testing.T) {
_, _, err := parseXAIManualCallbackToken("http://127.0.0.1:56121/callback?state=state-1&code=token-1", "state-1")
if err == nil {
t.Fatal("parseXAIManualCallbackToken() error = nil, want error")
}
}
+6
View File
@@ -116,6 +116,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
sdkAuth.NewGeminiAuthenticator(), sdkAuth.NewGeminiAuthenticator(),
sdkAuth.NewCodexAuthenticator(), sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(), sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewXAIAuthenticator(),
) )
} }
@@ -433,6 +434,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "kimi": case "kimi":
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg)) s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
case "xai":
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
default: default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" { if providerKey == "" {
@@ -1156,6 +1159,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
case "kimi": case "kimi":
models = registry.GetKimiModels() models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded) models = applyExcludedModels(models, excluded)
case "xai":
models = registry.GetXAIModels()
models = applyExcludedModels(models, excluded)
default: default:
// Handle OpenAI-compatibility providers by name using config // Handle OpenAI-compatibility providers by name using config
if s.cfg != nil { if s.cfg != nil {
@@ -0,0 +1,36 @@
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
func TestEnsureExecutorsForAuth_XAIBindsIndependentExecutor(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "xai-auth-1",
Provider: "xai",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
},
}
service.ensureExecutorsForAuth(auth)
resolved, ok := service.coreManager.Executor("xai")
if !ok || resolved == nil {
t.Fatal("expected xai executor after bind")
}
if _, isXAI := resolved.(*executor.XAIExecutor); !isXAI {
t.Fatalf("executor type = %T, want *executor.XAIExecutor", resolved)
}
if _, isCodex := resolved.(*executor.CodexAutoExecutor); isCodex {
t.Fatal("xai must not bind the codex auto executor")
}
}