feat(auth): add OAuth2 support for xAI with PKCE and token persistence
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
|
||||
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
@@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestXAIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing xAI authentication...")
|
||||
|
||||
pkceCodes, errPKCE := xaiauth.GeneratePKCECodes()
|
||||
if errPKCE != nil {
|
||||
log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||
return
|
||||
}
|
||||
|
||||
state, errState := misc.GenerateRandomState()
|
||||
if errState != nil {
|
||||
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
nonce, errNonce := misc.GenerateRandomState()
|
||||
if errNonce != nil {
|
||||
log.Errorf("Failed to generate nonce parameter: %v", errNonce)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
authSvc := xaiauth.NewXAIAuth(h.cfg)
|
||||
discovery, errDiscover := authSvc.Discover(ctx)
|
||||
if errDiscover != nil {
|
||||
log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"})
|
||||
return
|
||||
}
|
||||
|
||||
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath)
|
||||
authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
|
||||
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
|
||||
RedirectURI: redirectURI,
|
||||
CodeChallenge: pkceCodes.CodeChallenge,
|
||||
State: state,
|
||||
Nonce: nonce,
|
||||
})
|
||||
if errAuthURL != nil {
|
||||
log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "xai")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/xai/callback")
|
||||
if errTarget != nil {
|
||||
log.WithError(errTarget).Error("failed to compute xai callback target")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start xai callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var authCode string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "xai") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("xai oauth flow timed out")
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
var payload map[string]string
|
||||
_ = json.Unmarshal(data, &payload)
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("xAI authentication failed: %s", errStr)
|
||||
SetOAuthSessionError(state, "Authentication failed: "+errStr)
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("xAI authentication failed: state mismatch")
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("xAI authentication failed: code not found")
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
|
||||
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint)
|
||||
if errExchange != nil {
|
||||
log.Errorf("Failed to exchange xAI token: %v", errExchange)
|
||||
SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange))
|
||||
return
|
||||
}
|
||||
|
||||
tokenStorage := authSvc.CreateTokenStorage(bundle)
|
||||
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
|
||||
log.Error("xAI token exchange returned empty access token")
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
|
||||
label := strings.TrimSpace(tokenStorage.Email)
|
||||
if label == "" {
|
||||
label = "xAI"
|
||||
}
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "xai",
|
||||
"access_token": tokenStorage.AccessToken,
|
||||
"refresh_token": tokenStorage.RefreshToken,
|
||||
"id_token": tokenStorage.IDToken,
|
||||
"token_type": tokenStorage.TokenType,
|
||||
"expires_in": tokenStorage.ExpiresIn,
|
||||
"expired": tokenStorage.Expire,
|
||||
"last_refresh": tokenStorage.LastRefresh,
|
||||
"base_url": tokenStorage.BaseURL,
|
||||
"redirect_uri": tokenStorage.RedirectURI,
|
||||
"token_endpoint": tokenStorage.TokenEndpoint,
|
||||
"auth_kind": "oauth",
|
||||
}
|
||||
if tokenStorage.Email != "" {
|
||||
metadata["email"] = tokenStorage.Email
|
||||
}
|
||||
if tokenStorage.Subject != "" {
|
||||
metadata["sub"] = tokenStorage.Subject
|
||||
}
|
||||
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "xai",
|
||||
FileName: fileName,
|
||||
Label: label,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
Attributes: map[string]string{
|
||||
"auth_kind": "oauth",
|
||||
"base_url": tokenStorage.BaseURL,
|
||||
},
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save xAI token to file: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("xai")
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use xAI services through this CLI")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
@@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
return "gemini", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "xai", "x-ai", "x.ai", "grok":
|
||||
return "xai", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
|
||||
@@ -484,6 +484,20 @@ func (s *Server) setupRoutes() {
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
s.engine.GET("/xai/callback", func(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
})
|
||||
|
||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||
}
|
||||
|
||||
@@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow.
|
||||
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||
bytes := make([]byte, 96)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return nil, fmt.Errorf("xai pkce: generate verifier: %w", err)
|
||||
}
|
||||
verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes)
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
||||
return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TokenStorage stores xAI OAuth credentials on disk.
|
||||
type TokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Expire string `json:"expired,omitempty"`
|
||||
LastRefresh string `json:"last_refresh,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
BaseURL string `json:"base_url,omitempty"`
|
||||
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||
AuthKind string `json:"auth_kind,omitempty"`
|
||||
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows the token store to merge status fields before saving.
|
||||
func (ts *TokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile writes xAI credentials to a JSON auth file.
|
||||
func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "xai"
|
||||
ts.AuthKind = "oauth"
|
||||
if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil {
|
||||
return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll)
|
||||
}
|
||||
file, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("xai token storage: create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := file.Close(); errClose != nil {
|
||||
log.Errorf("xai token storage: close token file error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("xai token storage: merge metadata: %w", errMerge)
|
||||
}
|
||||
encoder := json.NewEncoder(file)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("xai token storage: write token file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used for xAI credentials.
|
||||
func CredentialFileName(email, subject string) string {
|
||||
email = sanitizeFileSegment(email)
|
||||
if email != "" {
|
||||
return fmt.Sprintf("xai-%s.json", email)
|
||||
}
|
||||
subject = sanitizeFileSegment(subject)
|
||||
if subject != "" {
|
||||
return fmt.Sprintf("xai-%s.json", subject)
|
||||
}
|
||||
return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
func sanitizeFileSegment(value string) string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range value {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
case r >= 'A' && r <= 'Z':
|
||||
b.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == '@' || r == '.' || r == '_' || r == '-':
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
b.WriteRune('-')
|
||||
}
|
||||
}
|
||||
return strings.Trim(b.String(), "-")
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
// Package xai provides OAuth2 authentication helpers for xAI Grok.
|
||||
package xai
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
// DefaultAPIBaseURL is the default xAI Responses API base URL.
|
||||
DefaultAPIBaseURL = "https://api.x.ai/v1"
|
||||
// Issuer is xAI's OAuth issuer.
|
||||
Issuer = "https://auth.x.ai"
|
||||
// DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints.
|
||||
DiscoveryURL = Issuer + "/.well-known/openid-configuration"
|
||||
// ClientID is the public xAI Grok CLI OAuth client ID.
|
||||
ClientID = "b1a00492-073a-47ea-816f-4c329264a828"
|
||||
// Scope is the OAuth scope set required for xAI API access.
|
||||
Scope = "openid profile email offline_access grok-cli:access api:access"
|
||||
// RedirectHost is the loopback host used by xAI OAuth.
|
||||
RedirectHost = "127.0.0.1"
|
||||
// CallbackPort is the preferred loopback callback port.
|
||||
CallbackPort = 56121
|
||||
// RedirectPath is the loopback callback path registered by the xAI client.
|
||||
RedirectPath = "/callback"
|
||||
)
|
||||
|
||||
var refreshLead = 5 * time.Minute
|
||||
|
||||
// RefreshLead returns the refresh lead time for xAI OAuth credentials.
|
||||
func RefreshLead() time.Duration {
|
||||
return refreshLead
|
||||
}
|
||||
|
||||
// PKCECodes holds the PKCE verifier/challenge pair.
|
||||
type PKCECodes struct {
|
||||
CodeVerifier string
|
||||
CodeChallenge string
|
||||
}
|
||||
|
||||
// AuthorizeURLParams contains the values used to build the xAI OAuth URL.
|
||||
type AuthorizeURLParams struct {
|
||||
AuthorizationEndpoint string
|
||||
RedirectURI string
|
||||
CodeChallenge string
|
||||
State string
|
||||
Nonce string
|
||||
}
|
||||
|
||||
// Discovery contains OAuth endpoints resolved from xAI OIDC discovery.
|
||||
type Discovery struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
}
|
||||
|
||||
// TokenData holds xAI OAuth token data.
|
||||
type TokenData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
TokenType string `json:"token_type,omitempty"`
|
||||
ExpiresIn int `json:"expires_in,omitempty"`
|
||||
Expire string `json:"expired,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Subject string `json:"sub,omitempty"`
|
||||
}
|
||||
|
||||
// AuthBundle aggregates token data and OAuth metadata for persistence.
|
||||
type AuthBundle struct {
|
||||
TokenData TokenData
|
||||
LastRefresh string
|
||||
BaseURL string
|
||||
RedirectURI string
|
||||
TokenEndpoint string
|
||||
}
|
||||
@@ -0,0 +1,304 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
|
||||
type XAIAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
|
||||
func NewXAIAuth(cfg *config.Config) *XAIAuth {
|
||||
return NewXAIAuthWithProxyURL(cfg, "")
|
||||
}
|
||||
|
||||
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
|
||||
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
|
||||
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||
var sdkCfg config.SDKConfig
|
||||
if cfg != nil {
|
||||
sdkCfg = cfg.SDKConfig
|
||||
if effectiveProxyURL == "" {
|
||||
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||
}
|
||||
}
|
||||
sdkCfg.ProxyURL = effectiveProxyURL
|
||||
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
|
||||
}
|
||||
|
||||
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
|
||||
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
|
||||
rawURL = strings.TrimSpace(rawURL)
|
||||
if rawURL == "" {
|
||||
return "", fmt.Errorf("xai discovery %s is empty", field)
|
||||
}
|
||||
parsed, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
|
||||
}
|
||||
if parsed.Scheme != "https" {
|
||||
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
|
||||
}
|
||||
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
|
||||
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
|
||||
}
|
||||
return rawURL, nil
|
||||
}
|
||||
|
||||
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
|
||||
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
|
||||
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimSpace(params.RedirectURI) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
|
||||
}
|
||||
if strings.TrimSpace(params.CodeChallenge) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: code challenge is required")
|
||||
}
|
||||
if strings.TrimSpace(params.State) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: state is required")
|
||||
}
|
||||
if strings.TrimSpace(params.Nonce) == "" {
|
||||
return "", fmt.Errorf("xai authorize URL: nonce is required")
|
||||
}
|
||||
values := url.Values{
|
||||
"response_type": {"code"},
|
||||
"client_id": {ClientID},
|
||||
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
|
||||
"scope": {Scope},
|
||||
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
|
||||
"code_challenge_method": {"S256"},
|
||||
"state": {strings.TrimSpace(params.State)},
|
||||
"nonce": {strings.TrimSpace(params.Nonce)},
|
||||
"plan": {"generic"},
|
||||
"referrer": {"cli-proxy-api"},
|
||||
}
|
||||
return endpoint + "?" + values.Encode(), nil
|
||||
}
|
||||
|
||||
// Discover resolves xAI OAuth endpoints through OIDC discovery.
|
||||
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai discovery: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: read response: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var payload struct {
|
||||
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
|
||||
}
|
||||
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
|
||||
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
|
||||
}
|
||||
if strings.TrimSpace(code) == "" {
|
||||
return nil, fmt.Errorf("xai token exchange: authorization code is required")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
|
||||
}
|
||||
if strings.TrimSpace(tokenEndpoint) == "" {
|
||||
discovery, errDiscover := a.Discover(ctx)
|
||||
if errDiscover != nil {
|
||||
return nil, errDiscover
|
||||
}
|
||||
tokenEndpoint = discovery.TokenEndpoint
|
||||
}
|
||||
form := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"code": {strings.TrimSpace(code)},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"client_id": {ClientID},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AuthBundle{
|
||||
TokenData: *tokenData,
|
||||
LastRefresh: time.Now().UTC().Format(time.RFC3339),
|
||||
BaseURL: DefaultAPIBaseURL,
|
||||
RedirectURI: strings.TrimSpace(redirectURI),
|
||||
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshTokens refreshes an xAI access token.
|
||||
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return nil, fmt.Errorf("xai token refresh: refresh token is required")
|
||||
}
|
||||
if strings.TrimSpace(tokenEndpoint) == "" {
|
||||
discovery, errDiscover := a.Discover(ctx)
|
||||
if errDiscover != nil {
|
||||
return nil, errDiscover
|
||||
}
|
||||
tokenEndpoint = discovery.TokenEndpoint
|
||||
}
|
||||
form := url.Values{
|
||||
"grant_type": {"refresh_token"},
|
||||
"client_id": {ClientID},
|
||||
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||
}
|
||||
return a.postTokenForm(ctx, tokenEndpoint, form)
|
||||
}
|
||||
|
||||
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai token request: create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai token request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai token request: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("xai token response: read body: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
var payload struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, fmt.Errorf("xai token response: parse body: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(payload.AccessToken) == "" {
|
||||
return nil, fmt.Errorf("xai token response missing access_token")
|
||||
}
|
||||
email, subject := parseJWTIdentity(payload.IDToken)
|
||||
return &TokenData{
|
||||
AccessToken: strings.TrimSpace(payload.AccessToken),
|
||||
RefreshToken: strings.TrimSpace(payload.RefreshToken),
|
||||
IDToken: strings.TrimSpace(payload.IDToken),
|
||||
TokenType: strings.TrimSpace(payload.TokenType),
|
||||
ExpiresIn: payload.ExpiresIn,
|
||||
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
|
||||
Email: email,
|
||||
Subject: subject,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage converts an auth bundle into persistable storage.
|
||||
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
|
||||
if bundle == nil {
|
||||
return nil
|
||||
}
|
||||
return &TokenStorage{
|
||||
Type: "xai",
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
IDToken: bundle.TokenData.IDToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
ExpiresIn: bundle.TokenData.ExpiresIn,
|
||||
Expire: bundle.TokenData.Expire,
|
||||
LastRefresh: bundle.LastRefresh,
|
||||
Email: strings.TrimSpace(bundle.TokenData.Email),
|
||||
Subject: bundle.TokenData.Subject,
|
||||
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
|
||||
RedirectURI: bundle.RedirectURI,
|
||||
TokenEndpoint: bundle.TokenEndpoint,
|
||||
AuthKind: "oauth",
|
||||
}
|
||||
}
|
||||
|
||||
func parseJWTIdentity(token string) (email string, subject string) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return "", ""
|
||||
}
|
||||
payload := parts[1]
|
||||
payload += strings.Repeat("=", (4-len(payload)%4)%4)
|
||||
raw, err := base64.URLEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
var claims map[string]any
|
||||
if err = json.Unmarshal(raw, &claims); err != nil {
|
||||
return "", ""
|
||||
}
|
||||
if v, ok := claims["email"].(string); ok {
|
||||
email = strings.TrimSpace(v)
|
||||
}
|
||||
if v, ok := claims["sub"].(string); ok {
|
||||
subject = strings.TrimSpace(v)
|
||||
}
|
||||
return email, subject
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, value := range values {
|
||||
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package xai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) {
|
||||
authURL, err := BuildAuthorizeURL(AuthorizeURLParams{
|
||||
AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize",
|
||||
RedirectURI: "http://127.0.0.1:56121/callback",
|
||||
CodeChallenge: "challenge",
|
||||
State: "state-123",
|
||||
Nonce: "nonce-123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildAuthorizeURL() error = %v", err)
|
||||
}
|
||||
|
||||
parsed, errParse := url.Parse(authURL)
|
||||
if errParse != nil {
|
||||
t.Fatalf("parse authorize URL: %v", errParse)
|
||||
}
|
||||
if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" {
|
||||
t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path)
|
||||
}
|
||||
|
||||
query := parsed.Query()
|
||||
want := map[string]string{
|
||||
"response_type": "code",
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": "http://127.0.0.1:56121/callback",
|
||||
"scope": Scope,
|
||||
"code_challenge": "challenge",
|
||||
"code_challenge_method": "S256",
|
||||
"state": "state-123",
|
||||
"nonce": "nonce-123",
|
||||
"plan": "generic",
|
||||
"referrer": "cli-proxy-api",
|
||||
}
|
||||
for key, value := range want {
|
||||
if got := query.Get(key); got != value {
|
||||
t.Fatalf("%s = %q, want %q", key, got, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) {
|
||||
if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil {
|
||||
t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err)
|
||||
}
|
||||
if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil {
|
||||
t.Fatal("expected non-HTTPS endpoint to be rejected")
|
||||
}
|
||||
if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil {
|
||||
t.Fatal("expected non-xAI endpoint to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) {
|
||||
var gotForm url.Values
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("method = %s, want POST", r.Method)
|
||||
}
|
||||
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") {
|
||||
t.Fatalf("Content-Type = %q, want form", got)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("ParseForm() error = %v", err)
|
||||
}
|
||||
gotForm = r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
auth := NewXAIAuth(nil)
|
||||
tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshTokens() error = %v", err)
|
||||
}
|
||||
if tokenData.AccessToken != "new-access" {
|
||||
t.Fatalf("access token = %q, want new-access", tokenData.AccessToken)
|
||||
}
|
||||
if gotForm.Get("grant_type") != "refresh_token" {
|
||||
t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type"))
|
||||
}
|
||||
if gotForm.Get("client_id") != ClientID {
|
||||
t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID)
|
||||
}
|
||||
if gotForm.Get("refresh_token") != "old-refresh" {
|
||||
t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token"))
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
// newAuthManager creates a new authentication manager instance with all supported
|
||||
// authenticators and a file-based token store. It initializes authenticators for
|
||||
// Gemini, Codex, Claude, Antigravity, and Kimi providers.
|
||||
// Gemini, Codex, Claude, Antigravity, Kimi, and xAI providers.
|
||||
//
|
||||
// Returns:
|
||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||
@@ -18,6 +18,7 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewClaudeAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
sdkAuth.NewXAIAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoXAILogin triggers the OAuth flow for the xAI provider and saves tokens.
|
||||
func DoXAILogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "xai", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("xAI authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("xAI authentication successful!")
|
||||
}
|
||||
@@ -137,7 +137,7 @@ type Config struct {
|
||||
|
||||
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||
// These aliases affect both model listing and model routing for supported channels:
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi.
|
||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
|
||||
//
|
||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||
|
||||
@@ -21,6 +21,7 @@ type staticModelsJSON struct {
|
||||
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||
Kimi []*ModelInfo `json:"kimi"`
|
||||
Antigravity []*ModelInfo `json:"antigravity"`
|
||||
XAI []*ModelInfo `json:"xai"`
|
||||
}
|
||||
|
||||
// GetClaudeModels returns the standard Claude model definitions.
|
||||
@@ -78,6 +79,11 @@ func GetAntigravityModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().Antigravity)
|
||||
}
|
||||
|
||||
// GetXAIModels returns the standard xAI Grok model definitions.
|
||||
func GetXAIModels() []*ModelInfo {
|
||||
return cloneModelInfos(getModels().XAI)
|
||||
}
|
||||
|
||||
// WithCodexBuiltins injects hard-coded Codex-only model definitions that should
|
||||
// not depend on remote models.json updates. Built-ins replace any matching IDs
|
||||
// already present in the provided slice.
|
||||
@@ -167,6 +173,7 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
||||
// - codex
|
||||
// - kimi
|
||||
// - antigravity
|
||||
// - xai
|
||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
key := strings.ToLower(strings.TrimSpace(channel))
|
||||
switch key {
|
||||
@@ -186,6 +193,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetKimiModels()
|
||||
case "antigravity":
|
||||
return GetAntigravityModels()
|
||||
case "xai", "x-ai", "grok":
|
||||
return GetXAIModels()
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
@@ -208,6 +217,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
data.CodexPro,
|
||||
data.Kimi,
|
||||
data.Antigravity,
|
||||
data.XAI,
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
|
||||
@@ -215,6 +215,7 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||
{"kimi", oldData.Kimi, newData.Kimi},
|
||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||
{"xai", oldData.XAI, newData.XAI},
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(sections))
|
||||
@@ -335,6 +336,7 @@ func validateModelsCatalog(data *staticModelsJSON) error {
|
||||
{name: "codex-pro", models: data.CodexPro},
|
||||
{name: "kimi", models: data.Kimi},
|
||||
{name: "antigravity", models: data.Antigravity},
|
||||
{name: "xai", models: data.XAI},
|
||||
}
|
||||
|
||||
for _, section := range requiredSections {
|
||||
|
||||
@@ -46,7 +46,8 @@
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
"high",
|
||||
"xhigh"
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -2064,5 +2065,109 @@
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"xai": [
|
||||
{
|
||||
"id": "grok-4.3",
|
||||
"object": "model",
|
||||
"created": 1775606400,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok 4.3",
|
||||
"name": "grok-4.3",
|
||||
"description": "xAI Grok 4.3 model for the Responses API.",
|
||||
"context_length": 1000000,
|
||||
"max_completion_tokens": 65536,
|
||||
"thinking": {
|
||||
"zero_allowed": true,
|
||||
"levels": [
|
||||
"none",
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "grok-4.20-0309-reasoning",
|
||||
"object": "model",
|
||||
"created": 1773014400,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok 4.20 0309 Reasoning",
|
||||
"name": "grok-4.20-0309-reasoning",
|
||||
"description": "xAI Grok 4.20 0309 reasoning model for the Responses API.",
|
||||
"context_length": 2000000,
|
||||
"max_completion_tokens": 65536
|
||||
},
|
||||
{
|
||||
"id": "grok-4.20-0309-non-reasoning",
|
||||
"object": "model",
|
||||
"created": 1773014400,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok 4.20 0309 Non Reasoning",
|
||||
"name": "grok-4.20-0309-non-reasoning",
|
||||
"description": "xAI Grok 4.20 0309 non-reasoning model for the Responses API.",
|
||||
"context_length": 2000000,
|
||||
"max_completion_tokens": 65536
|
||||
},
|
||||
{
|
||||
"id": "grok-4.20-multi-agent-0309",
|
||||
"object": "model",
|
||||
"created": 1773014400,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok 4.20 Multi Agent 0309",
|
||||
"name": "grok-4.20-multi-agent-0309",
|
||||
"description": "xAI Grok 4.20 multi-agent model for the Responses API.",
|
||||
"context_length": 2000000,
|
||||
"max_completion_tokens": 65536,
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "grok-3-mini",
|
||||
"object": "model",
|
||||
"created": 1740960000,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok 3 Mini",
|
||||
"name": "grok-3-mini",
|
||||
"description": "xAI Grok 3 Mini model for the Responses API.",
|
||||
"context_length": 131072,
|
||||
"max_completion_tokens": 32768,
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "grok-3-mini-fast",
|
||||
"object": "model",
|
||||
"created": 1740960000,
|
||||
"owned_by": "xai",
|
||||
"type": "xai",
|
||||
"display_name": "Grok 3 Mini Fast",
|
||||
"name": "grok-3-mini-fast",
|
||||
"description": "xAI Grok 3 Mini Fast model for the Responses API.",
|
||||
"context_length": 131072,
|
||||
"max_completion_tokens": 32768,
|
||||
"thinking": {
|
||||
"levels": [
|
||||
"low",
|
||||
"medium",
|
||||
"high"
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -0,0 +1,570 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
var xaiDataTag = []byte("data:")
|
||||
|
||||
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
|
||||
type XAIExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewXAIExecutor creates a new xAI executor.
|
||||
func NewXAIExecutor(cfg *config.Config) *XAIExecutor {
|
||||
return &XAIExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the provider identifier.
|
||||
func (e *XAIExecutor) Identifier() string {
|
||||
return "xai"
|
||||
}
|
||||
|
||||
// PrepareRequest injects xAI credentials into the outgoing HTTP request.
|
||||
func (e *XAIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token, _ := xaiCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects xAI credentials into the request and executes it.
|
||||
func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("xai executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
|
||||
return nil, errPrepare
|
||||
}
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
|
||||
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
|
||||
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||
if !bytes.HasPrefix(line, xaiDataTag) {
|
||||
continue
|
||||
}
|
||||
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
|
||||
switch gjson.GetBytes(eventData, "type").String() {
|
||||
case "response.output_item.done":
|
||||
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||
case "response.completed":
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, ¶m)
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"}
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
|
||||
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
|
||||
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return nil, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
var param any
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
translatedLine := bytes.Clone(line)
|
||||
if bytes.HasPrefix(line, xaiDataTag) {
|
||||
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
|
||||
switch gjson.GetBytes(eventData, "type").String() {
|
||||
case "response.output_item.done":
|
||||
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||
case "response.completed":
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
eventData = xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||
translatedLine = append([]byte("data: "), eventData...)
|
||||
}
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, ¶m)
|
||||
for i := range chunks {
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx, errScan)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for xAI Responses requests.
|
||||
func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
prepared, err := e.prepareResponsesRequest(ctx, req, opts, false)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
count, err := enc.Count(string(prepared.body))
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err)
|
||||
}
|
||||
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON))
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes xAI OAuth credentials using the stored refresh token.
|
||||
func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("xai executor: refresh called")
|
||||
if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled {
|
||||
return refreshed, err
|
||||
}
|
||||
if auth == nil {
|
||||
return nil, statusErr{code: http.StatusInternalServerError, msg: "xai executor: auth is nil"}
|
||||
}
|
||||
refreshToken := xaiMetadataString(auth.Metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
return auth, nil
|
||||
}
|
||||
tokenEndpoint := xaiMetadataString(auth.Metadata, "token_endpoint")
|
||||
svc := xaiauth.NewXAIAuthWithProxyURL(e.cfg, auth.ProxyURL)
|
||||
td, err := svc.RefreshTokens(ctx, refreshToken, tokenEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["type"] = "xai"
|
||||
auth.Metadata["auth_kind"] = "oauth"
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.IDToken != "" {
|
||||
auth.Metadata["id_token"] = td.IDToken
|
||||
}
|
||||
if td.TokenType != "" {
|
||||
auth.Metadata["token_type"] = td.TokenType
|
||||
}
|
||||
if td.ExpiresIn > 0 {
|
||||
auth.Metadata["expires_in"] = td.ExpiresIn
|
||||
}
|
||||
if td.Expire != "" {
|
||||
auth.Metadata["expired"] = td.Expire
|
||||
}
|
||||
if td.Email != "" {
|
||||
auth.Metadata["email"] = td.Email
|
||||
}
|
||||
if td.Subject != "" {
|
||||
auth.Metadata["sub"] = td.Subject
|
||||
}
|
||||
if tokenEndpoint != "" {
|
||||
auth.Metadata["token_endpoint"] = tokenEndpoint
|
||||
}
|
||||
if xaiMetadataString(auth.Metadata, "base_url") == "" {
|
||||
auth.Metadata["base_url"] = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
auth.Attributes["auth_kind"] = "oauth"
|
||||
if strings.TrimSpace(auth.Attributes["base_url"]) == "" {
|
||||
auth.Attributes["base_url"] = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
type xaiPreparedRequest struct {
|
||||
baseModel string
|
||||
from sdktranslator.Format
|
||||
to sdktranslator.Format
|
||||
originalPayload []byte
|
||||
body []byte
|
||||
sessionID string
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
|
||||
var err error
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", stream)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = sanitizeXAIResponsesBody(body, baseModel)
|
||||
|
||||
sessionID := xaiExecutionSessionID(req, opts)
|
||||
if sessionID != "" {
|
||||
body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID)
|
||||
}
|
||||
|
||||
return &xaiPreparedRequest{
|
||||
baseModel: baseModel,
|
||||
from: from,
|
||||
to: to,
|
||||
originalPayload: originalPayload,
|
||||
body: body,
|
||||
sessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) recordXAIRequest(ctx context.Context, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
|
||||
func xaiCreds(auth *cliproxyauth.Auth) (token, baseURL string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
token = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if token == "" {
|
||||
token = xaiMetadataString(auth.Metadata, "access_token")
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = xaiMetadataString(auth.Metadata, "base_url")
|
||||
}
|
||||
}
|
||||
return token, baseURL
|
||||
}
|
||||
|
||||
func applyXAIHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, sessionID string) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(token) != "" {
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
r.Header.Set("Connection", "Keep-Alive")
|
||||
if sessionID != "" {
|
||||
r.Header.Set("x-grok-conv-id", sessionID)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) string {
|
||||
if value := xaiMetadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
|
||||
return value
|
||||
}
|
||||
if value := xaiMetadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
|
||||
return value
|
||||
}
|
||||
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
||||
return strings.TrimSpace(promptCacheKey.String())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func xaiMetadataString(meta map[string]any, key string) string {
|
||||
if len(meta) == 0 || key == "" {
|
||||
return ""
|
||||
}
|
||||
value, ok := meta[key]
|
||||
if !ok || value == nil {
|
||||
return ""
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(typed.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(typed))
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeXAIResponsesBody(body []byte, model string) []byte {
|
||||
body = removeXAIEncryptedReasoningInclude(body)
|
||||
if !xaiSupportsReasoningEffort(model) {
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func removeXAIEncryptedReasoningInclude(body []byte) []byte {
|
||||
include := gjson.GetBytes(body, "include")
|
||||
if !include.Exists() || !include.IsArray() {
|
||||
return body
|
||||
}
|
||||
kept := make([]string, 0, len(include.Array()))
|
||||
for _, item := range include.Array() {
|
||||
value := strings.TrimSpace(item.String())
|
||||
if value == "" || value == "reasoning.encrypted_content" {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, value)
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "include", kept)
|
||||
return body
|
||||
}
|
||||
|
||||
func xaiSupportsReasoningEffort(model string) bool {
|
||||
name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName))
|
||||
if idx := strings.LastIndex(name, "/"); idx >= 0 {
|
||||
name = name[idx+1:]
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(name, "grok-3-mini"):
|
||||
return true
|
||||
case strings.HasPrefix(name, "grok-4.20-multi-agent"):
|
||||
return true
|
||||
case strings.HasPrefix(name, "grok-4.3"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func xaiCollectOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) {
|
||||
itemResult := gjson.GetBytes(eventData, "item")
|
||||
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
|
||||
return
|
||||
}
|
||||
outputIndexResult := gjson.GetBytes(eventData, "output_index")
|
||||
if outputIndexResult.Exists() {
|
||||
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
|
||||
return
|
||||
}
|
||||
*outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw))
|
||||
}
|
||||
|
||||
func xaiPatchCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte {
|
||||
outputResult := gjson.GetBytes(eventData, "response.output")
|
||||
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
|
||||
if !shouldPatchOutput {
|
||||
return eventData
|
||||
}
|
||||
|
||||
indexes := make([]int64, 0, len(outputItemsByIndex))
|
||||
for idx := range outputItemsByIndex {
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
sort.Slice(indexes, func(i, j int) bool {
|
||||
return indexes[i] < indexes[j]
|
||||
})
|
||||
|
||||
outputArray := []byte("[]")
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('[')
|
||||
wrote := false
|
||||
for _, idx := range indexes {
|
||||
if wrote {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
buf.Write(outputItemsByIndex[idx])
|
||||
wrote = true
|
||||
}
|
||||
for _, item := range outputItemsFallback {
|
||||
if wrote {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
buf.Write(item)
|
||||
wrote = true
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
if wrote {
|
||||
outputArray = buf.Bytes()
|
||||
}
|
||||
|
||||
patched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray)
|
||||
return patched
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotAuth string
|
||||
var gotGrokConvID string
|
||||
var gotOriginator string
|
||||
var gotAccountID string
|
||||
var gotBody []byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotGrokConvID = r.Header.Get("x-grok-conv-id")
|
||||
gotOriginator = r.Header.Get("Originator")
|
||||
gotAccountID = r.Header.Get("Chatgpt-Account-Id")
|
||||
var errRead error
|
||||
gotBody, errRead = io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read body: %v", errRead)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}],\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "xai-auth",
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "xai-token",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-4.3",
|
||||
Payload: []byte(`{"model":"grok-4.3","input":"hello","include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Stream: false,
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.ExecutionSessionMetadataKey: "conv-xai-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if gotPath != "/responses" {
|
||||
t.Fatalf("path = %q, want /responses", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer xai-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
|
||||
}
|
||||
if gotGrokConvID != "conv-xai-1" {
|
||||
t.Fatalf("x-grok-conv-id = %q, want conv-xai-1", gotGrokConvID)
|
||||
}
|
||||
if gotOriginator != "" {
|
||||
t.Fatalf("Originator = %q, want empty", gotOriginator)
|
||||
}
|
||||
if gotAccountID != "" {
|
||||
t.Fatalf("Chatgpt-Account-Id = %q, want empty", gotAccountID)
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "prompt_cache_key").String() != "conv-xai-1" {
|
||||
t.Fatalf("prompt_cache_key missing from body: %s", string(gotBody))
|
||||
}
|
||||
if !gjson.GetBytes(gotBody, "stream").Bool() {
|
||||
t.Fatalf("stream = false, want true; body=%s", string(gotBody))
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "reasoning.effort").String() != "high" {
|
||||
t.Fatalf("reasoning.effort = %q, want high; body=%s", gjson.GetBytes(gotBody, "reasoning.effort").String(), string(gotBody))
|
||||
}
|
||||
for _, include := range gjson.GetBytes(gotBody, "include").Array() {
|
||||
if include.String() == "reasoning.encrypted_content" {
|
||||
t.Fatalf("xai request must not ask for encrypted reasoning content: %s", string(gotBody))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var errRead error
|
||||
gotBody, errRead = io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read body: %v", errRead)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-4",
|
||||
Payload: []byte(`{"model":"grok-4","input":"hello","reasoning":{"effort":"high"}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(gotBody, "reasoning").Exists() {
|
||||
t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody))
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,7 @@ var oauthProviders = []oauthProvider{
|
||||
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
|
||||
{"Antigravity", "antigravity-auth-url", "🟪"},
|
||||
{"Kimi", "kimi-auth-url", "🟫"},
|
||||
{"xAI", "xai-auth-url", "⬛"},
|
||||
}
|
||||
|
||||
// oauthTabModel handles OAuth login flows.
|
||||
@@ -280,6 +281,8 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
|
||||
providerKey = "antigravity"
|
||||
case "kimi-auth-url":
|
||||
providerKey = "kimi"
|
||||
case "xai-auth-url":
|
||||
providerKey = "xai"
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user