feat(auth): add OAuth2 support for xAI with PKCE and token persistence
- Implemented xAI OAuth2 integration with PKCE (Proof Key for Code Exchange) support. - Added logic for token exchange, refresh, and persistent storage in JSON format. - Created `xai` package with helpers for OAuth discovery, API token handling, and URL building. - Introduced `XAIExecutor` for integrating xAI credentials into runtime HTTP requests. - Added unit tests to validate OAuth flow, token persistence, and endpoint validation.
This commit is contained in:
@@ -0,0 +1,570 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
xaiauth "github.com/router-for-me/CLIProxyAPI/v7/internal/auth/xai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/runtime/executor/helps"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
)
|
||||
|
||||
var xaiDataTag = []byte("data:")
|
||||
|
||||
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
|
||||
type XAIExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewXAIExecutor creates a new xAI executor.
|
||||
func NewXAIExecutor(cfg *config.Config) *XAIExecutor {
|
||||
return &XAIExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns the provider identifier.
|
||||
func (e *XAIExecutor) Identifier() string {
|
||||
return "xai"
|
||||
}
|
||||
|
||||
// PrepareRequest injects xAI credentials into the outgoing HTTP request.
|
||||
func (e *XAIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token, _ := xaiCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects xAI credentials into the request and executes it.
|
||||
func (e *XAIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("xai executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if errPrepare := e.PrepareRequest(httpReq, auth); errPrepare != nil {
|
||||
return nil, errPrepare
|
||||
}
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
|
||||
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
|
||||
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for _, line := range bytes.Split(data, []byte("\n")) {
|
||||
if !bytes.HasPrefix(line, xaiDataTag) {
|
||||
continue
|
||||
}
|
||||
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
|
||||
switch gjson.GetBytes(eventData, "type").String() {
|
||||
case "response.output_item.done":
|
||||
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||
case "response.completed":
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
completedData := xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, completedData, ¶m)
|
||||
return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return resp, statusErr{code: http.StatusRequestTimeout, msg: "xai stream error: stream disconnected before response.completed"}
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
token, baseURL := xaiCreds(auth)
|
||||
if baseURL == "" {
|
||||
baseURL = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
|
||||
prepared, err := e.prepareResponsesRequest(ctx, req, opts, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), prepared.baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + "/responses"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(prepared.body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyXAIHeaders(httpReq, auth, token, true, prepared.sessionID)
|
||||
e.recordXAIRequest(ctx, auth, url, httpReq.Header.Clone(), prepared.body)
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return nil, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("xai executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 52_428_800)
|
||||
var param any
|
||||
outputItemsByIndex := make(map[int64][]byte)
|
||||
var outputItemsFallback [][]byte
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, line)
|
||||
translatedLine := bytes.Clone(line)
|
||||
if bytes.HasPrefix(line, xaiDataTag) {
|
||||
eventData := bytes.TrimSpace(line[len(xaiDataTag):])
|
||||
switch gjson.GetBytes(eventData, "type").String() {
|
||||
case "response.output_item.done":
|
||||
xaiCollectOutputItemDone(eventData, outputItemsByIndex, &outputItemsFallback)
|
||||
case "response.completed":
|
||||
if detail, ok := helps.ParseCodexUsage(eventData); ok {
|
||||
reporter.Publish(ctx, detail)
|
||||
}
|
||||
eventData = xaiPatchCompletedOutput(eventData, outputItemsByIndex, outputItemsFallback)
|
||||
translatedLine = append([]byte("data: "), eventData...)
|
||||
}
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, prepared.to, prepared.from, req.Model, prepared.originalPayload, prepared.body, translatedLine, ¶m)
|
||||
for i := range chunks {
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunks[i]}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.PublishFailure(ctx, errScan)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errScan}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for xAI Responses requests.
|
||||
func (e *XAIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
prepared, err := e.prepareResponsesRequest(ctx, req, opts, false)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
enc, err := tokenizer.Get(tokenizer.Cl100kBase)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
count, err := enc.Count(string(prepared.body))
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("xai executor: token counting failed: %w", err)
|
||||
}
|
||||
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, prepared.to, prepared.from, int64(count), []byte(usageJSON))
|
||||
return cliproxyexecutor.Response{Payload: translated}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes xAI OAuth credentials using the stored refresh token.
|
||||
func (e *XAIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("xai executor: refresh called")
|
||||
if refreshed, handled, err := helps.RefreshAuthViaHome(ctx, e.cfg, auth); handled {
|
||||
return refreshed, err
|
||||
}
|
||||
if auth == nil {
|
||||
return nil, statusErr{code: http.StatusInternalServerError, msg: "xai executor: auth is nil"}
|
||||
}
|
||||
refreshToken := xaiMetadataString(auth.Metadata, "refresh_token")
|
||||
if refreshToken == "" {
|
||||
return auth, nil
|
||||
}
|
||||
tokenEndpoint := xaiMetadataString(auth.Metadata, "token_endpoint")
|
||||
svc := xaiauth.NewXAIAuthWithProxyURL(e.cfg, auth.ProxyURL)
|
||||
td, err := svc.RefreshTokens(ctx, refreshToken, tokenEndpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["type"] = "xai"
|
||||
auth.Metadata["auth_kind"] = "oauth"
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.IDToken != "" {
|
||||
auth.Metadata["id_token"] = td.IDToken
|
||||
}
|
||||
if td.TokenType != "" {
|
||||
auth.Metadata["token_type"] = td.TokenType
|
||||
}
|
||||
if td.ExpiresIn > 0 {
|
||||
auth.Metadata["expires_in"] = td.ExpiresIn
|
||||
}
|
||||
if td.Expire != "" {
|
||||
auth.Metadata["expired"] = td.Expire
|
||||
}
|
||||
if td.Email != "" {
|
||||
auth.Metadata["email"] = td.Email
|
||||
}
|
||||
if td.Subject != "" {
|
||||
auth.Metadata["sub"] = td.Subject
|
||||
}
|
||||
if tokenEndpoint != "" {
|
||||
auth.Metadata["token_endpoint"] = tokenEndpoint
|
||||
}
|
||||
if xaiMetadataString(auth.Metadata, "base_url") == "" {
|
||||
auth.Metadata["base_url"] = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
auth.Metadata["last_refresh"] = time.Now().UTC().Format(time.RFC3339)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
auth.Attributes["auth_kind"] = "oauth"
|
||||
if strings.TrimSpace(auth.Attributes["base_url"]) == "" {
|
||||
auth.Attributes["base_url"] = xaiauth.DefaultAPIBaseURL
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
type xaiPreparedRequest struct {
|
||||
baseModel string
|
||||
from sdktranslator.Format
|
||||
to sdktranslator.Format
|
||||
originalPayload []byte
|
||||
body []byte
|
||||
sessionID string
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) prepareResponsesRequest(ctx context.Context, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) (*xaiPreparedRequest, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayloadSource := req.Payload
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
}
|
||||
originalPayload := bytes.Clone(originalPayloadSource)
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
|
||||
var err error
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
requestedModel := helps.PayloadRequestedModel(opts, req.Model)
|
||||
requestPath := helps.PayloadRequestPath(opts)
|
||||
body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel, requestPath)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
body, _ = sjson.SetBytes(body, "stream", stream)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
|
||||
body, _ = sjson.DeleteBytes(body, "safety_identifier")
|
||||
body, _ = sjson.DeleteBytes(body, "stream_options")
|
||||
body = normalizeCodexInstructions(body)
|
||||
body = sanitizeXAIResponsesBody(body, baseModel)
|
||||
|
||||
sessionID := xaiExecutionSessionID(req, opts)
|
||||
if sessionID != "" {
|
||||
body, _ = sjson.SetBytes(body, "prompt_cache_key", sessionID)
|
||||
}
|
||||
|
||||
return &xaiPreparedRequest{
|
||||
baseModel: baseModel,
|
||||
from: from,
|
||||
to: to,
|
||||
originalPayload: originalPayload,
|
||||
body: body,
|
||||
sessionID: sessionID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *XAIExecutor) recordXAIRequest(ctx context.Context, auth *cliproxyauth.Auth, url string, headers http.Header, body []byte) {
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
}
|
||||
|
||||
func xaiCreds(auth *cliproxyauth.Auth) (token, baseURL string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
token = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if token == "" {
|
||||
token = xaiMetadataString(auth.Metadata, "access_token")
|
||||
}
|
||||
if baseURL == "" {
|
||||
baseURL = xaiMetadataString(auth.Metadata, "base_url")
|
||||
}
|
||||
}
|
||||
return token, baseURL
|
||||
}
|
||||
|
||||
func applyXAIHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, sessionID string) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(token) != "" {
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
} else {
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
r.Header.Set("Connection", "Keep-Alive")
|
||||
if sessionID != "" {
|
||||
r.Header.Set("x-grok-conv-id", sessionID)
|
||||
}
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(r, attrs)
|
||||
}
|
||||
|
||||
func xaiExecutionSessionID(req cliproxyexecutor.Request, opts cliproxyexecutor.Options) string {
|
||||
if value := xaiMetadataString(opts.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
|
||||
return value
|
||||
}
|
||||
if value := xaiMetadataString(req.Metadata, cliproxyexecutor.ExecutionSessionMetadataKey); value != "" {
|
||||
return value
|
||||
}
|
||||
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
|
||||
return strings.TrimSpace(promptCacheKey.String())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func xaiMetadataString(meta map[string]any, key string) string {
|
||||
if len(meta) == 0 || key == "" {
|
||||
return ""
|
||||
}
|
||||
value, ok := meta[key]
|
||||
if !ok || value == nil {
|
||||
return ""
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case fmt.Stringer:
|
||||
return strings.TrimSpace(typed.String())
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(typed))
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeXAIResponsesBody(body []byte, model string) []byte {
|
||||
body = removeXAIEncryptedReasoningInclude(body)
|
||||
if !xaiSupportsReasoningEffort(model) {
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning")
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func removeXAIEncryptedReasoningInclude(body []byte) []byte {
|
||||
include := gjson.GetBytes(body, "include")
|
||||
if !include.Exists() || !include.IsArray() {
|
||||
return body
|
||||
}
|
||||
kept := make([]string, 0, len(include.Array()))
|
||||
for _, item := range include.Array() {
|
||||
value := strings.TrimSpace(item.String())
|
||||
if value == "" || value == "reasoning.encrypted_content" {
|
||||
continue
|
||||
}
|
||||
kept = append(kept, value)
|
||||
}
|
||||
body, _ = sjson.SetBytes(body, "include", kept)
|
||||
return body
|
||||
}
|
||||
|
||||
func xaiSupportsReasoningEffort(model string) bool {
|
||||
name := strings.ToLower(strings.TrimSpace(thinking.ParseSuffix(model).ModelName))
|
||||
if idx := strings.LastIndex(name, "/"); idx >= 0 {
|
||||
name = name[idx+1:]
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(name, "grok-3-mini"):
|
||||
return true
|
||||
case strings.HasPrefix(name, "grok-4.20-multi-agent"):
|
||||
return true
|
||||
case strings.HasPrefix(name, "grok-4.3"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func xaiCollectOutputItemDone(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback *[][]byte) {
|
||||
itemResult := gjson.GetBytes(eventData, "item")
|
||||
if !itemResult.Exists() || itemResult.Type != gjson.JSON {
|
||||
return
|
||||
}
|
||||
outputIndexResult := gjson.GetBytes(eventData, "output_index")
|
||||
if outputIndexResult.Exists() {
|
||||
outputItemsByIndex[outputIndexResult.Int()] = []byte(itemResult.Raw)
|
||||
return
|
||||
}
|
||||
*outputItemsFallback = append(*outputItemsFallback, []byte(itemResult.Raw))
|
||||
}
|
||||
|
||||
func xaiPatchCompletedOutput(eventData []byte, outputItemsByIndex map[int64][]byte, outputItemsFallback [][]byte) []byte {
|
||||
outputResult := gjson.GetBytes(eventData, "response.output")
|
||||
shouldPatchOutput := (!outputResult.Exists() || !outputResult.IsArray() || len(outputResult.Array()) == 0) && (len(outputItemsByIndex) > 0 || len(outputItemsFallback) > 0)
|
||||
if !shouldPatchOutput {
|
||||
return eventData
|
||||
}
|
||||
|
||||
indexes := make([]int64, 0, len(outputItemsByIndex))
|
||||
for idx := range outputItemsByIndex {
|
||||
indexes = append(indexes, idx)
|
||||
}
|
||||
sort.Slice(indexes, func(i, j int) bool {
|
||||
return indexes[i] < indexes[j]
|
||||
})
|
||||
|
||||
outputArray := []byte("[]")
|
||||
var buf bytes.Buffer
|
||||
buf.WriteByte('[')
|
||||
wrote := false
|
||||
for _, idx := range indexes {
|
||||
if wrote {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
buf.Write(outputItemsByIndex[idx])
|
||||
wrote = true
|
||||
}
|
||||
for _, item := range outputItemsFallback {
|
||||
if wrote {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
buf.Write(item)
|
||||
wrote = true
|
||||
}
|
||||
buf.WriteByte(']')
|
||||
if wrote {
|
||||
outputArray = buf.Bytes()
|
||||
}
|
||||
|
||||
patched, _ := sjson.SetRawBytes(eventData, "response.output", outputArray)
|
||||
return patched
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v7/internal/translator"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v7/sdk/translator"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestXAIExecutorExecuteShapesResponsesRequest(t *testing.T) {
|
||||
var gotPath string
|
||||
var gotAuth string
|
||||
var gotGrokConvID string
|
||||
var gotOriginator string
|
||||
var gotAccountID string
|
||||
var gotBody []byte
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotPath = r.URL.Path
|
||||
gotAuth = r.Header.Get("Authorization")
|
||||
gotGrokConvID = r.Header.Get("x-grok-conv-id")
|
||||
gotOriginator = r.Header.Get("Originator")
|
||||
gotAccountID = r.Header.Get("Chatgpt-Account-Id")
|
||||
var errRead error
|
||||
gotBody, errRead = io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read body: %v", errRead)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4.3\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}],\"usage\":{\"input_tokens\":1,\"output_tokens\":1,\"total_tokens\":2}}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
ID: "xai-auth",
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
Metadata: map[string]any{
|
||||
"access_token": "xai-token",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-4.3",
|
||||
Payload: []byte(`{"model":"grok-4.3","input":"hello","include":["reasoning.encrypted_content"],"reasoning":{"effort":"high"}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Stream: false,
|
||||
Metadata: map[string]any{
|
||||
cliproxyexecutor.ExecutionSessionMetadataKey: "conv-xai-1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if gotPath != "/responses" {
|
||||
t.Fatalf("path = %q, want /responses", gotPath)
|
||||
}
|
||||
if gotAuth != "Bearer xai-token" {
|
||||
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
|
||||
}
|
||||
if gotGrokConvID != "conv-xai-1" {
|
||||
t.Fatalf("x-grok-conv-id = %q, want conv-xai-1", gotGrokConvID)
|
||||
}
|
||||
if gotOriginator != "" {
|
||||
t.Fatalf("Originator = %q, want empty", gotOriginator)
|
||||
}
|
||||
if gotAccountID != "" {
|
||||
t.Fatalf("Chatgpt-Account-Id = %q, want empty", gotAccountID)
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "prompt_cache_key").String() != "conv-xai-1" {
|
||||
t.Fatalf("prompt_cache_key missing from body: %s", string(gotBody))
|
||||
}
|
||||
if !gjson.GetBytes(gotBody, "stream").Bool() {
|
||||
t.Fatalf("stream = false, want true; body=%s", string(gotBody))
|
||||
}
|
||||
if gjson.GetBytes(gotBody, "reasoning.effort").String() != "high" {
|
||||
t.Fatalf("reasoning.effort = %q, want high; body=%s", gjson.GetBytes(gotBody, "reasoning.effort").String(), string(gotBody))
|
||||
}
|
||||
for _, include := range gjson.GetBytes(gotBody, "include").Array() {
|
||||
if include.String() == "reasoning.encrypted_content" {
|
||||
t.Fatalf("xai request must not ask for encrypted reasoning content: %s", string(gotBody))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestXAIExecutorOmitsUnsupportedReasoningEffort(t *testing.T) {
|
||||
var gotBody []byte
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var errRead error
|
||||
gotBody, errRead = io.ReadAll(r.Body)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read body: %v", errRead)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
_, _ = w.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"object\":\"response\",\"created_at\":0,\"status\":\"completed\",\"model\":\"grok-4\",\"output\":[{\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"output_text\",\"text\":\"ok\"}]}]}}\n\n"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
exec := NewXAIExecutor(&config.Config{})
|
||||
auth := &cliproxyauth.Auth{
|
||||
Provider: "xai",
|
||||
Attributes: map[string]string{
|
||||
"base_url": server.URL,
|
||||
"auth_kind": "oauth",
|
||||
},
|
||||
Metadata: map[string]any{"access_token": "xai-token"},
|
||||
}
|
||||
|
||||
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
|
||||
Model: "grok-4",
|
||||
Payload: []byte(`{"model":"grok-4","input":"hello","reasoning":{"effort":"high"}}`),
|
||||
}, cliproxyexecutor.Options{
|
||||
SourceFormat: sdktranslator.FormatOpenAIResponse,
|
||||
Stream: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Execute() error = %v", err)
|
||||
}
|
||||
|
||||
if gjson.GetBytes(gotBody, "reasoning").Exists() {
|
||||
t.Fatalf("unsupported xAI model must omit reasoning key: %s", string(gotBody))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user