feat(api, xai): add xAI Grok video model support with API integration

- Introduced new xAI `grok-imagine-video` model for video generation with configurable options (e.g., duration, size, resolution).
- Implemented video-specific API endpoints (`/v1/videos`, `/v1/videos/generations`, `/v1/videos/edits`, `/v1/videos/extensions`), including request validation and model handling.
- Enhanced model registry with `xaiBuiltinVideoModelID` and metadata for video capabilities.
- Added unit tests to validate video model support, request structures, and API response handling.
- Extended `XAIExecutor` to integrate video generation and retrieval via runtime requests.
This commit is contained in:
Luis Pater
2026-05-17 02:53:50 +08:00
parent 2ff9e33e26
commit 53d1fd6c5c
9 changed files with 1130 additions and 2 deletions
+5
View File
@@ -387,6 +387,11 @@ func (s *Server) setupRoutes() {
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/images/generations", openaiHandlers.ImagesGenerations)
v1.POST("/images/edits", openaiHandlers.ImagesEdits)
v1.POST("/videos", openaiHandlers.VideosCreate)
v1.POST("/videos/generations", openaiHandlers.XAIVideosGenerations)
v1.POST("/videos/edits", openaiHandlers.XAIVideosEdits)
v1.POST("/videos/extensions", openaiHandlers.XAIVideosExtensions)
v1.GET("/videos/:request_id", openaiHandlers.XAIVideosRetrieve)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
+1
View File
@@ -21,6 +21,7 @@ var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/images",
"/v1/videos",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
+6
View File
@@ -66,4 +66,10 @@ func TestIsAIAPIPathIncludesImages(t *testing.T) {
if !isAIAPIPath("/v1/images/edits") {
t.Fatalf("expected /v1/images/edits to be treated as AI API path")
}
if !isAIAPIPath("/v1/videos") {
t.Fatalf("expected /v1/videos to be treated as AI API path")
}
if !isAIAPIPath("/v1/videos/video_123") {
t.Fatalf("expected /v1/videos/video_123 to be treated as AI API path")
}
}
+16 -2
View File
@@ -10,6 +10,7 @@ const (
codexBuiltinImageModelID = "gpt-image-2"
xaiBuiltinImageModelID = "grok-imagine-image"
xaiBuiltinImageQualityModelID = "grok-imagine-image-quality"
xaiBuiltinVideoModelID = "grok-imagine-video"
)
// staticModelsJSON mirrors the top-level structure of models.json.
@@ -95,10 +96,10 @@ func WithCodexBuiltins(models []*ModelInfo) []*ModelInfo {
return upsertModelInfos(models, codexBuiltinImageModelInfo())
}
// WithXAIBuiltins injects hard-coded xAI image model definitions that should
// WithXAIBuiltins injects hard-coded xAI image/video model definitions that should
// not depend on remote models.json updates.
func WithXAIBuiltins(models []*ModelInfo) []*ModelInfo {
return upsertModelInfos(models, xaiBuiltinImageModelInfo(), xaiBuiltinImageQualityModelInfo())
return upsertModelInfos(models, xaiBuiltinImageModelInfo(), xaiBuiltinImageQualityModelInfo(), xaiBuiltinVideoModelInfo())
}
func codexBuiltinImageModelInfo() *ModelInfo {
@@ -139,6 +140,19 @@ func xaiBuiltinImageQualityModelInfo() *ModelInfo {
}
}
func xaiBuiltinVideoModelInfo() *ModelInfo {
return &ModelInfo{
ID: xaiBuiltinVideoModelID,
Object: "model",
Created: 1735689600, // 2025-01-01
OwnedBy: "xai",
Type: "xai",
DisplayName: "Grok Imagine Video",
Name: xaiBuiltinVideoModelID,
Description: "xAI Grok video generation model.",
}
}
func upsertModelInfos(models []*ModelInfo, extras ...*ModelInfo) []*ModelInfo {
if len(extras) == 0 {
return models
@@ -33,6 +33,22 @@ func TestCodexStaticModelsIncludeGPT55(t *testing.T) {
assertGPT55ModelInfo(t, "lookup", model)
}
func TestWithXAIBuiltinsAddsVideoModel(t *testing.T) {
models := WithXAIBuiltins(nil)
found := false
for _, model := range models {
if model != nil && model.ID == xaiBuiltinVideoModelID {
found = true
if model.OwnedBy != "xai" {
t.Fatalf("OwnedBy = %q, want xai", model.OwnedBy)
}
}
}
if !found {
t.Fatalf("expected %s builtin model", xaiBuiltinVideoModelID)
}
}
func findModelInfo(models []*ModelInfo, id string) *ModelInfo {
for _, model := range models {
if model != nil && model.ID == id {
+96
View File
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"time"
@@ -29,9 +30,15 @@ var xaiDataTag = []byte("data:")
const (
xaiImageHandlerType = "openai-image"
xaiVideoHandlerType = "openai-video"
xaiImagesGenerationsPath = "/images/generations"
xaiImagesEditsPath = "/images/edits"
xaiDefaultImageEndpointPath = xaiImagesGenerationsPath
xaiVideosGenerationsPath = "/videos/generations"
xaiVideosEditsPath = "/videos/edits"
xaiVideosExtensionsPath = "/videos/extensions"
xaiVideosPath = "/videos"
xaiIdempotencyKeyMetaKey = "idempotency_key"
)
// XAIExecutor is a stateless executor for xAI Grok's Responses API.
@@ -86,6 +93,9 @@ func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
if endpointPath := xaiImageEndpointPath(opts); endpointPath != "" {
return e.executeImages(ctx, auth, req, endpointPath)
}
if xaiIsVideoRequest(opts) {
return e.executeVideos(ctx, auth, req, opts)
}
token, baseURL := xaiCreds(auth)
if baseURL == "" {
@@ -207,6 +217,71 @@ func (e *XAIExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth
return cliproxyexecutor.Response{Payload: data, Headers: httpResp.Header.Clone()}, nil
}
func (e *XAIExecutor) executeVideos(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
token, baseURL := xaiCreds(auth)
if baseURL == "" {
baseURL = xaiauth.DefaultAPIBaseURL
}
method := http.MethodPost
endpointPath := xaiVideosGenerationsPath
var body io.Reader = bytes.NewReader(req.Payload)
switch path := xaiVideoEndpointPath(opts); path {
case xaiVideosGenerationsPath, xaiVideosEditsPath, xaiVideosExtensionsPath:
endpointPath = path
default:
if requestID := strings.TrimSpace(gjson.GetBytes(req.Payload, "request_id").String()); requestID != "" {
method = http.MethodGet
endpointPath = xaiVideosPath + "/" + url.PathEscape(requestID)
body = nil
}
}
requestURL := strings.TrimSuffix(baseURL, "/") + endpointPath
httpReq, err := http.NewRequestWithContext(ctx, method, requestURL, body)
if err != nil {
return resp, err
}
applyXAIHeaders(httpReq, auth, token, false, "")
if method == http.MethodPost {
key := xaiMetadataString(opts.Metadata, xaiIdempotencyKeyMetaKey)
if key == "" && opts.Headers != nil {
key = strings.TrimSpace(opts.Headers.Get("x-idempotency-key"))
}
if key != "" {
httpReq.Header.Set("x-idempotency-key", key)
}
}
e.recordXAIRequest(ctx, auth, requestURL, httpReq.Header.Clone(), req.Payload)
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("xai executor: close response body error: %v", errClose)
}
}()
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
data, err := io.ReadAll(httpResp.Body)
if err != nil {
helps.RecordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
helps.AppendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return resp, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
return cliproxyexecutor.Response{Payload: data, Headers: httpResp.Header.Clone()}, nil
}
func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
token, baseURL := xaiCreds(auth)
if baseURL == "" {
@@ -525,6 +600,27 @@ func xaiImageEndpointPath(opts cliproxyexecutor.Options) string {
return xaiDefaultImageEndpointPath
}
func xaiIsVideoRequest(opts cliproxyexecutor.Options) bool {
return opts.SourceFormat.String() == xaiVideoHandlerType
}
func xaiVideoEndpointPath(opts cliproxyexecutor.Options) string {
if !xaiIsVideoRequest(opts) {
return ""
}
path := xaiMetadataString(opts.Metadata, cliproxyexecutor.RequestPathMetadataKey)
if strings.HasSuffix(path, "/videos/edits") {
return xaiVideosEditsPath
}
if strings.HasSuffix(path, "/videos/extensions") {
return xaiVideosExtensionsPath
}
if strings.HasSuffix(path, "/videos/generations") {
return xaiVideosGenerationsPath
}
return ""
}
func xaiMetadataString(meta map[string]any, key string) string {
if len(meta) == 0 || key == "" {
return ""
@@ -229,3 +229,168 @@ func TestXAIExecutorExecuteImagesUsesEditsEndpoint(t *testing.T) {
t.Fatalf("path = %q, want /images/edits", gotPath)
}
}
func TestXAIExecutorExecuteVideosCreate(t *testing.T) {
var gotPath string
var gotMethod string
var gotAuth string
var gotIdempotencyKey string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotMethod = r.Method
gotAuth = r.Header.Get("Authorization")
gotIdempotencyKey = r.Header.Get("x-idempotency-key")
var errRead error
gotBody, errRead = io.ReadAll(r.Body)
if errRead != nil {
t.Fatalf("read body: %v", errRead)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"request_id":"vid_123"}`))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "xai",
Attributes: map[string]string{"base_url": server.URL},
Metadata: map[string]any{"access_token": "xai-token"},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-imagine-video",
Payload: []byte(`{"model":"grok-imagine-video","prompt":"animate","duration":4}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-video"),
Metadata: map[string]any{
"idempotency_key": "idem-123",
},
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotMethod != http.MethodPost {
t.Fatalf("method = %q, want POST", gotMethod)
}
if gotPath != "/videos/generations" {
t.Fatalf("path = %q, want /videos/generations", gotPath)
}
if gotAuth != "Bearer xai-token" {
t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth)
}
if gotIdempotencyKey != "idem-123" {
t.Fatalf("x-idempotency-key = %q, want idem-123", gotIdempotencyKey)
}
if string(gotBody) != `{"model":"grok-imagine-video","prompt":"animate","duration":4}` {
t.Fatalf("body = %s", string(gotBody))
}
if gjson.GetBytes(resp.Payload, "request_id").String() != "vid_123" {
t.Fatalf("payload = %s", string(resp.Payload))
}
}
func TestXAIExecutorExecuteVideosRetrieve(t *testing.T) {
var gotPath string
var gotMethod string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotMethod = r.Method
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"done","video":{"url":"https://vidgen.x.ai/video.mp4","duration":6},"model":"grok-imagine-video","progress":100}`))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "xai",
Attributes: map[string]string{"base_url": server.URL},
Metadata: map[string]any{"access_token": "xai-token"},
}
resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-imagine-video",
Payload: []byte(`{"request_id":"vid_123"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-video"),
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotMethod != http.MethodGet {
t.Fatalf("method = %q, want GET", gotMethod)
}
if gotPath != "/videos/vid_123" {
t.Fatalf("path = %q, want /videos/vid_123", gotPath)
}
if gjson.GetBytes(resp.Payload, "video.url").String() != "https://vidgen.x.ai/video.mp4" {
t.Fatalf("payload = %s", string(resp.Payload))
}
}
func TestXAIExecutorExecuteVideosUsesNativeEndpointFromRequestPath(t *testing.T) {
tests := []struct {
name string
requestPath string
wantPath string
}{
{
name: "generations",
requestPath: "/v1/videos/generations",
wantPath: "/videos/generations",
},
{
name: "edits",
requestPath: "/v1/videos/edits",
wantPath: "/videos/edits",
},
{
name: "extensions",
requestPath: "/v1/videos/extensions",
wantPath: "/videos/extensions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var gotPath string
var gotMethod string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotMethod = r.Method
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"request_id":"vid_123"}`))
}))
defer server.Close()
exec := NewXAIExecutor(&config.Config{})
auth := &cliproxyauth.Auth{
Provider: "xai",
Attributes: map[string]string{"base_url": server.URL},
Metadata: map[string]any{"access_token": "xai-token"},
}
_, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "grok-imagine-video",
Payload: []byte(`{"model":"grok-imagine-video","prompt":"animate"}`),
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-video"),
Metadata: map[string]any{
cliproxyexecutor.RequestPathMetadataKey: tt.requestPath,
},
})
if err != nil {
t.Fatalf("Execute() error = %v", err)
}
if gotMethod != http.MethodPost {
t.Fatalf("method = %q, want POST", gotMethod)
}
if gotPath != tt.wantPath {
t.Fatalf("path = %q, want %s", gotPath, tt.wantPath)
}
})
}
}
@@ -0,0 +1,598 @@
package openai
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
videosPath = "/v1/videos"
xaiVideosGenerationsAPI = "/v1/videos/generations"
xaiVideosEditsAPI = "/v1/videos/edits"
xaiVideosExtensionsAPI = "/v1/videos/extensions"
defaultXAIVideosModel = "grok-imagine-video"
xaiVideosHandlerType = "openai-video"
defaultVideosSeconds = "4"
defaultVideosSize = "720x1280"
defaultVideosResolution = "720p"
maxXAIVideoReferences = 7
)
type xaiVideoCreateMetadata struct {
Model string
Prompt string
Seconds string
Size string
CreatedAt int64
}
func videosModelBase(model string) string {
_, baseModel := imagesModelParts(model)
return strings.ToLower(strings.TrimSpace(baseModel))
}
func isXAIVideosModel(model string) bool {
prefix, baseModel := imagesModelParts(model)
baseModel = strings.ToLower(strings.TrimSpace(baseModel))
if baseModel != defaultXAIVideosModel {
return false
}
prefix = strings.ToLower(strings.TrimSpace(prefix))
return prefix == "" || prefix == "xai" || prefix == "x-ai" || prefix == "grok"
}
func isSupportedVideosModel(model string) bool {
return isXAIVideosModel(model)
}
func rejectUnsupportedVideosModel(c *gin.Context, model string) bool {
if isSupportedVideosModel(model) {
return false
}
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s. Use %s.", model, videosPath, defaultXAIVideosModel),
Type: "invalid_request_error",
},
})
return true
}
func rejectUnsupportedNativeVideosModel(c *gin.Context, model string) bool {
if isSupportedVideosModel(model) {
return false
}
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s, %s, or %s. Use %s.", model, xaiVideosGenerationsAPI, xaiVideosEditsAPI, xaiVideosExtensionsAPI, defaultXAIVideosModel),
Type: "invalid_request_error",
},
})
return true
}
func canonicalXAIVideosModel(model string) string {
if videosModelBase(model) == defaultXAIVideosModel {
return defaultXAIVideosModel
}
return defaultXAIVideosModel
}
func readVideosCreateRequest(c *gin.Context) ([]byte, error) {
contentType := strings.ToLower(strings.TrimSpace(c.ContentType()))
switch contentType {
case "multipart/form-data", "application/x-www-form-urlencoded":
return videosCreateRequestFromForm(c)
default:
rawJSON, err := handlers.ReadRequestBody(c)
if err != nil {
return nil, err
}
if !json.Valid(rawJSON) {
return nil, fmt.Errorf("body must be valid JSON")
}
return rawJSON, nil
}
}
func readXAIVideosNativeRequest(c *gin.Context) ([]byte, error) {
rawJSON, err := handlers.ReadRequestBody(c)
if err != nil {
return nil, err
}
if !json.Valid(rawJSON) {
return nil, fmt.Errorf("body must be valid JSON")
}
return rawJSON, nil
}
func videosCreateRequestFromForm(c *gin.Context) ([]byte, error) {
rawJSON := []byte(`{}`)
for _, field := range []string{"model", "prompt", "seconds", "size", "aspect_ratio", "resolution"} {
if value := strings.TrimSpace(c.PostForm(field)); value != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, field, value)
}
}
if value := strings.TrimSpace(firstPostForm(c, "input_reference[image_url]", "input_reference.image_url", "image_url")); value != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "input_reference.image_url", value)
}
if value := strings.TrimSpace(firstPostForm(c, "input_reference[file_id]", "input_reference.file_id", "file_id")); value != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "input_reference.file_id", value)
}
if refs := strings.TrimSpace(c.PostForm("reference_image_urls")); refs != "" {
for _, ref := range strings.Split(refs, ",") {
if ref = strings.TrimSpace(ref); ref != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "reference_image_urls.-1", ref)
}
}
}
return rawJSON, nil
}
func firstPostForm(c *gin.Context, keys ...string) string {
for _, key := range keys {
if value := c.PostForm(key); strings.TrimSpace(value) != "" {
return value
}
}
return ""
}
func buildXAIVideosCreateRequest(rawJSON []byte, model string) ([]byte, xaiVideoCreateMetadata, error) {
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
if prompt == "" {
return nil, xaiVideoCreateMetadata{}, fmt.Errorf("prompt is required")
}
seconds, duration, err := normalizeXAIVideosSeconds(gjson.GetBytes(rawJSON, "seconds").String())
if err != nil {
return nil, xaiVideoCreateMetadata{}, err
}
size, aspectRatio, resolution, err := xaiVideosSizeOptions(gjson.GetBytes(rawJSON, "size").String())
if err != nil {
return nil, xaiVideoCreateMetadata{}, err
}
if value := xaiVideosAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), ""); value != "" {
aspectRatio = value
}
if value := xaiVideosResolution(gjson.GetBytes(rawJSON, "resolution").String(), ""); value != "" {
resolution = value
}
imageURL, err := xaiVideosInputImageURL(rawJSON)
if err != nil {
return nil, xaiVideoCreateMetadata{}, err
}
referenceImages := collectXAIVideoReferenceImages(rawJSON)
if len(referenceImages) > maxXAIVideoReferences {
return nil, xaiVideoCreateMetadata{}, fmt.Errorf("reference_images supports at most %d images on xAI", maxXAIVideoReferences)
}
if imageURL != "" && len(referenceImages) > 0 {
return nil, xaiVideoCreateMetadata{}, fmt.Errorf("image and reference_images cannot be combined on xAI")
}
if len(referenceImages) > 0 && duration > 10 {
duration = 10
seconds = "10"
}
req := []byte(`{}`)
req, _ = sjson.SetBytes(req, "model", canonicalXAIVideosModel(model))
req, _ = sjson.SetBytes(req, "prompt", prompt)
req, _ = sjson.SetRawBytes(req, "duration", []byte(strconv.FormatInt(duration, 10)))
req, _ = sjson.SetBytes(req, "aspect_ratio", aspectRatio)
req, _ = sjson.SetBytes(req, "resolution", resolution)
if imageURL != "" {
req, _ = sjson.SetBytes(req, "image.url", imageURL)
}
for _, image := range referenceImages {
req, _ = sjson.SetBytes(req, "reference_images.-1.url", image)
}
meta := xaiVideoCreateMetadata{
Model: defaultXAIVideosModel,
Prompt: prompt,
Seconds: seconds,
Size: size,
CreatedAt: time.Now().Unix(),
}
return req, meta, nil
}
func normalizeXAIVideosSeconds(raw string) (string, int64, error) {
seconds := strings.TrimSpace(raw)
if seconds == "" {
seconds = defaultVideosSeconds
}
duration, err := strconv.ParseInt(seconds, 10, 64)
if err != nil {
return "", 0, fmt.Errorf("seconds must be an integer")
}
if duration < 1 {
duration = 1
}
if duration > 15 {
duration = 15
}
return strconv.FormatInt(duration, 10), duration, nil
}
func xaiVideosSizeOptions(raw string) (size string, aspectRatio string, resolution string, err error) {
size = strings.TrimSpace(raw)
if size == "" {
size = defaultVideosSize
}
switch size {
case "720x1280", "1024x1792":
return size, "9:16", defaultVideosResolution, nil
case "1280x720", "1792x1024":
return size, "16:9", defaultVideosResolution, nil
default:
return "", "", "", fmt.Errorf("size must be one of 720x1280, 1280x720, 1024x1792, or 1792x1024")
}
}
func xaiVideosAspectRatio(raw string, fallback string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1:1", "square":
return "1:1"
case "16:9", "landscape":
return "16:9"
case "9:16", "portrait":
return "9:16"
case "4:3":
return "4:3"
case "3:4":
return "3:4"
case "3:2":
return "3:2"
case "2:3":
return "2:3"
default:
return fallback
}
}
func xaiVideosResolution(raw string, fallback string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "480p":
return "480p"
case "720p":
return "720p"
default:
return fallback
}
}
func xaiVideosInputImageURL(rawJSON []byte) (string, error) {
inputRef := gjson.GetBytes(rawJSON, "input_reference")
if inputRef.Exists() {
imageURL := strings.TrimSpace(inputRef.Get("image_url").String())
fileID := strings.TrimSpace(inputRef.Get("file_id").String())
if imageURL != "" && fileID != "" {
return "", fmt.Errorf("input_reference must provide exactly one of image_url or file_id")
}
if fileID != "" {
return "", fmt.Errorf("input_reference.file_id is not supported for xAI video generation; use input_reference.image_url")
}
if imageURL != "" {
return imageURL, nil
}
}
image := gjson.GetBytes(rawJSON, "image")
if image.Exists() {
if image.Type == gjson.String {
return strings.TrimSpace(image.String()), nil
}
if value := strings.TrimSpace(image.Get("url").String()); value != "" {
return value, nil
}
if value := strings.TrimSpace(image.Get("image_url.url").String()); value != "" {
return value, nil
}
}
return strings.TrimSpace(gjson.GetBytes(rawJSON, "image_url").String()), nil
}
func collectXAIVideoReferenceImages(rawJSON []byte) []string {
out := make([]string, 0)
appendRef := func(value string) {
value = strings.TrimSpace(value)
if value != "" {
out = append(out, value)
}
}
collectArray := func(result gjson.Result) {
if !result.IsArray() {
return
}
result.ForEach(func(_, item gjson.Result) bool {
if item.Type == gjson.String {
appendRef(item.String())
return true
}
if value := item.Get("url").String(); value != "" {
appendRef(value)
return true
}
if value := item.Get("image_url.url").String(); value != "" {
appendRef(value)
}
return true
})
}
collectArray(gjson.GetBytes(rawJSON, "reference_images"))
collectArray(gjson.GetBytes(rawJSON, "reference_image_urls"))
return out
}
func buildVideosCreateAPIResponseFromXAI(payload []byte, meta xaiVideoCreateMetadata) ([]byte, error) {
requestID := strings.TrimSpace(gjson.GetBytes(payload, "request_id").String())
if requestID == "" {
requestID = strings.TrimSpace(gjson.GetBytes(payload, "id").String())
}
if requestID == "" {
return nil, fmt.Errorf("xAI video response did not include request_id")
}
out := []byte(`{"object":"video","progress":0,"status":"queued"}`)
out, _ = sjson.SetBytes(out, "id", requestID)
out, _ = sjson.SetBytes(out, "model", meta.Model)
out, _ = sjson.SetBytes(out, "prompt", meta.Prompt)
out, _ = sjson.SetBytes(out, "seconds", meta.Seconds)
out, _ = sjson.SetBytes(out, "size", meta.Size)
out, _ = sjson.SetBytes(out, "created_at", meta.CreatedAt)
if status := openAIVideoStatus(gjson.GetBytes(payload, "status").String()); status != "" {
out, _ = sjson.SetBytes(out, "status", status)
}
if progress := gjson.GetBytes(payload, "progress"); progress.Exists() {
out, _ = sjson.SetRawBytes(out, "progress", []byte(progress.Raw))
}
return out, nil
}
func buildVideosRetrieveAPIResponseFromXAI(videoID string, payload []byte, fallbackModel string) ([]byte, error) {
out := []byte(`{"object":"video"}`)
out, _ = sjson.SetBytes(out, "id", videoID)
model := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if model == "" {
model = fallbackModel
}
out, _ = sjson.SetBytes(out, "model", model)
if status := openAIVideoStatus(gjson.GetBytes(payload, "status").String()); status != "" {
out, _ = sjson.SetBytes(out, "status", status)
}
if progress := gjson.GetBytes(payload, "progress"); progress.Exists() {
out, _ = sjson.SetRawBytes(out, "progress", []byte(progress.Raw))
}
if duration := gjson.GetBytes(payload, "video.duration"); duration.Exists() {
out, _ = sjson.SetBytes(out, "seconds", duration.String())
}
if video := gjson.GetBytes(payload, "video"); video.Exists() && json.Valid([]byte(video.Raw)) {
out, _ = sjson.SetRawBytes(out, "video", []byte(video.Raw))
}
if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && json.Valid([]byte(usage.Raw)) {
out, _ = sjson.SetRawBytes(out, "usage", []byte(usage.Raw))
}
if errPayload := gjson.GetBytes(payload, "error"); errPayload.Exists() && json.Valid([]byte(errPayload.Raw)) {
out, _ = sjson.SetRawBytes(out, "error", []byte(errPayload.Raw))
}
return out, nil
}
func openAIVideoStatus(status string) string {
switch strings.ToLower(strings.TrimSpace(status)) {
case "queued", "pending":
return "queued"
case "in_progress", "processing", "running":
return "in_progress"
case "completed", "done", "succeeded", "success":
return "completed"
case "failed", "error", "expired", "cancelled", "canceled":
return "failed"
default:
return ""
}
}
func (h *OpenAIAPIHandler) VideosCreate(c *gin.Context) {
rawJSON, err := readVideosCreateRequest(c)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
videoModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
if videoModel == "" {
videoModel = defaultXAIVideosModel
}
if rejectUnsupportedVideosModel(c, videoModel) {
return
}
xaiReq, meta, err := buildXAIVideosCreateRequest(rawJSON, videoModel)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
h.collectXAIVideosCreate(c, xaiReq, meta)
}
func (h *OpenAIAPIHandler) XAIVideosGenerations(c *gin.Context) {
h.handleXAIVideosNativePost(c)
}
func (h *OpenAIAPIHandler) XAIVideosEdits(c *gin.Context) {
h.handleXAIVideosNativePost(c)
}
func (h *OpenAIAPIHandler) XAIVideosExtensions(c *gin.Context) {
h.handleXAIVideosNativePost(c)
}
func (h *OpenAIAPIHandler) handleXAIVideosNativePost(c *gin.Context) {
rawJSON, err := readXAIVideosNativeRequest(c)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
videoModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
if videoModel == "" {
videoModel = defaultXAIVideosModel
}
if rejectUnsupportedNativeVideosModel(c, videoModel) {
return
}
h.collectXAIVideosNative(c, rawJSON, videoModel)
}
func (h *OpenAIAPIHandler) XAIVideosRetrieve(c *gin.Context) {
requestID := strings.TrimSpace(c.Param("request_id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Param("video_id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Invalid request: request_id is required",
Type: "invalid_request_error",
},
})
return
}
payload := []byte(`{}`)
payload, _ = sjson.SetBytes(payload, "request_id", requestID)
h.collectXAIVideosNative(c, payload, defaultXAIVideosModel)
}
func (h *OpenAIAPIHandler) VideosRetrieve(c *gin.Context) {
videoID := strings.TrimSpace(c.Param("video_id"))
if videoID == "" {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Invalid request: video_id is required",
Type: "invalid_request_error",
},
})
return
}
payload := []byte(`{}`)
payload, _ = sjson.SetBytes(payload, "request_id", videoID)
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiVideosHandlerType, defaultXAIVideosModel, payload, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
out, err := buildVideosRetrieveAPIResponseFromXAI(videoID, resp, defaultXAIVideosModel)
if err != nil {
errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}
h.WriteErrorResponse(c, errMsg)
cliCancel(err)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(out)
cliCancel(nil)
}
func (h *OpenAIAPIHandler) collectXAIVideosNative(c *gin.Context, rawJSON []byte, model string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiVideosHandlerType, model, rawJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel(nil)
}
func (h *OpenAIAPIHandler) collectXAIVideosCreate(c *gin.Context, xaiReq []byte, meta xaiVideoCreateMetadata) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiVideosHandlerType, meta.Model, xaiReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
out, err := buildVideosCreateAPIResponseFromXAI(resp, meta)
if err != nil {
errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}
h.WriteErrorResponse(c, errMsg)
cliCancel(err)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(out)
cliCancel(nil)
}
@@ -0,0 +1,227 @@
package openai
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func performVideosEndpointRequest(t *testing.T, method string, endpointPath string, contentType string, body io.Reader, handler gin.HandlerFunc) *httptest.ResponseRecorder {
t.Helper()
gin.SetMode(gin.TestMode)
router := gin.New()
switch method {
case http.MethodGet:
router.GET(endpointPath, handler)
default:
router.POST(endpointPath, handler)
}
req := httptest.NewRequest(method, endpointPath, body)
if contentType != "" {
req.Header.Set("Content-Type", contentType)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
return resp
}
func TestVideosModelValidationAllowsXAIVideoModel(t *testing.T) {
for _, model := range []string{"grok-imagine-video", "xai/grok-imagine-video", "x-ai/grok-imagine-video", "grok/grok-imagine-video"} {
if !isSupportedVideosModel(model) {
t.Fatalf("expected %s to be supported", model)
}
}
if isSupportedVideosModel("sora-2") {
t.Fatal("expected sora-2 to be rejected")
}
if isSupportedVideosModel("codex/grok-imagine-video") {
t.Fatal("expected codex/grok-imagine-video to be rejected")
}
}
func TestBuildXAIVideosCreateRequest(t *testing.T) {
rawJSON := []byte(`{"model":"xai/grok-imagine-video","prompt":"a cat playing piano","seconds":"8","size":"1280x720","input_reference":{"image_url":"https://example.com/cat.png"}}`)
req, meta, err := buildXAIVideosCreateRequest(rawJSON, "xai/grok-imagine-video")
if err != nil {
t.Fatalf("buildXAIVideosCreateRequest() error = %v", err)
}
if got := gjson.GetBytes(req, "model").String(); got != defaultXAIVideosModel {
t.Fatalf("model = %q, want %s", got, defaultXAIVideosModel)
}
if got := gjson.GetBytes(req, "prompt").String(); got != "a cat playing piano" {
t.Fatalf("prompt = %q", got)
}
if got := gjson.GetBytes(req, "duration").Int(); got != 8 {
t.Fatalf("duration = %d, want 8", got)
}
if got := gjson.GetBytes(req, "aspect_ratio").String(); got != "16:9" {
t.Fatalf("aspect_ratio = %q, want 16:9", got)
}
if got := gjson.GetBytes(req, "resolution").String(); got != "720p" {
t.Fatalf("resolution = %q, want 720p", got)
}
if got := gjson.GetBytes(req, "image.url").String(); got != "https://example.com/cat.png" {
t.Fatalf("image.url = %q", got)
}
if meta.Seconds != "8" || meta.Size != "1280x720" || meta.Prompt != "a cat playing piano" {
t.Fatalf("unexpected meta: %+v", meta)
}
}
func TestBuildXAIVideosCreateRequestAllowsCustomSeconds(t *testing.T) {
rawJSON := []byte(`{"model":"grok-imagine-video","prompt":"a cat playing piano","seconds":"6"}`)
req, meta, err := buildXAIVideosCreateRequest(rawJSON, "grok-imagine-video")
if err != nil {
t.Fatalf("buildXAIVideosCreateRequest() error = %v", err)
}
if got := gjson.GetBytes(req, "duration").Int(); got != 6 {
t.Fatalf("duration = %d, want 6", got)
}
if meta.Seconds != "6" {
t.Fatalf("meta seconds = %q, want 6", meta.Seconds)
}
}
func TestBuildXAIVideosCreateRequestRejectsFileIDReference(t *testing.T) {
rawJSON := []byte(`{"prompt":"animate","input_reference":{"file_id":"file_123"}}`)
_, _, err := buildXAIVideosCreateRequest(rawJSON, defaultXAIVideosModel)
if err == nil || !strings.Contains(err.Error(), "input_reference.file_id is not supported") {
t.Fatalf("error = %v, want unsupported file_id error", err)
}
}
func TestBuildVideosCreateAPIResponseFromXAI(t *testing.T) {
meta := xaiVideoCreateMetadata{
Model: defaultXAIVideosModel,
Prompt: "animate",
Seconds: "4",
Size: "720x1280",
CreatedAt: 123,
}
out, err := buildVideosCreateAPIResponseFromXAI([]byte(`{"request_id":"vid_123"}`), meta)
if err != nil {
t.Fatalf("buildVideosCreateAPIResponseFromXAI() error = %v", err)
}
if got := gjson.GetBytes(out, "id").String(); got != "vid_123" {
t.Fatalf("id = %q, want vid_123", got)
}
if got := gjson.GetBytes(out, "object").String(); got != "video" {
t.Fatalf("object = %q, want video", got)
}
if got := gjson.GetBytes(out, "status").String(); got != "queued" {
t.Fatalf("status = %q, want queued", got)
}
if got := gjson.GetBytes(out, "created_at").Int(); got != 123 {
t.Fatalf("created_at = %d, want 123", got)
}
}
func TestBuildVideosRetrieveAPIResponseFromXAI(t *testing.T) {
payload := []byte(`{"status":"done","video":{"url":"https://vidgen.x.ai/video.mp4","duration":6,"respect_moderation":true},"model":"grok-imagine-video","usage":{"cost_in_usd_ticks":500000000},"progress":100}`)
out, err := buildVideosRetrieveAPIResponseFromXAI("vid_123", payload, defaultXAIVideosModel)
if err != nil {
t.Fatalf("buildVideosRetrieveAPIResponseFromXAI() error = %v", err)
}
if got := gjson.GetBytes(out, "id").String(); got != "vid_123" {
t.Fatalf("id = %q, want vid_123", got)
}
if got := gjson.GetBytes(out, "status").String(); got != "completed" {
t.Fatalf("status = %q, want completed", got)
}
if got := gjson.GetBytes(out, "seconds").String(); got != "6" {
t.Fatalf("seconds = %q, want 6", got)
}
if got := gjson.GetBytes(out, "video.url").String(); got != "https://vidgen.x.ai/video.mp4" {
t.Fatalf("video.url = %q", got)
}
if !gjson.GetBytes(out, "usage").Exists() {
t.Fatalf("usage missing: %s", string(out))
}
}
func TestVideosCreateRejectsUnsupportedModel(t *testing.T) {
handler := &OpenAIAPIHandler{}
body := strings.NewReader(`{"model":"sora-2","prompt":"make a video"}`)
resp := performVideosEndpointRequest(t, http.MethodPost, videosPath, "application/json", body, handler.VideosCreate)
if resp.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String())
}
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
expectedMessage := "Model sora-2 is not supported on " + videosPath + ". Use " + defaultXAIVideosModel + "."
if message != expectedMessage {
t.Fatalf("error message = %q, want %q", message, expectedMessage)
}
}
func TestXAIVideosNativeRejectsUnsupportedModel(t *testing.T) {
handler := &OpenAIAPIHandler{}
body := strings.NewReader(`{"model":"sora-2","prompt":"make a video"}`)
resp := performVideosEndpointRequest(t, http.MethodPost, xaiVideosGenerationsAPI, "application/json", body, handler.XAIVideosGenerations)
if resp.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String())
}
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
expectedMessage := "Model sora-2 is not supported on " + xaiVideosGenerationsAPI + ", " + xaiVideosEditsAPI + ", or " + xaiVideosExtensionsAPI + ". Use " + defaultXAIVideosModel + "."
if message != expectedMessage {
t.Fatalf("error message = %q, want %q", message, expectedMessage)
}
}
func TestXAIVideosNativeRejectsInvalidJSON(t *testing.T) {
handler := &OpenAIAPIHandler{}
body := strings.NewReader(`{"model":`)
resp := performVideosEndpointRequest(t, http.MethodPost, xaiVideosEditsAPI, "application/json", body, handler.XAIVideosEdits)
if resp.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String())
}
if got := gjson.GetBytes(resp.Body.Bytes(), "error.type").String(); got != "invalid_request_error" {
t.Fatalf("error type = %q, want invalid_request_error", got)
}
}
func TestVideosCreateFormRequest(t *testing.T) {
rawJSON, err := videosCreateRequestFromFormContext("model=grok-imagine-video&prompt=make+a+video&seconds=4&size=720x1280&input_reference%5Bimage_url%5D=https%3A%2F%2Fexample.com%2Fa.png")
if err != nil {
t.Fatalf("videosCreateRequestFromFormContext() error = %v", err)
}
if got := gjson.GetBytes(rawJSON, "input_reference.image_url").String(); got != "https://example.com/a.png" {
t.Fatalf("input_reference.image_url = %q", got)
}
}
func videosCreateRequestFromFormContext(body string) ([]byte, error) {
gin.SetMode(gin.TestMode)
router := gin.New()
var rawJSON []byte
var err error
router.POST(videosPath, func(c *gin.Context) {
rawJSON, err = videosCreateRequestFromForm(c)
})
req := httptest.NewRequest(http.MethodPost, videosPath, strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
return rawJSON, err
}