feat(auth): add OAuth2 support for xAI with PKCE and token persistence
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
@@ -182,6 +182,7 @@ func main() {
|
|||||||
var oauthCallbackPort int
|
var oauthCallbackPort int
|
||||||
var antigravityLogin bool
|
var antigravityLogin bool
|
||||||
var kimiLogin bool
|
var kimiLogin bool
|
||||||
|
var xaiLogin bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var vertexImport string
|
var vertexImport string
|
||||||
var vertexImportPrefix string
|
var vertexImportPrefix string
|
||||||
@@ -203,6 +204,7 @@ func main() {
|
|||||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||||
|
flag.BoolVar(&xaiLogin, "xai-login", false, "Login to xAI using OAuth")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
@@ -656,6 +658,8 @@ func main() {
|
|||||||
cmd.DoClaudeLogin(cfg, options)
|
cmd.DoClaudeLogin(cfg, options)
|
||||||
} else if kimiLogin {
|
} else if kimiLogin {
|
||||||
cmd.DoKimiLogin(cfg, options)
|
cmd.DoKimiLogin(cfg, options)
|
||||||
|
} else if xaiLogin {
|
||||||
|
cmd.DoXAILogin(cfg, options)
|
||||||
} else {
|
} else {
|
||||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||||
if isCloudDeploy && !configFileExists {
|
if isCloudDeploy && !configFileExists {
|
||||||
|
|||||||
+6
-1
@@ -345,7 +345,7 @@ nonstream-keepalive-interval: 0
|
|||||||
|
|
||||||
# Global OAuth model name aliases (per channel)
|
# Global OAuth model name aliases (per channel)
|
||||||
# These aliases rename model IDs for both model listing and request routing.
|
# These aliases rename model IDs for both model listing and request routing.
|
||||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi.
|
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
|
||||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||||
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
# NOTE: Because aliases affect the merged /v1 model list and merged request routing, overlapping
|
||||||
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
# client-visible names can become ambiguous across providers. /api/provider/{provider}/... helps
|
||||||
@@ -375,6 +375,9 @@ nonstream-keepalive-interval: 0
|
|||||||
# kimi:
|
# kimi:
|
||||||
# - name: "kimi-k2.5"
|
# - name: "kimi-k2.5"
|
||||||
# alias: "k2.5"
|
# alias: "k2.5"
|
||||||
|
# xai:
|
||||||
|
# - name: "grok-4.3"
|
||||||
|
# alias: "grok-latest"
|
||||||
|
|
||||||
# OAuth provider excluded models
|
# OAuth provider excluded models
|
||||||
# oauth-excluded-models:
|
# oauth-excluded-models:
|
||||||
@@ -395,6 +398,8 @@ nonstream-keepalive-interval: 0
|
|||||||
# - "gpt-5-codex-mini"
|
# - "gpt-5-codex-mini"
|
||||||
# kimi:
|
# kimi:
|
||||||
# - "kimi-k2-thinking"
|
# - "kimi-k2-thinking"
|
||||||
|
# xai:
|
||||||
|
# - "grok-3-mini"
|
||||||
|
|
||||||
# Optional payload configuration
|
# Optional payload configuration
|
||||||
# payload:
|
# payload:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/codex"
|
||||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
|
geminiAuth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/gemini"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/auth/kimi"
|
||||||
|
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||||
@@ -2132,6 +2133,185 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
|||||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) RequestXAIToken(c *gin.Context) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|
||||||
|
fmt.Println("Initializing xAI authentication...")
|
||||||
|
|
||||||
|
pkceCodes, errPKCE := xaiauth.GeneratePKCECodes()
|
||||||
|
if errPKCE != nil {
|
||||||
|
log.Errorf("Failed to generate xAI PKCE codes: %v", errPKCE)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state, errState := misc.GenerateRandomState()
|
||||||
|
if errState != nil {
|
||||||
|
log.Errorf("Failed to generate state parameter: %v", errState)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nonce, errNonce := misc.GenerateRandomState()
|
||||||
|
if errNonce != nil {
|
||||||
|
log.Errorf("Failed to generate nonce parameter: %v", errNonce)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate nonce parameter"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := xaiauth.NewXAIAuth(h.cfg)
|
||||||
|
discovery, errDiscover := authSvc.Discover(ctx)
|
||||||
|
if errDiscover != nil {
|
||||||
|
log.Errorf("Failed to discover xAI OAuth endpoints: %v", errDiscover)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to discover oauth endpoints"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, xaiauth.CallbackPort, xaiauth.RedirectPath)
|
||||||
|
authURL, errAuthURL := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
|
||||||
|
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
CodeChallenge: pkceCodes.CodeChallenge,
|
||||||
|
State: state,
|
||||||
|
Nonce: nonce,
|
||||||
|
})
|
||||||
|
if errAuthURL != nil {
|
||||||
|
log.Errorf("Failed to generate xAI authorization URL: %v", errAuthURL)
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
RegisterOAuthSession(state, "xai")
|
||||||
|
|
||||||
|
isWebUI := isWebUIRequest(c)
|
||||||
|
var forwarder *callbackForwarder
|
||||||
|
if isWebUI {
|
||||||
|
targetURL, errTarget := h.managementCallbackURL("/xai/callback")
|
||||||
|
if errTarget != nil {
|
||||||
|
log.WithError(errTarget).Error("failed to compute xai callback target")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var errStart error
|
||||||
|
if forwarder, errStart = startCallbackForwarder(xaiauth.CallbackPort, "xai", targetURL); errStart != nil {
|
||||||
|
log.WithError(errStart).Error("failed to start xai callback forwarder")
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if isWebUI {
|
||||||
|
defer stopCallbackForwarderInstance(xaiauth.CallbackPort, forwarder)
|
||||||
|
}
|
||||||
|
|
||||||
|
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-xai-%s.oauth", state))
|
||||||
|
deadline := time.Now().Add(5 * time.Minute)
|
||||||
|
var authCode string
|
||||||
|
for {
|
||||||
|
if !IsOAuthSessionPending(state, "xai") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
log.Error("xai oauth flow timed out")
|
||||||
|
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||||
|
var payload map[string]string
|
||||||
|
_ = json.Unmarshal(data, &payload)
|
||||||
|
_ = os.Remove(waitFile)
|
||||||
|
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||||
|
log.Errorf("xAI authentication failed: %s", errStr)
|
||||||
|
SetOAuthSessionError(state, "Authentication failed: "+errStr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||||
|
log.Errorf("xAI authentication failed: state mismatch")
|
||||||
|
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authCode = strings.TrimSpace(payload["code"])
|
||||||
|
if authCode == "" {
|
||||||
|
log.Error("xAI authentication failed: code not found")
|
||||||
|
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI, pkceCodes, discovery.TokenEndpoint)
|
||||||
|
if errExchange != nil {
|
||||||
|
log.Errorf("Failed to exchange xAI token: %v", errExchange)
|
||||||
|
SetOAuthSessionError(state, oauthSessionErrorWithCause("Failed to exchange authorization code for tokens", errExchange))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(bundle)
|
||||||
|
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
|
||||||
|
log.Error("xAI token exchange returned empty access token")
|
||||||
|
SetOAuthSessionError(state, "Failed to exchange token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
|
||||||
|
label := strings.TrimSpace(tokenStorage.Email)
|
||||||
|
if label == "" {
|
||||||
|
label = "xAI"
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "xai",
|
||||||
|
"access_token": tokenStorage.AccessToken,
|
||||||
|
"refresh_token": tokenStorage.RefreshToken,
|
||||||
|
"id_token": tokenStorage.IDToken,
|
||||||
|
"token_type": tokenStorage.TokenType,
|
||||||
|
"expires_in": tokenStorage.ExpiresIn,
|
||||||
|
"expired": tokenStorage.Expire,
|
||||||
|
"last_refresh": tokenStorage.LastRefresh,
|
||||||
|
"base_url": tokenStorage.BaseURL,
|
||||||
|
"redirect_uri": tokenStorage.RedirectURI,
|
||||||
|
"token_endpoint": tokenStorage.TokenEndpoint,
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
}
|
||||||
|
if tokenStorage.Email != "" {
|
||||||
|
metadata["email"] = tokenStorage.Email
|
||||||
|
}
|
||||||
|
if tokenStorage.Subject != "" {
|
||||||
|
metadata["sub"] = tokenStorage.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "xai",
|
||||||
|
FileName: fileName,
|
||||||
|
Label: label,
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
"base_url": tokenStorage.BaseURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Errorf("Failed to save xAI token to file: %v", errSave)
|
||||||
|
SetOAuthSessionError(state, "Failed to save token to file")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
CompleteOAuthSession(state)
|
||||||
|
CompleteOAuthSessionsByProvider("xai")
|
||||||
|
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||||
|
fmt.Println("You can now use xAI services through this CLI")
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
ctx = PopulateAuthContext(ctx, c)
|
ctx = PopulateAuthContext(ctx, c)
|
||||||
|
|||||||
@@ -242,6 +242,8 @@ func NormalizeOAuthProvider(provider string) (string, error) {
|
|||||||
return "gemini", nil
|
return "gemini", nil
|
||||||
case "antigravity", "anti-gravity":
|
case "antigravity", "anti-gravity":
|
||||||
return "antigravity", nil
|
return "antigravity", nil
|
||||||
|
case "xai", "x-ai", "x.ai", "grok":
|
||||||
|
return "xai", nil
|
||||||
default:
|
default:
|
||||||
return "", errUnsupportedOAuthFlow
|
return "", errUnsupportedOAuthFlow
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -484,6 +484,20 @@ func (s *Server) setupRoutes() {
|
|||||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
s.engine.GET("/xai/callback", func(c *gin.Context) {
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
errStr := c.Query("error")
|
||||||
|
if errStr == "" {
|
||||||
|
errStr = c.Query("error_description")
|
||||||
|
}
|
||||||
|
if state != "" {
|
||||||
|
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "xai", state, code, errStr)
|
||||||
|
}
|
||||||
|
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||||
|
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||||
|
})
|
||||||
|
|
||||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -685,6 +699,7 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||||
|
mgmt.GET("/xai-auth-url", s.mgmt.RequestXAIToken)
|
||||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeneratePKCECodes creates a verifier/challenge pair for the OAuth flow.
|
||||||
|
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||||
|
bytes := make([]byte, 96)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("xai pkce: generate verifier: %w", err)
|
||||||
|
}
|
||||||
|
verifier := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes)
|
||||||
|
hash := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
||||||
|
return &PKCECodes{CodeVerifier: verifier, CodeChallenge: challenge}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenStorage stores xAI OAuth credentials on disk.
|
||||||
|
type TokenStorage struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token,omitempty"`
|
||||||
|
TokenType string `json:"token_type,omitempty"`
|
||||||
|
ExpiresIn int `json:"expires_in,omitempty"`
|
||||||
|
Expire string `json:"expired,omitempty"`
|
||||||
|
LastRefresh string `json:"last_refresh,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
BaseURL string `json:"base_url,omitempty"`
|
||||||
|
RedirectURI string `json:"redirect_uri,omitempty"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint,omitempty"`
|
||||||
|
AuthKind string `json:"auth_kind,omitempty"`
|
||||||
|
|
||||||
|
Metadata map[string]any `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetadata allows the token store to merge status fields before saving.
|
||||||
|
func (ts *TokenStorage) SetMetadata(meta map[string]any) {
|
||||||
|
ts.Metadata = meta
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile writes xAI credentials to a JSON auth file.
|
||||||
|
func (ts *TokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
ts.Type = "xai"
|
||||||
|
ts.AuthKind = "oauth"
|
||||||
|
if errMkdirAll := os.MkdirAll(filepath.Dir(authFilePath), 0o700); errMkdirAll != nil {
|
||||||
|
return fmt.Errorf("xai token storage: create directory: %w", errMkdirAll)
|
||||||
|
}
|
||||||
|
file, err := os.Create(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("xai token storage: create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := file.Close(); errClose != nil {
|
||||||
|
log.Errorf("xai token storage: close token file error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||||
|
if errMerge != nil {
|
||||||
|
return fmt.Errorf("xai token storage: merge metadata: %w", errMerge)
|
||||||
|
}
|
||||||
|
encoder := json.NewEncoder(file)
|
||||||
|
encoder.SetIndent("", " ")
|
||||||
|
if err = encoder.Encode(data); err != nil {
|
||||||
|
return fmt.Errorf("xai token storage: write token file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CredentialFileName returns the filename used for xAI credentials.
|
||||||
|
func CredentialFileName(email, subject string) string {
|
||||||
|
email = sanitizeFileSegment(email)
|
||||||
|
if email != "" {
|
||||||
|
return fmt.Sprintf("xai-%s.json", email)
|
||||||
|
}
|
||||||
|
subject = sanitizeFileSegment(subject)
|
||||||
|
if subject != "" {
|
||||||
|
return fmt.Sprintf("xai-%s.json", subject)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("xai-%d.json", time.Now().UnixMilli())
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeFileSegment(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
for _, r := range value {
|
||||||
|
switch {
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r == '@' || r == '.' || r == '_' || r == '-':
|
||||||
|
b.WriteRune(r)
|
||||||
|
default:
|
||||||
|
b.WriteRune('-')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Trim(b.String(), "-")
|
||||||
|
}
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
// Package xai provides OAuth2 authentication helpers for xAI Grok.
|
||||||
|
package xai
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultAPIBaseURL is the default xAI Responses API base URL.
|
||||||
|
DefaultAPIBaseURL = "https://api.x.ai/v1"
|
||||||
|
// Issuer is xAI's OAuth issuer.
|
||||||
|
Issuer = "https://auth.x.ai"
|
||||||
|
// DiscoveryURL is the OIDC discovery endpoint used to resolve OAuth endpoints.
|
||||||
|
DiscoveryURL = Issuer + "/.well-known/openid-configuration"
|
||||||
|
// ClientID is the public xAI Grok CLI OAuth client ID.
|
||||||
|
ClientID = "b1a00492-073a-47ea-816f-4c329264a828"
|
||||||
|
// Scope is the OAuth scope set required for xAI API access.
|
||||||
|
Scope = "openid profile email offline_access grok-cli:access api:access"
|
||||||
|
// RedirectHost is the loopback host used by xAI OAuth.
|
||||||
|
RedirectHost = "127.0.0.1"
|
||||||
|
// CallbackPort is the preferred loopback callback port.
|
||||||
|
CallbackPort = 56121
|
||||||
|
// RedirectPath is the loopback callback path registered by the xAI client.
|
||||||
|
RedirectPath = "/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
var refreshLead = 5 * time.Minute
|
||||||
|
|
||||||
|
// RefreshLead returns the refresh lead time for xAI OAuth credentials.
|
||||||
|
func RefreshLead() time.Duration {
|
||||||
|
return refreshLead
|
||||||
|
}
|
||||||
|
|
||||||
|
// PKCECodes holds the PKCE verifier/challenge pair.
|
||||||
|
type PKCECodes struct {
|
||||||
|
CodeVerifier string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeURLParams contains the values used to build the xAI OAuth URL.
|
||||||
|
type AuthorizeURLParams struct {
|
||||||
|
AuthorizationEndpoint string
|
||||||
|
RedirectURI string
|
||||||
|
CodeChallenge string
|
||||||
|
State string
|
||||||
|
Nonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discovery contains OAuth endpoints resolved from xAI OIDC discovery.
|
||||||
|
type Discovery struct {
|
||||||
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenData holds xAI OAuth token data.
|
||||||
|
type TokenData struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token,omitempty"`
|
||||||
|
TokenType string `json:"token_type,omitempty"`
|
||||||
|
ExpiresIn int `json:"expires_in,omitempty"`
|
||||||
|
Expire string `json:"expired,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Subject string `json:"sub,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthBundle aggregates token data and OAuth metadata for persistence.
|
||||||
|
type AuthBundle struct {
|
||||||
|
TokenData TokenData
|
||||||
|
LastRefresh string
|
||||||
|
BaseURL string
|
||||||
|
RedirectURI string
|
||||||
|
TokenEndpoint string
|
||||||
|
}
|
||||||
@@ -0,0 +1,304 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
|
||||||
|
type XAIAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
|
||||||
|
func NewXAIAuth(cfg *config.Config) *XAIAuth {
|
||||||
|
return NewXAIAuthWithProxyURL(cfg, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewXAIAuthWithProxyURL creates an xAI OAuth helper with an explicit proxy URL.
|
||||||
|
func NewXAIAuthWithProxyURL(cfg *config.Config, proxyURL string) *XAIAuth {
|
||||||
|
effectiveProxyURL := strings.TrimSpace(proxyURL)
|
||||||
|
var sdkCfg config.SDKConfig
|
||||||
|
if cfg != nil {
|
||||||
|
sdkCfg = cfg.SDKConfig
|
||||||
|
if effectiveProxyURL == "" {
|
||||||
|
effectiveProxyURL = strings.TrimSpace(cfg.ProxyURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sdkCfg.ProxyURL = effectiveProxyURL
|
||||||
|
return &XAIAuth{httpClient: util.SetProxy(&sdkCfg, &http.Client{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateOAuthEndpoint validates an endpoint returned by xAI discovery.
|
||||||
|
func ValidateOAuthEndpoint(rawURL string, field string) (string, error) {
|
||||||
|
rawURL = strings.TrimSpace(rawURL)
|
||||||
|
if rawURL == "" {
|
||||||
|
return "", fmt.Errorf("xai discovery %s is empty", field)
|
||||||
|
}
|
||||||
|
parsed, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("xai discovery %s is invalid: %w", field, err)
|
||||||
|
}
|
||||||
|
if parsed.Scheme != "https" {
|
||||||
|
return "", fmt.Errorf("xai discovery %s must use https: %q", field, rawURL)
|
||||||
|
}
|
||||||
|
host := strings.ToLower(strings.TrimSpace(parsed.Hostname()))
|
||||||
|
if host != "x.ai" && !strings.HasSuffix(host, ".x.ai") {
|
||||||
|
return "", fmt.Errorf("xai discovery %s host %q is not on x.ai", field, host)
|
||||||
|
}
|
||||||
|
return rawURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildAuthorizeURL builds the browser URL for xAI OAuth.
|
||||||
|
func BuildAuthorizeURL(params AuthorizeURLParams) (string, error) {
|
||||||
|
endpoint, err := ValidateOAuthEndpoint(params.AuthorizationEndpoint, "authorization_endpoint")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(params.RedirectURI) == "" {
|
||||||
|
return "", fmt.Errorf("xai authorize URL: redirect URI is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(params.CodeChallenge) == "" {
|
||||||
|
return "", fmt.Errorf("xai authorize URL: code challenge is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(params.State) == "" {
|
||||||
|
return "", fmt.Errorf("xai authorize URL: state is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(params.Nonce) == "" {
|
||||||
|
return "", fmt.Errorf("xai authorize URL: nonce is required")
|
||||||
|
}
|
||||||
|
values := url.Values{
|
||||||
|
"response_type": {"code"},
|
||||||
|
"client_id": {ClientID},
|
||||||
|
"redirect_uri": {strings.TrimSpace(params.RedirectURI)},
|
||||||
|
"scope": {Scope},
|
||||||
|
"code_challenge": {strings.TrimSpace(params.CodeChallenge)},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
"state": {strings.TrimSpace(params.State)},
|
||||||
|
"nonce": {strings.TrimSpace(params.Nonce)},
|
||||||
|
"plan": {"generic"},
|
||||||
|
"referrer": {"cli-proxy-api"},
|
||||||
|
}
|
||||||
|
return endpoint + "?" + values.Encode(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Discover resolves xAI OAuth endpoints through OIDC discovery.
|
||||||
|
func (a *XAIAuth) Discover(ctx context.Context) (*Discovery, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, DiscoveryURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai discovery: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai discovery: request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("xai discovery: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai discovery: read response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("xai discovery failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
AuthorizationEndpoint string `json:"authorization_endpoint"`
|
||||||
|
TokenEndpoint string `json:"token_endpoint"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil, fmt.Errorf("xai discovery: parse response: %w", err)
|
||||||
|
}
|
||||||
|
authorizationEndpoint, err := ValidateOAuthEndpoint(payload.AuthorizationEndpoint, "authorization_endpoint")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokenEndpoint, err := ValidateOAuthEndpoint(payload.TokenEndpoint, "token_endpoint")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Discovery{AuthorizationEndpoint: authorizationEndpoint, TokenEndpoint: tokenEndpoint}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokens exchanges an authorization code for xAI OAuth tokens.
|
||||||
|
func (a *XAIAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes, tokenEndpoint string) (*AuthBundle, error) {
|
||||||
|
if pkceCodes == nil {
|
||||||
|
return nil, fmt.Errorf("xai token exchange: PKCE codes are required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(code) == "" {
|
||||||
|
return nil, fmt.Errorf("xai token exchange: authorization code is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(redirectURI) == "" {
|
||||||
|
return nil, fmt.Errorf("xai token exchange: redirect URI is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(tokenEndpoint) == "" {
|
||||||
|
discovery, errDiscover := a.Discover(ctx)
|
||||||
|
if errDiscover != nil {
|
||||||
|
return nil, errDiscover
|
||||||
|
}
|
||||||
|
tokenEndpoint = discovery.TokenEndpoint
|
||||||
|
}
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {strings.TrimSpace(code)},
|
||||||
|
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||||
|
"client_id": {ClientID},
|
||||||
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
|
}
|
||||||
|
tokenData, err := a.postTokenForm(ctx, tokenEndpoint, form)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &AuthBundle{
|
||||||
|
TokenData: *tokenData,
|
||||||
|
LastRefresh: time.Now().UTC().Format(time.RFC3339),
|
||||||
|
BaseURL: DefaultAPIBaseURL,
|
||||||
|
RedirectURI: strings.TrimSpace(redirectURI),
|
||||||
|
TokenEndpoint: strings.TrimSpace(tokenEndpoint),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokens refreshes an xAI access token.
|
||||||
|
func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
|
||||||
|
if strings.TrimSpace(refreshToken) == "" {
|
||||||
|
return nil, fmt.Errorf("xai token refresh: refresh token is required")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(tokenEndpoint) == "" {
|
||||||
|
discovery, errDiscover := a.Discover(ctx)
|
||||||
|
if errDiscover != nil {
|
||||||
|
return nil, errDiscover
|
||||||
|
}
|
||||||
|
tokenEndpoint = discovery.TokenEndpoint
|
||||||
|
}
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"client_id": {ClientID},
|
||||||
|
"refresh_token": {strings.TrimSpace(refreshToken)},
|
||||||
|
}
|
||||||
|
return a.postTokenForm(ctx, tokenEndpoint, form)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *XAIAuth) postTokenForm(ctx context.Context, tokenEndpoint string, form url.Values) (*TokenData, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimSpace(tokenEndpoint), strings.NewReader(form.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai token request: create request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai token request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := resp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("xai token request: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai token response: read body: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("xai token request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||||
|
}
|
||||||
|
var payload struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
if err = json.Unmarshal(body, &payload); err != nil {
|
||||||
|
return nil, fmt.Errorf("xai token response: parse body: %w", err)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(payload.AccessToken) == "" {
|
||||||
|
return nil, fmt.Errorf("xai token response missing access_token")
|
||||||
|
}
|
||||||
|
email, subject := parseJWTIdentity(payload.IDToken)
|
||||||
|
return &TokenData{
|
||||||
|
AccessToken: strings.TrimSpace(payload.AccessToken),
|
||||||
|
RefreshToken: strings.TrimSpace(payload.RefreshToken),
|
||||||
|
IDToken: strings.TrimSpace(payload.IDToken),
|
||||||
|
TokenType: strings.TrimSpace(payload.TokenType),
|
||||||
|
ExpiresIn: payload.ExpiresIn,
|
||||||
|
Expire: time.Now().Add(time.Duration(payload.ExpiresIn) * time.Second).UTC().Format(time.RFC3339),
|
||||||
|
Email: email,
|
||||||
|
Subject: subject,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenStorage converts an auth bundle into persistable storage.
|
||||||
|
func (a *XAIAuth) CreateTokenStorage(bundle *AuthBundle) *TokenStorage {
|
||||||
|
if bundle == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &TokenStorage{
|
||||||
|
Type: "xai",
|
||||||
|
AccessToken: bundle.TokenData.AccessToken,
|
||||||
|
RefreshToken: bundle.TokenData.RefreshToken,
|
||||||
|
IDToken: bundle.TokenData.IDToken,
|
||||||
|
TokenType: bundle.TokenData.TokenType,
|
||||||
|
ExpiresIn: bundle.TokenData.ExpiresIn,
|
||||||
|
Expire: bundle.TokenData.Expire,
|
||||||
|
LastRefresh: bundle.LastRefresh,
|
||||||
|
Email: strings.TrimSpace(bundle.TokenData.Email),
|
||||||
|
Subject: bundle.TokenData.Subject,
|
||||||
|
BaseURL: firstNonEmpty(bundle.BaseURL, DefaultAPIBaseURL),
|
||||||
|
RedirectURI: bundle.RedirectURI,
|
||||||
|
TokenEndpoint: bundle.TokenEndpoint,
|
||||||
|
AuthKind: "oauth",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseJWTIdentity(token string) (email string, subject string) {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
payload := parts[1]
|
||||||
|
payload += strings.Repeat("=", (4-len(payload)%4)%4)
|
||||||
|
raw, err := base64.URLEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
var claims map[string]any
|
||||||
|
if err = json.Unmarshal(raw, &claims); err != nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if v, ok := claims["email"].(string); ok {
|
||||||
|
email = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
if v, ok := claims["sub"].(string); ok {
|
||||||
|
subject = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
return email, subject
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
|
for _, value := range values {
|
||||||
|
if trimmed := strings.TrimSpace(value); trimmed != "" {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package xai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildAuthorizeURLIncludesXAIRequiredParameters(t *testing.T) {
|
||||||
|
authURL, err := BuildAuthorizeURL(AuthorizeURLParams{
|
||||||
|
AuthorizationEndpoint: "https://auth.x.ai/oauth/authorize",
|
||||||
|
RedirectURI: "http://127.0.0.1:56121/callback",
|
||||||
|
CodeChallenge: "challenge",
|
||||||
|
State: "state-123",
|
||||||
|
Nonce: "nonce-123",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BuildAuthorizeURL() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, errParse := url.Parse(authURL)
|
||||||
|
if errParse != nil {
|
||||||
|
t.Fatalf("parse authorize URL: %v", errParse)
|
||||||
|
}
|
||||||
|
if parsed.Scheme != "https" || parsed.Host != "auth.x.ai" || parsed.Path != "/oauth/authorize" {
|
||||||
|
t.Fatalf("authorize URL endpoint = %s://%s%s", parsed.Scheme, parsed.Host, parsed.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := parsed.Query()
|
||||||
|
want := map[string]string{
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": ClientID,
|
||||||
|
"redirect_uri": "http://127.0.0.1:56121/callback",
|
||||||
|
"scope": Scope,
|
||||||
|
"code_challenge": "challenge",
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"state": "state-123",
|
||||||
|
"nonce": "nonce-123",
|
||||||
|
"plan": "generic",
|
||||||
|
"referrer": "cli-proxy-api",
|
||||||
|
}
|
||||||
|
for key, value := range want {
|
||||||
|
if got := query.Get(key); got != value {
|
||||||
|
t.Fatalf("%s = %q, want %q", key, got, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOAuthEndpointRejectsNonXAIOrigin(t *testing.T) {
|
||||||
|
if _, err := ValidateOAuthEndpoint("https://auth.x.ai/oauth/token", "token_endpoint"); err != nil {
|
||||||
|
t.Fatalf("ValidateOAuthEndpoint(xai) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := ValidateOAuthEndpoint("http://auth.x.ai/oauth/token", "token_endpoint"); err == nil {
|
||||||
|
t.Fatal("expected non-HTTPS endpoint to be rejected")
|
||||||
|
}
|
||||||
|
if _, err := ValidateOAuthEndpoint("https://evil.example/oauth/token", "token_endpoint"); err == nil {
|
||||||
|
t.Fatal("expected non-xAI endpoint to be rejected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokensPostsClientIDAndRefreshToken(t *testing.T) {
|
||||||
|
var gotForm url.Values
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("method = %s, want POST", r.Method)
|
||||||
|
}
|
||||||
|
if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/x-www-form-urlencoded") {
|
||||||
|
t.Fatalf("Content-Type = %q, want form", got)
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
t.Fatalf("ParseForm() error = %v", err)
|
||||||
|
}
|
||||||
|
gotForm = r.PostForm
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"access_token": "new-access",
|
||||||
|
"refresh_token": "new-refresh",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
auth := NewXAIAuth(nil)
|
||||||
|
tokenData, err := auth.RefreshTokens(context.Background(), "old-refresh", server.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RefreshTokens() error = %v", err)
|
||||||
|
}
|
||||||
|
if tokenData.AccessToken != "new-access" {
|
||||||
|
t.Fatalf("access token = %q, want new-access", tokenData.AccessToken)
|
||||||
|
}
|
||||||
|
if gotForm.Get("grant_type") != "refresh_token" {
|
||||||
|
t.Fatalf("grant_type = %q, want refresh_token", gotForm.Get("grant_type"))
|
||||||
|
}
|
||||||
|
if gotForm.Get("client_id") != ClientID {
|
||||||
|
t.Fatalf("client_id = %q, want %q", gotForm.Get("client_id"), ClientID)
|
||||||
|
}
|
||||||
|
if gotForm.Get("refresh_token") != "old-refresh" {
|
||||||
|
t.Fatalf("refresh_token = %q, want old-refresh", gotForm.Get("refresh_token"))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
// newAuthManager creates a new authentication manager instance with all supported
|
// newAuthManager creates a new authentication manager instance with all supported
|
||||||
// authenticators and a file-based token store. It initializes authenticators for
|
// authenticators and a file-based token store. It initializes authenticators for
|
||||||
// Gemini, Codex, Claude, Antigravity, and Kimi providers.
|
// Gemini, Codex, Claude, Antigravity, Kimi, and xAI providers.
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *sdkAuth.Manager: A configured authentication manager instance
|
// - *sdkAuth.Manager: A configured authentication manager instance
|
||||||
@@ -18,6 +18,7 @@ func newAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewClaudeAuthenticator(),
|
sdkAuth.NewClaudeAuthenticator(),
|
||||||
sdkAuth.NewAntigravityAuthenticator(),
|
sdkAuth.NewAntigravityAuthenticator(),
|
||||||
sdkAuth.NewKimiAuthenticator(),
|
sdkAuth.NewKimiAuthenticator(),
|
||||||
|
sdkAuth.NewXAIAuthenticator(),
|
||||||
)
|
)
|
||||||
return manager
|
return manager
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoXAILogin triggers the OAuth flow for the xAI provider and saves tokens.
|
||||||
|
func DoXAILogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
promptFn := options.Prompt
|
||||||
|
if promptFn == nil {
|
||||||
|
promptFn = defaultProjectPrompt()
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := newAuthManager()
|
||||||
|
authOpts := &sdkAuth.LoginOptions{
|
||||||
|
NoBrowser: options.NoBrowser,
|
||||||
|
CallbackPort: options.CallbackPort,
|
||||||
|
Metadata: map[string]string{},
|
||||||
|
Prompt: promptFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
record, savedPath, err := manager.Login(context.Background(), "xai", cfg, authOpts)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("xAI authentication failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if savedPath != "" {
|
||||||
|
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||||
|
}
|
||||||
|
if record != nil && record.Label != "" {
|
||||||
|
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||||
|
}
|
||||||
|
fmt.Println("xAI authentication successful!")
|
||||||
|
}
|
||||||
@@ -137,7 +137,7 @@ type Config struct {
|
|||||||
|
|
||||||
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
|
||||||
// These aliases affect both model listing and model routing for supported channels:
|
// These aliases affect both model listing and model routing for supported channels:
|
||||||
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi.
|
// gemini-cli, vertex, aistudio, antigravity, claude, codex, kimi, xai.
|
||||||
//
|
//
|
||||||
// NOTE: This does not apply to existing per-credential model alias features under:
|
// NOTE: This does not apply to existing per-credential model alias features under:
|
||||||
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type staticModelsJSON struct {
|
|||||||
CodexPro []*ModelInfo `json:"codex-pro"`
|
CodexPro []*ModelInfo `json:"codex-pro"`
|
||||||
Kimi []*ModelInfo `json:"kimi"`
|
Kimi []*ModelInfo `json:"kimi"`
|
||||||
Antigravity []*ModelInfo `json:"antigravity"`
|
Antigravity []*ModelInfo `json:"antigravity"`
|
||||||
|
XAI []*ModelInfo `json:"xai"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClaudeModels returns the standard Claude model definitions.
|
// GetClaudeModels returns the standard Claude model definitions.
|
||||||
@@ -78,6 +79,11 @@ func GetAntigravityModels() []*ModelInfo {
|
|||||||
return cloneModelInfos(getModels().Antigravity)
|
return cloneModelInfos(getModels().Antigravity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetXAIModels returns the standard xAI Grok model definitions.
|
||||||
|
func GetXAIModels() []*ModelInfo {
|
||||||
|
return cloneModelInfos(getModels().XAI)
|
||||||
|
}
|
||||||
|
|
||||||
// WithCodexBuiltins injects hard-coded Codex-only model definitions that should
|
// WithCodexBuiltins injects hard-coded Codex-only model definitions that should
|
||||||
// not depend on remote models.json updates. Built-ins replace any matching IDs
|
// not depend on remote models.json updates. Built-ins replace any matching IDs
|
||||||
// already present in the provided slice.
|
// already present in the provided slice.
|
||||||
@@ -167,6 +173,7 @@ func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
|
|||||||
// - codex
|
// - codex
|
||||||
// - kimi
|
// - kimi
|
||||||
// - antigravity
|
// - antigravity
|
||||||
|
// - xai
|
||||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||||
key := strings.ToLower(strings.TrimSpace(channel))
|
key := strings.ToLower(strings.TrimSpace(channel))
|
||||||
switch key {
|
switch key {
|
||||||
@@ -186,6 +193,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
|||||||
return GetKimiModels()
|
return GetKimiModels()
|
||||||
case "antigravity":
|
case "antigravity":
|
||||||
return GetAntigravityModels()
|
return GetAntigravityModels()
|
||||||
|
case "xai", "x-ai", "grok":
|
||||||
|
return GetXAIModels()
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -208,6 +217,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
|||||||
data.CodexPro,
|
data.CodexPro,
|
||||||
data.Kimi,
|
data.Kimi,
|
||||||
data.Antigravity,
|
data.Antigravity,
|
||||||
|
data.XAI,
|
||||||
}
|
}
|
||||||
for _, models := range allModels {
|
for _, models := range allModels {
|
||||||
for _, m := range models {
|
for _, m := range models {
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
|||||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||||
{"kimi", oldData.Kimi, newData.Kimi},
|
{"kimi", oldData.Kimi, newData.Kimi},
|
||||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||||
|
{"xai", oldData.XAI, newData.XAI},
|
||||||
}
|
}
|
||||||
|
|
||||||
seen := make(map[string]bool, len(sections))
|
seen := make(map[string]bool, len(sections))
|
||||||
@@ -335,6 +336,7 @@ func validateModelsCatalog(data *staticModelsJSON) error {
|
|||||||
{name: "codex-pro", models: data.CodexPro},
|
{name: "codex-pro", models: data.CodexPro},
|
||||||
{name: "kimi", models: data.Kimi},
|
{name: "kimi", models: data.Kimi},
|
||||||
{name: "antigravity", models: data.Antigravity},
|
{name: "antigravity", models: data.Antigravity},
|
||||||
|
{name: "xai", models: data.XAI},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, section := range requiredSections {
|
for _, section := range requiredSections {
|
||||||
|
|||||||
@@ -46,7 +46,8 @@
|
|||||||
"levels": [
|
"levels": [
|
||||||
"low",
|
"low",
|
||||||
"medium",
|
"medium",
|
||||||
"high"
|
"high",
|
||||||
|
"xhigh"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -2064,5 +2065,109 @@
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
],
|
||||||
|
"xai": [
|
||||||
|
{
|
||||||
|
"id": "grok-4.3",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1775606400,
|
||||||
|
"owned_by": "xai",
|
||||||
|
"type": "xai",
|
||||||
|
"display_name": "Grok 4.3",
|
||||||
|
"name": "grok-4.3",
|
||||||
|
"description": "xAI Grok 4.3 model for the Responses API.",
|
||||||
|
"context_length": 1000000,
|
||||||
|
"max_completion_tokens": 65536,
|
||||||
|
"thinking": {
|
||||||
|
"zero_allowed": true,
|
||||||
|
"levels": [
|
||||||
|
"none",
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "grok-4.20-0309-reasoning",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773014400,
|
||||||
|
"owned_by": "xai",
|
||||||
|
"type": "xai",
|
||||||
|
"display_name": "Grok 4.20 0309 Reasoning",
|
||||||
|
"name": "grok-4.20-0309-reasoning",
|
||||||
|
"description": "xAI Grok 4.20 0309 reasoning model for the Responses API.",
|
||||||
|
"context_length": 2000000,
|
||||||
|
"max_completion_tokens": 65536
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "grok-4.20-0309-non-reasoning",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773014400,
|
||||||
|
"owned_by": "xai",
|
||||||
|
"type": "xai",
|
||||||
|
"display_name": "Grok 4.20 0309 Non Reasoning",
|
||||||
|
"name": "grok-4.20-0309-non-reasoning",
|
||||||
|
"description": "xAI Grok 4.20 0309 non-reasoning model for the Responses API.",
|
||||||
|
"context_length": 2000000,
|
||||||
|
"max_completion_tokens": 65536
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "grok-4.20-multi-agent-0309",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1773014400,
|
||||||
|
"owned_by": "xai",
|
||||||
|
"type": "xai",
|
||||||
|
"display_name": "Grok 4.20 Multi Agent 0309",
|
||||||
|
"name": "grok-4.20-multi-agent-0309",
|
||||||
|
"description": "xAI Grok 4.20 multi-agent model for the Responses API.",
|
||||||
|
"context_length": 2000000,
|
||||||
|
"max_completion_tokens": 65536,
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "grok-3-mini",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1740960000,
|
||||||
|
"owned_by": "xai",
|
||||||
|
"type": "xai",
|
||||||
|
"display_name": "Grok 3 Mini",
|
||||||
|
"name": "grok-3-mini",
|
||||||
|
"description": "xAI Grok 3 Mini model for the Responses API.",
|
||||||
|
"context_length": 131072,
|
||||||
|
"max_completion_tokens": 32768,
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "grok-3-mini-fast",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1740960000,
|
||||||
|
"owned_by": "xai",
|
||||||
|
"type": "xai",
|
||||||
|
"display_name": "Grok 3 Mini Fast",
|
||||||
|
"name": "grok-3-mini-fast",
|
||||||
|
"description": "xAI Grok 3 Mini Fast model for the Responses API.",
|
||||||
|
"context_length": 131072,
|
||||||
|
"max_completion_tokens": 32768,
|
||||||
|
"thinking": {
|
||||||
|
"levels": [
|
||||||
|
"low",
|
||||||
|
"medium",
|
||||||
|
"high"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,570 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"github.com/tiktoken-go/tokenizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
var xaiDataTag = []byte("data:")
|
||||||
|
|
||||||
|
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
|
||||||
|
type XAIExecutor struct {
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewXAIExecutor creates a new xAI executor.
|
||||||
|
func NewXAIExecutor(cfg *config.Config) *XAIExecutor {
|
||||||
|
return &XAIExecutor{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier returns the provider identifier.
|
||||||
|
func (e *XAIExecutor) Identifier() string {
|
||||||
|
return "xai"
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrepareRequest injects xAI credentials into the outgoing HTTP request.
|
||||||
|
func (e *XAIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||||
|
if req == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
token, _ := xaiCreds(auth)
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HttpRequest injects xAI credentials into the request and executes it.
|
||||||
|
func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("xai executor: request is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = req.Context()
|
||||||
|
}
|
||||||
|
httpReq := req.WithContext(ctx)
|
||||||
|
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
|
||||||
|
return nil, errPrepare
|
||||||
|
}
|
||||||
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
return httpClient.Do(httpReq)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
token, baseURL := xaiCreds(auth)
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = xaiauth.DefaultAPIBaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
|
||||||
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
|
||||||
|
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
|
||||||
|
|
||||||
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return resp, errRead
|
||||||
|
}
|
||||||
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(httpResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
|
||||||
|
outputItemsByIndex := make(map[int64][]byte)
|
||||||
|
var outputItemsFallback [][]byte
|
||||||
|
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||||
|
if !bytes.HasPrefix(line, xaiDataTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
|
||||||
|
switch gjson.GetBytes(eventData, "type").String() {
|
||||||
|
case "response.output_item.done":
|
||||||
|
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||||
|
case "response.completed":
|
||||||
|
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||||
|
reporter.Publish(ctx, detail)
|
||||||
|
}
|
||||||
|
completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||||
|
var param any
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, ¶m)
|
||||||
|
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||||
|
token, baseURL := xaiCreds(auth)
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = xaiauth.DefaultAPIBaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
|
||||||
|
defer reporter.TrackFailure(ctx, &err)
|
||||||
|
|
||||||
|
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||||
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
|
||||||
|
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
|
||||||
|
|
||||||
|
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, err := httpClient.Do(httpReq)
|
||||||
|
if err != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
if errRead != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return nil, errRead
|
||||||
|
}
|
||||||
|
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
|
scanner.Buffer(nil, 52_428_800)
|
||||||
|
var param any
|
||||||
|
outputItemsByIndex := make(map[int64][]byte)
|
||||||
|
var outputItemsFallback [][]byte
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
translatedLine := bytes.Clone(line)
|
||||||
|
if bytes.HasPrefix(line, xaiDataTag) {
|
||||||
|
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
|
||||||
|
switch gjson.GetBytes(eventData, "type").String() {
|
||||||
|
case "response.output_item.done":
|
||||||
|
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||||
|
case "response.completed":
|
||||||
|
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||||
|
reporter.Publish(ctx, detail)
|
||||||
|
}
|
||||||
|
eventData = xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||||
|
translatedLine = append([]byte("data: "), eventData...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, ¶m)
|
||||||
|
for i := range chunks {
|
||||||
|
select {
|
||||||
|
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.PublishFailure(ctx, errScan)
|
||||||
|
select {
|
||||||
|
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens estimates token count for xAI Responses requests.
|
||||||
|
func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
prepared, err := e.prepareResponsesRequest(ctx, req, opts, false)
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: tokenizer init failed: %w", err)
|
||||||
|
}
|
||||||
|
count, err := enc.Count(string(prepared.body))
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err)
|
||||||
|
}
|
||||||
|
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||||
|
translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON))
|
||||||
|
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh refreshes xAI OAuth credentials using the stored refresh token.
|
||||||
|
func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
log.Debugf("xai executor: refresh called")
|
||||||
|
if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled {
|
||||||
|
return refreshed, err
|
||||||
|
}
|
||||||
|
if auth == nil {
|
||||||
|
return nil, statusErr{code: http.StatusInternalServerError, msg: "xai executor: auth is nil"}
|
||||||
|
}
|
||||||
|
refreshToken := xaiMetadataString(auth.Metadata, "refresh_token")
|
||||||
|
if refreshToken == "" {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
tokenEndpoint := xaiMetadataString(auth.Metadata, "token_endpoint")
|
||||||
|
svc := xaiauth.NewXAIAuthWithProxyURL(e.cfg, auth.ProxyURL)
|
||||||
|
td, err := svc.RefreshTokens(ctx, refreshToken, tokenEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if auth.Metadata == nil {
|
||||||
|
auth.Metadata = make(map[string]any)
|
||||||
|
}
|
||||||
|
auth.Metadata["type"] = "xai"
|
||||||
|
auth.Metadata["auth_kind"] = "oauth"
|
||||||
|
auth.Metadata["access_token"] = td.AccessToken
|
||||||
|
if td.RefreshToken != "" {
|
||||||
|
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||||
|
}
|
||||||
|
if td.IDToken != "" {
|
||||||
|
auth.Metadata["id_token"] = td.IDToken
|
||||||
|
}
|
||||||
|
if td.TokenType != "" {
|
||||||
|
auth.Metadata["token_type"] = td.TokenType
|
||||||
|
}
|
||||||
|
if td.ExpiresIn > 0 {
|
||||||
|
auth.Metadata["expires_in"] = td.ExpiresIn
|
||||||
|
}
|
||||||
|
if td.Expire != "" {
|
||||||
|
auth.Metadata["expired"] = td.Expire
|
||||||
|
}
|
||||||
|
if td.Email != "" {
|
||||||
|
auth.Metadata["email"] = td.Email
|
||||||
|
}
|
||||||
|
if td.Subject != "" {
|
||||||
|
auth.Metadata["sub"] = td.Subject
|
||||||
|
}
|
||||||
|
if tokenEndpoint != "" {
|
||||||
|
auth.Metadata["token_endpoint"] = tokenEndpoint
|
||||||
|
}
|
||||||
|
if xaiMetadataString(auth.Metadata, "base_url") == "" {
|
||||||
|
auth.Metadata["base_url"] = xaiauth.DefaultAPIBaseURL
|
||||||
|
}
|
||||||
|
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
|
||||||
|
if auth.Attributes == nil {
|
||||||
|
auth.Attributes = make(map[string]string)
|
||||||
|
}
|
||||||
|
auth.Attributes["auth_kind"] = "oauth"
|
||||||
|
if strings.TrimSpace(auth.Attributes["base_url"]) == "" {
|
||||||
|
auth.Attributes["base_url"] = xaiauth.DefaultAPIBaseURL
|
||||||
|
}
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type xaiPreparedRequest struct {
|
||||||
|
baseModel string
|
||||||
|
from sdktranslator.Format
|
||||||
|
to sdktranslator.Format
|
||||||
|
originalPayload []byte
|
||||||
|
body []byte
|
||||||
|
sessionID string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) {
|
||||||
|
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("codex")
|
||||||
|
originalPayloadSource := req.Payload
|
||||||
|
if len(opts.OriginalRequest) > 0 {
|
||||||
|
originalPayloadSource = opts.OriginalRequest
|
||||||
|
}
|
||||||
|
originalPayload := bytes.Clone(originalPayloadSource)
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||||
|
requestPath := helps.PayloadRequestPath(opts)
|
||||||
|
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
body, _ = sjson.SetBytes(body, "stream", stream)
|
||||||
|
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||||
|
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||||
|
body = normalizeCodexInstructions(body)
|
||||||
|
body = sanitizeXAIResponsesBody(body, baseModel)
|
||||||
|
|
||||||
|
sessionID := xaiExecutionSessionID(req, opts)
|
||||||
|
if sessionID != "" {
|
||||||
|
body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &xaiPreparedRequest{
|
||||||
|
baseModel: baseModel,
|
||||||
|
from: from,
|
||||||
|
to: to,
|
||||||
|
originalPayload: originalPayload,
|
||||||
|
body: body,
|
||||||
|
sessionID: sessionID,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *XAIExecutor) recordXAIRequest(ctx context.Context, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: headers,
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func xaiCreds(auth *cliproxyauth.Auth) (token, baseURL string) {
|
||||||
|
if auth == nil {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
token = strings.TrimSpace(auth.Attributes["api_key"])
|
||||||
|
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
|
||||||
|
}
|
||||||
|
if auth.Metadata != nil {
|
||||||
|
if token == "" {
|
||||||
|
token = xaiMetadataString(auth.Metadata, "access_token")
|
||||||
|
}
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = xaiMetadataString(auth.Metadata, "base_url")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return token, baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyXAIHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, sessionID string) {
|
||||||
|
r.Header.Set("Content-Type", "application/json")
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
|
r.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
}
|
||||||
|
if stream {
|
||||||
|
r.Header.Set("Accept", "text/event-stream")
|
||||||
|
} else {
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
}
|
||||||
|
r.Header.Set("Connection", "Keep-Alive")
|
||||||
|
if sessionID != "" {
|
||||||
|
r.Header.Set("x-grok-conv-id", sessionID)
|
||||||
|
}
|
||||||
|
var attrs map[string]string
|
||||||
|
if auth != nil {
|
||||||
|
attrs = auth.Attributes
|
||||||
|
}
|
||||||
|
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) string {
|
||||||
|
if value := xaiMetadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if value := xaiMetadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
||||||
|
return strings.TrimSpace(promptCacheKey.String())
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func xaiMetadataString(meta map[string]any, key string) string {
|
||||||
|
if len(meta) == 0 || key == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
value, ok := meta[key]
|
||||||
|
if !ok || value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(typed)
|
||||||
|
case fmt.Stringer:
|
||||||
|
return strings.TrimSpace(typed.String())
|
||||||
|
default:
|
||||||
|
return strings.TrimSpace(fmt.Sprint(typed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeXAIResponsesBody(body []byte, model string) []byte {
|
||||||
|
body = removeXAIEncryptedReasoningInclude(body)
|
||||||
|
if !xaiSupportsReasoningEffort(model) {
|
||||||
|
body, _ = sjson.DeleteBytes(body, "reasoning")
|
||||||
|
}
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeXAIEncryptedReasoningInclude(body []byte) []byte {
|
||||||
|
include := gjson.GetBytes(body, "include")
|
||||||
|
if !include.Exists() || !include.IsArray() {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
kept := make([]string, 0, len(include.Array()))
|
||||||
|
for _, item := range include.Array() {
|
||||||
|
value := strings.TrimSpace(item.String())
|
||||||
|
if value == "" || value == "reasoning.encrypted_content" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
kept = append(kept, value)
|
||||||
|
}
|
||||||
|
body, _ = sjson.SetBytes(body, "include", kept)
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
func xaiSupportsReasoningEffort(model string) bool {
|
||||||
|
name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName))
|
||||||
|
if idx := strings.LastIndex(name, "/"); idx >= 0 {
|
||||||
|
name = name[idx+1:]
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(name, "grok-3-mini"):
|
||||||
|
return true
|
||||||
|
case strings.HasPrefix(name, "grok-4.20-multi-agent"):
|
||||||
|
return true
|
||||||
|
case strings.HasPrefix(name, "grok-4.3"):
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func xaiCollectOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) {
|
||||||
|
itemResult := gjson.GetBytes(eventData, "item")
|
||||||
|
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
outputIndexResult := gjson.GetBytes(eventData, "output_index")
|
||||||
|
if outputIndexResult.Exists() {
|
||||||
|
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw))
|
||||||
|
}
|
||||||
|
|
||||||
|
func xaiPatchCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte {
|
||||||
|
outputResult := gjson.GetBytes(eventData, "response.output")
|
||||||
|
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
|
||||||
|
if !shouldPatchOutput {
|
||||||
|
return eventData
|
||||||
|
}
|
||||||
|
|
||||||
|
indexes := make([]int64, 0, len(outputItemsByIndex))
|
||||||
|
for idx := range outputItemsByIndex {
|
||||||
|
indexes = append(indexes, idx)
|
||||||
|
}
|
||||||
|
sort.Slice(indexes, func(i, j int) bool {
|
||||||
|
return indexes[i] < indexes[j]
|
||||||
|
})
|
||||||
|
|
||||||
|
outputArray := []byte("[]")
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteByte('[')
|
||||||
|
wrote := false
|
||||||
|
for _, idx := range indexes {
|
||||||
|
if wrote {
|
||||||
|
buf.WriteByte(',')
|
||||||
|
}
|
||||||
|
buf.Write(outputItemsByIndex[idx])
|
||||||
|
wrote = true
|
||||||
|
}
|
||||||
|
for _, item := range outputItemsFallback {
|
||||||
|
if wrote {
|
||||||
|
buf.WriteByte(',')
|
||||||
|
}
|
||||||
|
buf.Write(item)
|
||||||
|
wrote = true
|
||||||
|
}
|
||||||
|
buf.WriteByte(']')
|
||||||
|
if wrote {
|
||||||
|
outputArray = buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
patched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray)
|
||||||
|
return patched
|
||||||
|
}
|
||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) {
|
||||||
|
var gotPath string
|
||||||
|
var gotAuth string
|
||||||
|
var gotGrokConvID string
|
||||||
|
var gotOriginator string
|
||||||
|
var gotAccountID string
|
||||||
|
var gotBody []byte
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotPath = r.URL.Path
|
||||||
|
gotAuth = r.Header.Get("Authorization")
|
||||||
|
gotGrokConvID = r.Header.Get("x-grok-conv-id")
|
||||||
|
gotOriginator = r.Header.Get("Originator")
|
||||||
|
gotAccountID = r.Header.Get("Chatgpt-Account-Id")
|
||||||
|
var errRead error
|
||||||
|
gotBody, errRead = io.ReadAll(r.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
t.Fatalf("read body: %v", errRead)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}],\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewXAIExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
ID: "xai-auth",
|
||||||
|
Provider: "xai",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{
|
||||||
|
"access_token": "xai-token",
|
||||||
|
"email": "user@example.com",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "grok-4.3",
|
||||||
|
Payload: []byte(`{"model":"grok-4.3","input":"hello","include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||||
|
Stream: false,
|
||||||
|
Metadata: map[string]any{
|
||||||
|
cliproxyexecutor.ExecutionSessionMetadataKey: "conv-xai-1",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotPath != "/responses" {
|
||||||
|
t.Fatalf("path = %q, want /responses", gotPath)
|
||||||
|
}
|
||||||
|
if gotAuth != "Bearer xai-token" {
|
||||||
|
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
|
||||||
|
}
|
||||||
|
if gotGrokConvID != "conv-xai-1" {
|
||||||
|
t.Fatalf("x-grok-conv-id = %q, want conv-xai-1", gotGrokConvID)
|
||||||
|
}
|
||||||
|
if gotOriginator != "" {
|
||||||
|
t.Fatalf("Originator = %q, want empty", gotOriginator)
|
||||||
|
}
|
||||||
|
if gotAccountID != "" {
|
||||||
|
t.Fatalf("Chatgpt-Account-Id = %q, want empty", gotAccountID)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "prompt_cache_key").String() != "conv-xai-1" {
|
||||||
|
t.Fatalf("prompt_cache_key missing from body: %s", string(gotBody))
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(gotBody, "stream").Bool() {
|
||||||
|
t.Fatalf("stream = false, want true; body=%s", string(gotBody))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(gotBody, "reasoning.effort").String() != "high" {
|
||||||
|
t.Fatalf("reasoning.effort = %q, want high; body=%s", gjson.GetBytes(gotBody, "reasoning.effort").String(), string(gotBody))
|
||||||
|
}
|
||||||
|
for _, include := range gjson.GetBytes(gotBody, "include").Array() {
|
||||||
|
if include.String() == "reasoning.encrypted_content" {
|
||||||
|
t.Fatalf("xai request must not ask for encrypted reasoning content: %s", string(gotBody))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
|
||||||
|
var gotBody []byte
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var errRead error
|
||||||
|
gotBody, errRead = io.ReadAll(r.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
t.Fatalf("read body: %v", errRead)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
exec := NewXAIExecutor(&config.Config{})
|
||||||
|
auth := &cliproxyauth.Auth{
|
||||||
|
Provider: "xai",
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"base_url": server.URL,
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
Metadata: map[string]any{"access_token": "xai-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||||
|
Model: "grok-4",
|
||||||
|
Payload: []byte(`{"model":"grok-4","input":"hello","reasoning":{"effort":"high"}}`),
|
||||||
|
}, cliproxyexecutor.Options{
|
||||||
|
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||||
|
Stream: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Execute() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.GetBytes(gotBody, "reasoning").Exists() {
|
||||||
|
t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -24,6 +24,7 @@ var oauthProviders = []oauthProvider{
|
|||||||
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
|
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
|
||||||
{"Antigravity", "antigravity-auth-url", "🟪"},
|
{"Antigravity", "antigravity-auth-url", "🟪"},
|
||||||
{"Kimi", "kimi-auth-url", "🟫"},
|
{"Kimi", "kimi-auth-url", "🟫"},
|
||||||
|
{"xAI", "xai-auth-url", "⬛"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// oauthTabModel handles OAuth login flows.
|
// oauthTabModel handles OAuth login flows.
|
||||||
@@ -280,6 +281,8 @@ func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
|
|||||||
providerKey = "antigravity"
|
providerKey = "antigravity"
|
||||||
case "kimi-auth-url":
|
case "kimi-auth-url":
|
||||||
providerKey = "kimi"
|
providerKey = "kimi"
|
||||||
|
case "xai-auth-url":
|
||||||
|
providerKey = "xai"
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ func init() {
|
|||||||
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
|
||||||
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
|
||||||
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
|
||||||
|
registerRefreshLead("xai", func() Authenticator { return NewXAIAuthenticator() })
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerRefreshLead(provider string, factory func() Authenticator) {
|
func registerRefreshLead(provider string, factory func() Authenticator) {
|
||||||
|
|||||||
+282
@@ -0,0 +1,282 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/browser"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/misc"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// XAIAuthenticator implements the xAI Grok OAuth loopback flow.
|
||||||
|
type XAIAuthenticator struct{}
|
||||||
|
|
||||||
|
// NewXAIAuthenticator constructs a new xAI authenticator.
|
||||||
|
func NewXAIAuthenticator() Authenticator {
|
||||||
|
return &XAIAuthenticator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider returns the provider key for xAI.
|
||||||
|
func (XAIAuthenticator) Provider() string {
|
||||||
|
return "xai"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLead instructs the manager to refresh before token expiry.
|
||||||
|
func (XAIAuthenticator) RefreshLead() *time.Duration {
|
||||||
|
lead := xaiauth.RefreshLead()
|
||||||
|
return &lead
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login launches a local OAuth flow to obtain xAI tokens and persists them.
|
||||||
|
func (a XAIAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if opts == nil {
|
||||||
|
opts = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
callbackPort := xaiauth.CallbackPort
|
||||||
|
if opts.CallbackPort > 0 {
|
||||||
|
callbackPort = opts.CallbackPort
|
||||||
|
}
|
||||||
|
|
||||||
|
pkceCodes, err := xaiauth.GeneratePKCECodes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai pkce generation failed: %w", err)
|
||||||
|
}
|
||||||
|
state, err := misc.GenerateRandomState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai state generation failed: %w", err)
|
||||||
|
}
|
||||||
|
nonce, err := misc.GenerateRandomState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("xai nonce generation failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authSvc := xaiauth.NewXAIAuth(cfg)
|
||||||
|
discovery, err := authSvc.Discover(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv, port, callbackCh, errServer := startXAICallbackServer(callbackPort)
|
||||||
|
if errServer != nil {
|
||||||
|
return nil, fmt.Errorf("xai: failed to start callback server: %w", errServer)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if errShutdown := srv.Shutdown(shutdownCtx); errShutdown != nil {
|
||||||
|
log.Warnf("xai callback server shutdown error: %v", errShutdown)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
redirectURI := fmt.Sprintf("http://%s:%d%s", xaiauth.RedirectHost, port, xaiauth.RedirectPath)
|
||||||
|
authURL, err := xaiauth.BuildAuthorizeURL(xaiauth.AuthorizeURLParams{
|
||||||
|
AuthorizationEndpoint: discovery.AuthorizationEndpoint,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
CodeChallenge: pkceCodes.CodeChallenge,
|
||||||
|
State: state,
|
||||||
|
Nonce: nonce,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.NoBrowser {
|
||||||
|
fmt.Println("Opening browser for xAI authentication")
|
||||||
|
if !browser.IsAvailable() {
|
||||||
|
log.Warn("No browser available; please open the URL manually")
|
||||||
|
util.PrintSSHTunnelInstructions(port)
|
||||||
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
|
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
|
||||||
|
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||||
|
util.PrintSSHTunnelInstructions(port)
|
||||||
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
util.PrintSSHTunnelInstructions(port)
|
||||||
|
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Waiting for xAI authentication callback...")
|
||||||
|
|
||||||
|
var result callbackResult
|
||||||
|
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||||
|
defer timeoutTimer.Stop()
|
||||||
|
|
||||||
|
var manualPromptTimer *time.Timer
|
||||||
|
var manualPromptC <-chan time.Time
|
||||||
|
if opts.Prompt != nil {
|
||||||
|
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||||
|
manualPromptC = manualPromptTimer.C
|
||||||
|
defer manualPromptTimer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
var manualInputCh <-chan string
|
||||||
|
var manualInputErrCh <-chan error
|
||||||
|
|
||||||
|
waitForCallback:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case result = <-callbackCh:
|
||||||
|
break waitForCallback
|
||||||
|
case <-manualPromptC:
|
||||||
|
manualPromptC = nil
|
||||||
|
if manualPromptTimer != nil {
|
||||||
|
manualPromptTimer.Stop()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case result = <-callbackCh:
|
||||||
|
break waitForCallback
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
manualInputCh, manualInputErrCh = misc.AsyncPrompt(opts.Prompt, "Paste the xAI callback Token (or press Enter to keep waiting): ")
|
||||||
|
continue
|
||||||
|
case input := <-manualInputCh:
|
||||||
|
manualInputCh = nil
|
||||||
|
manualInputErrCh = nil
|
||||||
|
manualResult, ok, errParse := parseXAIManualCallbackToken(input, state)
|
||||||
|
if errParse != nil {
|
||||||
|
return nil, errParse
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = manualResult
|
||||||
|
break waitForCallback
|
||||||
|
case errManual := <-manualInputErrCh:
|
||||||
|
return nil, errManual
|
||||||
|
case <-timeoutTimer.C:
|
||||||
|
return nil, fmt.Errorf("xai: authentication timed out")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
return nil, fmt.Errorf("xai: authentication failed: %s", result.Error)
|
||||||
|
}
|
||||||
|
if result.State != state {
|
||||||
|
return nil, fmt.Errorf("xai: invalid state")
|
||||||
|
}
|
||||||
|
if result.Code == "" {
|
||||||
|
return nil, fmt.Errorf("xai: missing authorization code")
|
||||||
|
}
|
||||||
|
|
||||||
|
bundle, errExchange := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI, pkceCodes, discovery.TokenEndpoint)
|
||||||
|
if errExchange != nil {
|
||||||
|
return nil, fmt.Errorf("xai: token exchange failed: %w", errExchange)
|
||||||
|
}
|
||||||
|
tokenStorage := authSvc.CreateTokenStorage(bundle)
|
||||||
|
if tokenStorage == nil || strings.TrimSpace(tokenStorage.AccessToken) == "" {
|
||||||
|
return nil, fmt.Errorf("xai token storage missing access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := xaiauth.CredentialFileName(tokenStorage.Email, tokenStorage.Subject)
|
||||||
|
label := strings.TrimSpace(tokenStorage.Email)
|
||||||
|
if label == "" {
|
||||||
|
label = "xAI"
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := map[string]any{
|
||||||
|
"type": "xai",
|
||||||
|
"access_token": tokenStorage.AccessToken,
|
||||||
|
"refresh_token": tokenStorage.RefreshToken,
|
||||||
|
"id_token": tokenStorage.IDToken,
|
||||||
|
"token_type": tokenStorage.TokenType,
|
||||||
|
"expires_in": tokenStorage.ExpiresIn,
|
||||||
|
"expired": tokenStorage.Expire,
|
||||||
|
"last_refresh": tokenStorage.LastRefresh,
|
||||||
|
"base_url": tokenStorage.BaseURL,
|
||||||
|
"redirect_uri": tokenStorage.RedirectURI,
|
||||||
|
"token_endpoint": tokenStorage.TokenEndpoint,
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
}
|
||||||
|
if tokenStorage.Email != "" {
|
||||||
|
metadata["email"] = tokenStorage.Email
|
||||||
|
}
|
||||||
|
if tokenStorage.Subject != "" {
|
||||||
|
metadata["sub"] = tokenStorage.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("xAI authentication successful")
|
||||||
|
|
||||||
|
return &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: a.Provider(),
|
||||||
|
FileName: fileName,
|
||||||
|
Label: label,
|
||||||
|
Storage: tokenStorage,
|
||||||
|
Metadata: metadata,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
"base_url": tokenStorage.BaseURL,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseXAIManualCallbackToken(input string, state string) (callbackResult, bool, error) {
|
||||||
|
token := strings.TrimSpace(input)
|
||||||
|
if token == "" {
|
||||||
|
return callbackResult{}, false, nil
|
||||||
|
}
|
||||||
|
if strings.Contains(token, "://") || strings.Contains(token, "?") || strings.Contains(token, "code=") {
|
||||||
|
return callbackResult{}, false, fmt.Errorf("xai: paste only the callback token")
|
||||||
|
}
|
||||||
|
return callbackResult{Code: token, State: state}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func startXAICallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
|
||||||
|
if port <= 0 {
|
||||||
|
port = xaiauth.CallbackPort
|
||||||
|
}
|
||||||
|
addr := fmt.Sprintf("%s:%d", xaiauth.RedirectHost, port)
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, nil, err
|
||||||
|
}
|
||||||
|
port = listener.Addr().(*net.TCPAddr).Port
|
||||||
|
resultCh := make(chan callbackResult, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc(xaiauth.RedirectPath, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
q := r.URL.Query()
|
||||||
|
result := callbackResult{
|
||||||
|
Code: strings.TrimSpace(q.Get("code")),
|
||||||
|
Error: strings.TrimSpace(q.Get("error")),
|
||||||
|
State: strings.TrimSpace(q.Get("state")),
|
||||||
|
}
|
||||||
|
resultCh <- result
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
if result.Code != "" && result.Error == "" {
|
||||||
|
_, _ = w.Write([]byte("<h1>Login successful</h1><p>You can close this window.</p>"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte("<h1>Login failed</h1><p>Please check the CLI output.</p>"))
|
||||||
|
})
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
WriteTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
|
||||||
|
log.Warnf("xai callback server error: %v", errServe)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return srv, port, resultCh, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestXAIAuthenticatorProviderAndRefreshLead(t *testing.T) {
|
||||||
|
authenticator := NewXAIAuthenticator()
|
||||||
|
if authenticator.Provider() != "xai" {
|
||||||
|
t.Fatalf("Provider() = %q, want xai", authenticator.Provider())
|
||||||
|
}
|
||||||
|
lead := authenticator.RefreshLead()
|
||||||
|
if lead == nil || *lead <= 0 {
|
||||||
|
t.Fatalf("RefreshLead() = %v, want positive duration", lead)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseXAIManualCallbackTokenAcceptsRawCode(t *testing.T) {
|
||||||
|
result, ok, err := parseXAIManualCallbackToken(" V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg ", "state-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseXAIManualCallbackToken() error = %v", err)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("parseXAIManualCallbackToken() ok = false, want true")
|
||||||
|
}
|
||||||
|
if result.Code != "V0auoESADonzF4bY_Ag2whBFnVeqzHJm6nW2uW012rqCCW5cstFV58qvDFBvnPBXXe0rZSKOcs3PwwfACKp1qg" {
|
||||||
|
t.Fatalf("Code = %q", result.Code)
|
||||||
|
}
|
||||||
|
if result.State != "state-1" {
|
||||||
|
t.Fatalf("State = %q, want state-1", result.State)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseXAIManualCallbackTokenRejectsCallbackURL(t *testing.T) {
|
||||||
|
_, _, err := parseXAIManualCallbackToken("http://127.0.0.1:56121/callback?state=state-1&code=token-1", "state-1")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("parseXAIManualCallbackToken() error = nil, want error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -116,6 +116,7 @@ func newDefaultAuthManager() *sdkAuth.Manager {
|
|||||||
sdkAuth.NewGeminiAuthenticator(),
|
sdkAuth.NewGeminiAuthenticator(),
|
||||||
sdkAuth.NewCodexAuthenticator(),
|
sdkAuth.NewCodexAuthenticator(),
|
||||||
sdkAuth.NewClaudeAuthenticator(),
|
sdkAuth.NewClaudeAuthenticator(),
|
||||||
|
sdkAuth.NewXAIAuthenticator(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -433,6 +434,8 @@ func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace
|
|||||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||||
case "kimi":
|
case "kimi":
|
||||||
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
|
||||||
|
case "xai":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewXAIExecutor(s.cfg))
|
||||||
default:
|
default:
|
||||||
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
if providerKey == "" {
|
if providerKey == "" {
|
||||||
@@ -1156,6 +1159,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
case "kimi":
|
case "kimi":
|
||||||
models = registry.GetKimiModels()
|
models = registry.GetKimiModels()
|
||||||
models = applyExcludedModels(models, excluded)
|
models = applyExcludedModels(models, excluded)
|
||||||
|
case "xai":
|
||||||
|
models = registry.GetXAIModels()
|
||||||
|
models = applyExcludedModels(models, excluded)
|
||||||
default:
|
default:
|
||||||
// Handle OpenAI-compatibility providers by name using config
|
// Handle OpenAI-compatibility providers by name using config
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
package cliproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnsureExecutorsForAuth_XAIBindsIndependentExecutor(t *testing.T) {
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||||
|
}
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "xai-auth-1",
|
||||||
|
Provider: "xai",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"auth_kind": "oauth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
service.ensureExecutorsForAuth(auth)
|
||||||
|
resolved, ok := service.coreManager.Executor("xai")
|
||||||
|
if !ok || resolved == nil {
|
||||||
|
t.Fatal("expected xai executor after bind")
|
||||||
|
}
|
||||||
|
if _, isXAI := resolved.(*executor.XAIExecutor); !isXAI {
|
||||||
|
t.Fatalf("executor type = %T, want *executor.XAIExecutor", resolved)
|
||||||
|
}
|
||||||
|
if _, isCodex := resolved.(*executor.CodexAutoExecutor); isCodex {
|
||||||
|
t.Fatal("xai must not bind the codex auto executor")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user