From 53d1fd6c5c8f8703458501e8b8bf5d23408caead Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 17 May 2026 02:53:50 +0800 Subject: [PATCH] feat(api, xai): add xAI Grok video model support with API integration - Introduced new xAI `grok-imagine-video` model for video generation with configurable options (e.g., duration, size, resolution). - Implemented video-specific API endpoints (`/v1/videos`, `/v1/videos/generations`, `/v1/videos/edits`, `/v1/videos/extensions`), including request validation and model handling. - Enhanced model registry with `xaiBuiltinVideoModelID` and metadata for video capabilities. - Added unit tests to validate video model support, request structures, and API response handling. - Extended `XAIExecutor` to integrate video generation and retrieval via runtime requests. --- internal/api/server.go | 5 + internal/logging/gin_logger.go | 1 + internal/logging/gin_logger_test.go | 6 + internal/registry/model_definitions.go | 18 +- internal/registry/model_definitions_test.go | 16 + internal/runtime/executor/xai_executor.go | 96 +++ .../runtime/executor/xai_executor_test.go | 165 +++++ .../handlers/openai/openai_videos_handlers.go | 598 ++++++++++++++++++ .../openai/openai_videos_handlers_test.go | 227 +++++++ 9 files changed, 1130 insertions(+), 2 deletions(-) create mode 100644 sdk/api/handlers/openai/openai_videos_handlers.go create mode 100644 sdk/api/handlers/openai/openai_videos_handlers_test.go diff --git a/internal/api/server.go b/internal/api/server.go index 499c4acb..110a827d 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -387,6 +387,11 @@ func (s *Server) setupRoutes() { v1.POST("/completions", openaiHandlers.Completions) v1.POST("/images/generations", openaiHandlers.ImagesGenerations) v1.POST("/images/edits", openaiHandlers.ImagesEdits) + v1.POST("/videos", openaiHandlers.VideosCreate) + v1.POST("/videos/generations", openaiHandlers.XAIVideosGenerations) + v1.POST("/videos/edits", openaiHandlers.XAIVideosEdits) + v1.POST("/videos/extensions", openaiHandlers.XAIVideosExtensions) + v1.GET("/videos/:request_id", openaiHandlers.XAIVideosRetrieve) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket) diff --git a/internal/logging/gin_logger.go b/internal/logging/gin_logger.go index 6e3559b8..80821376 100644 --- a/internal/logging/gin_logger.go +++ b/internal/logging/gin_logger.go @@ -21,6 +21,7 @@ var aiAPIPrefixes = []string{ "/v1/chat/completions", "/v1/completions", "/v1/images", + "/v1/videos", "/v1/messages", "/v1/responses", "/v1beta/models/", diff --git a/internal/logging/gin_logger_test.go b/internal/logging/gin_logger_test.go index 9bd3ddfb..73480dec 100644 --- a/internal/logging/gin_logger_test.go +++ b/internal/logging/gin_logger_test.go @@ -66,4 +66,10 @@ func TestIsAIAPIPathIncludesImages(t *testing.T) { if !isAIAPIPath("/v1/images/edits") { t.Fatalf("expected /v1/images/edits to be treated as AI API path") } + if !isAIAPIPath("/v1/videos") { + t.Fatalf("expected /v1/videos to be treated as AI API path") + } + if !isAIAPIPath("/v1/videos/video_123") { + t.Fatalf("expected /v1/videos/video_123 to be treated as AI API path") + } } diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index fcb5827d..f160325f 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -10,6 +10,7 @@ const ( codexBuiltinImageModelID = "gpt-image-2" xaiBuiltinImageModelID = "grok-imagine-image" xaiBuiltinImageQualityModelID = "grok-imagine-image-quality" + xaiBuiltinVideoModelID = "grok-imagine-video" ) // staticModelsJSON mirrors the top-level structure of models.json. @@ -95,10 +96,10 @@ func WithCodexBuiltins(models []*ModelInfo) []*ModelInfo { return upsertModelInfos(models, codexBuiltinImageModelInfo()) } -// WithXAIBuiltins injects hard-coded xAI image model definitions that should +// WithXAIBuiltins injects hard-coded xAI image/video model definitions that should // not depend on remote models.json updates. func WithXAIBuiltins(models []*ModelInfo) []*ModelInfo { - return upsertModelInfos(models, xaiBuiltinImageModelInfo(), xaiBuiltinImageQualityModelInfo()) + return upsertModelInfos(models, xaiBuiltinImageModelInfo(), xaiBuiltinImageQualityModelInfo(), xaiBuiltinVideoModelInfo()) } func codexBuiltinImageModelInfo() *ModelInfo { @@ -139,6 +140,19 @@ func xaiBuiltinImageQualityModelInfo() *ModelInfo { } } +func xaiBuiltinVideoModelInfo() *ModelInfo { + return &ModelInfo{ + ID: xaiBuiltinVideoModelID, + Object: "model", + Created: 1735689600, // 2025-01-01 + OwnedBy: "xai", + Type: "xai", + DisplayName: "Grok Imagine Video", + Name: xaiBuiltinVideoModelID, + Description: "xAI Grok video generation model.", + } +} + func upsertModelInfos(models []*ModelInfo, extras ...*ModelInfo) []*ModelInfo { if len(extras) == 0 { return models diff --git a/internal/registry/model_definitions_test.go b/internal/registry/model_definitions_test.go index bb2fc460..f7ce02bc 100644 --- a/internal/registry/model_definitions_test.go +++ b/internal/registry/model_definitions_test.go @@ -33,6 +33,22 @@ func TestCodexStaticModelsIncludeGPT55(t *testing.T) { assertGPT55ModelInfo(t, "lookup", model) } +func TestWithXAIBuiltinsAddsVideoModel(t *testing.T) { + models := WithXAIBuiltins(nil) + found := false + for _, model := range models { + if model != nil && model.ID == xaiBuiltinVideoModelID { + found = true + if model.OwnedBy != "xai" { + t.Fatalf("OwnedBy = %q, want xai", model.OwnedBy) + } + } + } + if !found { + t.Fatalf("expected %s builtin model", xaiBuiltinVideoModelID) + } +} + func findModelInfo(models []*ModelInfo, id string) *ModelInfo { for _, model := range models { if model != nil && model.ID == id { diff --git a/internal/runtime/executor/xai_executor.go b/internal/runtime/executor/xai_executor.go index 592506ac..507ad6a7 100644 --- a/internal/runtime/executor/xai_executor.go +++ b/internal/runtime/executor/xai_executor.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "net/url" "sort" "strings" "time" @@ -29,9 +30,15 @@ var xaiDataTag = []byte("data:") const ( xaiImageHandlerType = "openai-image" + xaiVideoHandlerType = "openai-video" xaiImagesGenerationsPath = "/images/generations" xaiImagesEditsPath = "/images/edits" xaiDefaultImageEndpointPath = xaiImagesGenerationsPath + xaiVideosGenerationsPath = "/videos/generations" + xaiVideosEditsPath = "/videos/edits" + xaiVideosExtensionsPath = "/videos/extensions" + xaiVideosPath = "/videos" + xaiIdempotencyKeyMetaKey = "idempotency_key" ) // XAIExecutor is a stateless executor for xAI Grok's Responses API. @@ -86,6 +93,9 @@ func (e *XAIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req if endpointPath := xaiImageEndpointPath(opts); endpointPath != "" { return e.executeImages(ctx, auth, req, endpointPath) } + if xaiIsVideoRequest(opts) { + return e.executeVideos(ctx, auth, req, opts) + } token, baseURL := xaiCreds(auth) if baseURL == "" { @@ -207,6 +217,71 @@ func (e *XAIExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth return cliproxyexecutor.Response{Payload: data, Headers: httpResp.Header.Clone()}, nil } +func (e *XAIExecutor) executeVideos(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + token, baseURL := xaiCreds(auth) + if baseURL == "" { + baseURL = xaiauth.DefaultAPIBaseURL + } + + method := http.MethodPost + endpointPath := xaiVideosGenerationsPath + var body io.Reader = bytes.NewReader(req.Payload) + + switch path := xaiVideoEndpointPath(opts); path { + case xaiVideosGenerationsPath, xaiVideosEditsPath, xaiVideosExtensionsPath: + endpointPath = path + default: + if requestID := strings.TrimSpace(gjson.GetBytes(req.Payload, "request_id").String()); requestID != "" { + method = http.MethodGet + endpointPath = xaiVideosPath + "/" + url.PathEscape(requestID) + body = nil + } + } + requestURL := strings.TrimSuffix(baseURL, "/") + endpointPath + httpReq, err := http.NewRequestWithContext(ctx, method, requestURL, body) + if err != nil { + return resp, err + } + applyXAIHeaders(httpReq, auth, token, false, "") + if method == http.MethodPost { + key := xaiMetadataString(opts.Metadata, xaiIdempotencyKeyMetaKey) + if key == "" && opts.Headers != nil { + key = strings.TrimSpace(opts.Headers.Get("x-idempotency-key")) + } + if key != "" { + httpReq.Header.Set("x-idempotency-key", key) + } + } + e.recordXAIRequest(ctx, auth, requestURL, httpReq.Header.Clone(), req.Payload) + + httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, err := httpClient.Do(httpReq) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("xai executor: close response body error: %v", errClose) + } + }() + helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + + data, err := io.ReadAll(httpResp.Body) + if err != nil { + helps.RecordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + helps.AppendAPIResponseChunk(ctx, e.cfg, data) + + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return resp, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + + return cliproxyexecutor.Response{Payload: data, Headers: httpResp.Header.Clone()}, nil +} + func (e *XAIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { token, baseURL := xaiCreds(auth) if baseURL == "" { @@ -525,6 +600,27 @@ func xaiImageEndpointPath(opts cliproxyexecutor.Options) string { return xaiDefaultImageEndpointPath } +func xaiIsVideoRequest(opts cliproxyexecutor.Options) bool { + return opts.SourceFormat.String() == xaiVideoHandlerType +} + +func xaiVideoEndpointPath(opts cliproxyexecutor.Options) string { + if !xaiIsVideoRequest(opts) { + return "" + } + path := xaiMetadataString(opts.Metadata, cliproxyexecutor.RequestPathMetadataKey) + if strings.HasSuffix(path, "/videos/edits") { + return xaiVideosEditsPath + } + if strings.HasSuffix(path, "/videos/extensions") { + return xaiVideosExtensionsPath + } + if strings.HasSuffix(path, "/videos/generations") { + return xaiVideosGenerationsPath + } + return "" +} + func xaiMetadataString(meta map[string]any, key string) string { if len(meta) == 0 || key == "" { return "" diff --git a/internal/runtime/executor/xai_executor_test.go b/internal/runtime/executor/xai_executor_test.go index 1a517f75..1f8683ff 100644 --- a/internal/runtime/executor/xai_executor_test.go +++ b/internal/runtime/executor/xai_executor_test.go @@ -229,3 +229,168 @@ func TestXAIExecutorExecuteImagesUsesEditsEndpoint(t *testing.T) { t.Fatalf("path = %q, want /images/edits", gotPath) } } + +func TestXAIExecutorExecuteVideosCreate(t *testing.T) { + var gotPath string + var gotMethod string + var gotAuth string + var gotIdempotencyKey string + var gotBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotMethod = r.Method + gotAuth = r.Header.Get("Authorization") + gotIdempotencyKey = r.Header.Get("x-idempotency-key") + var errRead error + gotBody, errRead = io.ReadAll(r.Body) + if errRead != nil { + t.Fatalf("read body: %v", errRead) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"request_id":"vid_123"}`)) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Attributes: map[string]string{"base_url": server.URL}, + Metadata: map[string]any{"access_token": "xai-token"}, + } + + resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-imagine-video", + Payload: []byte(`{"model":"grok-imagine-video","prompt":"animate","duration":4}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai-video"), + Metadata: map[string]any{ + "idempotency_key": "idem-123", + }, + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gotMethod != http.MethodPost { + t.Fatalf("method = %q, want POST", gotMethod) + } + if gotPath != "/videos/generations" { + t.Fatalf("path = %q, want /videos/generations", gotPath) + } + if gotAuth != "Bearer xai-token" { + t.Fatalf("Authorization = %q, want Bearer xai-token", gotAuth) + } + if gotIdempotencyKey != "idem-123" { + t.Fatalf("x-idempotency-key = %q, want idem-123", gotIdempotencyKey) + } + if string(gotBody) != `{"model":"grok-imagine-video","prompt":"animate","duration":4}` { + t.Fatalf("body = %s", string(gotBody)) + } + if gjson.GetBytes(resp.Payload, "request_id").String() != "vid_123" { + t.Fatalf("payload = %s", string(resp.Payload)) + } +} + +func TestXAIExecutorExecuteVideosRetrieve(t *testing.T) { + var gotPath string + var gotMethod string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotMethod = r.Method + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"status":"done","video":{"url":"https://vidgen.x.ai/video.mp4","duration":6},"model":"grok-imagine-video","progress":100}`)) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Attributes: map[string]string{"base_url": server.URL}, + Metadata: map[string]any{"access_token": "xai-token"}, + } + + resp, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-imagine-video", + Payload: []byte(`{"request_id":"vid_123"}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai-video"), + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gotMethod != http.MethodGet { + t.Fatalf("method = %q, want GET", gotMethod) + } + if gotPath != "/videos/vid_123" { + t.Fatalf("path = %q, want /videos/vid_123", gotPath) + } + if gjson.GetBytes(resp.Payload, "video.url").String() != "https://vidgen.x.ai/video.mp4" { + t.Fatalf("payload = %s", string(resp.Payload)) + } +} + +func TestXAIExecutorExecuteVideosUsesNativeEndpointFromRequestPath(t *testing.T) { + tests := []struct { + name string + requestPath string + wantPath string + }{ + { + name: "generations", + requestPath: "/v1/videos/generations", + wantPath: "/videos/generations", + }, + { + name: "edits", + requestPath: "/v1/videos/edits", + wantPath: "/videos/edits", + }, + { + name: "extensions", + requestPath: "/v1/videos/extensions", + wantPath: "/videos/extensions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotPath string + var gotMethod string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotMethod = r.Method + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"request_id":"vid_123"}`)) + })) + defer server.Close() + + exec := NewXAIExecutor(&config.Config{}) + auth := &cliproxyauth.Auth{ + Provider: "xai", + Attributes: map[string]string{"base_url": server.URL}, + Metadata: map[string]any{"access_token": "xai-token"}, + } + + _, err := exec.Execute(context.Background(), auth, cliproxyexecutor.Request{ + Model: "grok-imagine-video", + Payload: []byte(`{"model":"grok-imagine-video","prompt":"animate"}`), + }, cliproxyexecutor.Options{ + SourceFormat: sdktranslator.FromString("openai-video"), + Metadata: map[string]any{ + cliproxyexecutor.RequestPathMetadataKey: tt.requestPath, + }, + }) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + if gotMethod != http.MethodPost { + t.Fatalf("method = %q, want POST", gotMethod) + } + if gotPath != tt.wantPath { + t.Fatalf("path = %q, want %s", gotPath, tt.wantPath) + } + }) + } +} diff --git a/sdk/api/handlers/openai/openai_videos_handlers.go b/sdk/api/handlers/openai/openai_videos_handlers.go new file mode 100644 index 00000000..15e69a68 --- /dev/null +++ b/sdk/api/handlers/openai/openai_videos_handlers.go @@ -0,0 +1,598 @@ +package openai + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + videosPath = "/v1/videos" + xaiVideosGenerationsAPI = "/v1/videos/generations" + xaiVideosEditsAPI = "/v1/videos/edits" + xaiVideosExtensionsAPI = "/v1/videos/extensions" + defaultXAIVideosModel = "grok-imagine-video" + xaiVideosHandlerType = "openai-video" + defaultVideosSeconds = "4" + defaultVideosSize = "720x1280" + defaultVideosResolution = "720p" + maxXAIVideoReferences = 7 +) + +type xaiVideoCreateMetadata struct { + Model string + Prompt string + Seconds string + Size string + CreatedAt int64 +} + +func videosModelBase(model string) string { + _, baseModel := imagesModelParts(model) + return strings.ToLower(strings.TrimSpace(baseModel)) +} + +func isXAIVideosModel(model string) bool { + prefix, baseModel := imagesModelParts(model) + baseModel = strings.ToLower(strings.TrimSpace(baseModel)) + if baseModel != defaultXAIVideosModel { + return false + } + + prefix = strings.ToLower(strings.TrimSpace(prefix)) + return prefix == "" || prefix == "xai" || prefix == "x-ai" || prefix == "grok" +} + +func isSupportedVideosModel(model string) bool { + return isXAIVideosModel(model) +} + +func rejectUnsupportedVideosModel(c *gin.Context, model string) bool { + if isSupportedVideosModel(model) { + return false + } + + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Model %s is not supported on %s. Use %s.", model, videosPath, defaultXAIVideosModel), + Type: "invalid_request_error", + }, + }) + return true +} + +func rejectUnsupportedNativeVideosModel(c *gin.Context, model string) bool { + if isSupportedVideosModel(model) { + return false + } + + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Model %s is not supported on %s, %s, or %s. Use %s.", model, xaiVideosGenerationsAPI, xaiVideosEditsAPI, xaiVideosExtensionsAPI, defaultXAIVideosModel), + Type: "invalid_request_error", + }, + }) + return true +} + +func canonicalXAIVideosModel(model string) string { + if videosModelBase(model) == defaultXAIVideosModel { + return defaultXAIVideosModel + } + return defaultXAIVideosModel +} + +func readVideosCreateRequest(c *gin.Context) ([]byte, error) { + contentType := strings.ToLower(strings.TrimSpace(c.ContentType())) + switch contentType { + case "multipart/form-data", "application/x-www-form-urlencoded": + return videosCreateRequestFromForm(c) + default: + rawJSON, err := handlers.ReadRequestBody(c) + if err != nil { + return nil, err + } + if !json.Valid(rawJSON) { + return nil, fmt.Errorf("body must be valid JSON") + } + return rawJSON, nil + } +} + +func readXAIVideosNativeRequest(c *gin.Context) ([]byte, error) { + rawJSON, err := handlers.ReadRequestBody(c) + if err != nil { + return nil, err + } + if !json.Valid(rawJSON) { + return nil, fmt.Errorf("body must be valid JSON") + } + return rawJSON, nil +} + +func videosCreateRequestFromForm(c *gin.Context) ([]byte, error) { + rawJSON := []byte(`{}`) + for _, field := range []string{"model", "prompt", "seconds", "size", "aspect_ratio", "resolution"} { + if value := strings.TrimSpace(c.PostForm(field)); value != "" { + rawJSON, _ = sjson.SetBytes(rawJSON, field, value) + } + } + if value := strings.TrimSpace(firstPostForm(c, "input_reference[image_url]", "input_reference.image_url", "image_url")); value != "" { + rawJSON, _ = sjson.SetBytes(rawJSON, "input_reference.image_url", value) + } + if value := strings.TrimSpace(firstPostForm(c, "input_reference[file_id]", "input_reference.file_id", "file_id")); value != "" { + rawJSON, _ = sjson.SetBytes(rawJSON, "input_reference.file_id", value) + } + if refs := strings.TrimSpace(c.PostForm("reference_image_urls")); refs != "" { + for _, ref := range strings.Split(refs, ",") { + if ref = strings.TrimSpace(ref); ref != "" { + rawJSON, _ = sjson.SetBytes(rawJSON, "reference_image_urls.-1", ref) + } + } + } + return rawJSON, nil +} + +func firstPostForm(c *gin.Context, keys ...string) string { + for _, key := range keys { + if value := c.PostForm(key); strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func buildXAIVideosCreateRequest(rawJSON []byte, model string) ([]byte, xaiVideoCreateMetadata, error) { + prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String()) + if prompt == "" { + return nil, xaiVideoCreateMetadata{}, fmt.Errorf("prompt is required") + } + + seconds, duration, err := normalizeXAIVideosSeconds(gjson.GetBytes(rawJSON, "seconds").String()) + if err != nil { + return nil, xaiVideoCreateMetadata{}, err + } + + size, aspectRatio, resolution, err := xaiVideosSizeOptions(gjson.GetBytes(rawJSON, "size").String()) + if err != nil { + return nil, xaiVideoCreateMetadata{}, err + } + if value := xaiVideosAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), ""); value != "" { + aspectRatio = value + } + if value := xaiVideosResolution(gjson.GetBytes(rawJSON, "resolution").String(), ""); value != "" { + resolution = value + } + + imageURL, err := xaiVideosInputImageURL(rawJSON) + if err != nil { + return nil, xaiVideoCreateMetadata{}, err + } + referenceImages := collectXAIVideoReferenceImages(rawJSON) + if len(referenceImages) > maxXAIVideoReferences { + return nil, xaiVideoCreateMetadata{}, fmt.Errorf("reference_images supports at most %d images on xAI", maxXAIVideoReferences) + } + if imageURL != "" && len(referenceImages) > 0 { + return nil, xaiVideoCreateMetadata{}, fmt.Errorf("image and reference_images cannot be combined on xAI") + } + if len(referenceImages) > 0 && duration > 10 { + duration = 10 + seconds = "10" + } + + req := []byte(`{}`) + req, _ = sjson.SetBytes(req, "model", canonicalXAIVideosModel(model)) + req, _ = sjson.SetBytes(req, "prompt", prompt) + req, _ = sjson.SetRawBytes(req, "duration", []byte(strconv.FormatInt(duration, 10))) + req, _ = sjson.SetBytes(req, "aspect_ratio", aspectRatio) + req, _ = sjson.SetBytes(req, "resolution", resolution) + if imageURL != "" { + req, _ = sjson.SetBytes(req, "image.url", imageURL) + } + for _, image := range referenceImages { + req, _ = sjson.SetBytes(req, "reference_images.-1.url", image) + } + + meta := xaiVideoCreateMetadata{ + Model: defaultXAIVideosModel, + Prompt: prompt, + Seconds: seconds, + Size: size, + CreatedAt: time.Now().Unix(), + } + return req, meta, nil +} + +func normalizeXAIVideosSeconds(raw string) (string, int64, error) { + seconds := strings.TrimSpace(raw) + if seconds == "" { + seconds = defaultVideosSeconds + } + duration, err := strconv.ParseInt(seconds, 10, 64) + if err != nil { + return "", 0, fmt.Errorf("seconds must be an integer") + } + if duration < 1 { + duration = 1 + } + if duration > 15 { + duration = 15 + } + return strconv.FormatInt(duration, 10), duration, nil +} + +func xaiVideosSizeOptions(raw string) (size string, aspectRatio string, resolution string, err error) { + size = strings.TrimSpace(raw) + if size == "" { + size = defaultVideosSize + } + switch size { + case "720x1280", "1024x1792": + return size, "9:16", defaultVideosResolution, nil + case "1280x720", "1792x1024": + return size, "16:9", defaultVideosResolution, nil + default: + return "", "", "", fmt.Errorf("size must be one of 720x1280, 1280x720, 1024x1792, or 1792x1024") + } +} + +func xaiVideosAspectRatio(raw string, fallback string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1:1", "square": + return "1:1" + case "16:9", "landscape": + return "16:9" + case "9:16", "portrait": + return "9:16" + case "4:3": + return "4:3" + case "3:4": + return "3:4" + case "3:2": + return "3:2" + case "2:3": + return "2:3" + default: + return fallback + } +} + +func xaiVideosResolution(raw string, fallback string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "480p": + return "480p" + case "720p": + return "720p" + default: + return fallback + } +} + +func xaiVideosInputImageURL(rawJSON []byte) (string, error) { + inputRef := gjson.GetBytes(rawJSON, "input_reference") + if inputRef.Exists() { + imageURL := strings.TrimSpace(inputRef.Get("image_url").String()) + fileID := strings.TrimSpace(inputRef.Get("file_id").String()) + if imageURL != "" && fileID != "" { + return "", fmt.Errorf("input_reference must provide exactly one of image_url or file_id") + } + if fileID != "" { + return "", fmt.Errorf("input_reference.file_id is not supported for xAI video generation; use input_reference.image_url") + } + if imageURL != "" { + return imageURL, nil + } + } + + image := gjson.GetBytes(rawJSON, "image") + if image.Exists() { + if image.Type == gjson.String { + return strings.TrimSpace(image.String()), nil + } + if value := strings.TrimSpace(image.Get("url").String()); value != "" { + return value, nil + } + if value := strings.TrimSpace(image.Get("image_url.url").String()); value != "" { + return value, nil + } + } + + return strings.TrimSpace(gjson.GetBytes(rawJSON, "image_url").String()), nil +} + +func collectXAIVideoReferenceImages(rawJSON []byte) []string { + out := make([]string, 0) + appendRef := func(value string) { + value = strings.TrimSpace(value) + if value != "" { + out = append(out, value) + } + } + collectArray := func(result gjson.Result) { + if !result.IsArray() { + return + } + result.ForEach(func(_, item gjson.Result) bool { + if item.Type == gjson.String { + appendRef(item.String()) + return true + } + if value := item.Get("url").String(); value != "" { + appendRef(value) + return true + } + if value := item.Get("image_url.url").String(); value != "" { + appendRef(value) + } + return true + }) + } + collectArray(gjson.GetBytes(rawJSON, "reference_images")) + collectArray(gjson.GetBytes(rawJSON, "reference_image_urls")) + return out +} + +func buildVideosCreateAPIResponseFromXAI(payload []byte, meta xaiVideoCreateMetadata) ([]byte, error) { + requestID := strings.TrimSpace(gjson.GetBytes(payload, "request_id").String()) + if requestID == "" { + requestID = strings.TrimSpace(gjson.GetBytes(payload, "id").String()) + } + if requestID == "" { + return nil, fmt.Errorf("xAI video response did not include request_id") + } + + out := []byte(`{"object":"video","progress":0,"status":"queued"}`) + out, _ = sjson.SetBytes(out, "id", requestID) + out, _ = sjson.SetBytes(out, "model", meta.Model) + out, _ = sjson.SetBytes(out, "prompt", meta.Prompt) + out, _ = sjson.SetBytes(out, "seconds", meta.Seconds) + out, _ = sjson.SetBytes(out, "size", meta.Size) + out, _ = sjson.SetBytes(out, "created_at", meta.CreatedAt) + if status := openAIVideoStatus(gjson.GetBytes(payload, "status").String()); status != "" { + out, _ = sjson.SetBytes(out, "status", status) + } + if progress := gjson.GetBytes(payload, "progress"); progress.Exists() { + out, _ = sjson.SetRawBytes(out, "progress", []byte(progress.Raw)) + } + return out, nil +} + +func buildVideosRetrieveAPIResponseFromXAI(videoID string, payload []byte, fallbackModel string) ([]byte, error) { + out := []byte(`{"object":"video"}`) + out, _ = sjson.SetBytes(out, "id", videoID) + + model := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if model == "" { + model = fallbackModel + } + out, _ = sjson.SetBytes(out, "model", model) + + if status := openAIVideoStatus(gjson.GetBytes(payload, "status").String()); status != "" { + out, _ = sjson.SetBytes(out, "status", status) + } + if progress := gjson.GetBytes(payload, "progress"); progress.Exists() { + out, _ = sjson.SetRawBytes(out, "progress", []byte(progress.Raw)) + } + if duration := gjson.GetBytes(payload, "video.duration"); duration.Exists() { + out, _ = sjson.SetBytes(out, "seconds", duration.String()) + } + if video := gjson.GetBytes(payload, "video"); video.Exists() && json.Valid([]byte(video.Raw)) { + out, _ = sjson.SetRawBytes(out, "video", []byte(video.Raw)) + } + if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && json.Valid([]byte(usage.Raw)) { + out, _ = sjson.SetRawBytes(out, "usage", []byte(usage.Raw)) + } + if errPayload := gjson.GetBytes(payload, "error"); errPayload.Exists() && json.Valid([]byte(errPayload.Raw)) { + out, _ = sjson.SetRawBytes(out, "error", []byte(errPayload.Raw)) + } + return out, nil +} + +func openAIVideoStatus(status string) string { + switch strings.ToLower(strings.TrimSpace(status)) { + case "queued", "pending": + return "queued" + case "in_progress", "processing", "running": + return "in_progress" + case "completed", "done", "succeeded", "success": + return "completed" + case "failed", "error", "expired", "cancelled", "canceled": + return "failed" + default: + return "" + } +} + +func (h *OpenAIAPIHandler) VideosCreate(c *gin.Context) { + rawJSON, err := readVideosCreateRequest(c) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + videoModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) + if videoModel == "" { + videoModel = defaultXAIVideosModel + } + if rejectUnsupportedVideosModel(c, videoModel) { + return + } + + xaiReq, meta, err := buildXAIVideosCreateRequest(rawJSON, videoModel) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + h.collectXAIVideosCreate(c, xaiReq, meta) +} + +func (h *OpenAIAPIHandler) XAIVideosGenerations(c *gin.Context) { + h.handleXAIVideosNativePost(c) +} + +func (h *OpenAIAPIHandler) XAIVideosEdits(c *gin.Context) { + h.handleXAIVideosNativePost(c) +} + +func (h *OpenAIAPIHandler) XAIVideosExtensions(c *gin.Context) { + h.handleXAIVideosNativePost(c) +} + +func (h *OpenAIAPIHandler) handleXAIVideosNativePost(c *gin.Context) { + rawJSON, err := readXAIVideosNativeRequest(c) + if err != nil { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + videoModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String()) + if videoModel == "" { + videoModel = defaultXAIVideosModel + } + if rejectUnsupportedNativeVideosModel(c, videoModel) { + return + } + + h.collectXAIVideosNative(c, rawJSON, videoModel) +} + +func (h *OpenAIAPIHandler) XAIVideosRetrieve(c *gin.Context) { + requestID := strings.TrimSpace(c.Param("request_id")) + if requestID == "" { + requestID = strings.TrimSpace(c.Param("video_id")) + } + if requestID == "" { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: request_id is required", + Type: "invalid_request_error", + }, + }) + return + } + + payload := []byte(`{}`) + payload, _ = sjson.SetBytes(payload, "request_id", requestID) + h.collectXAIVideosNative(c, payload, defaultXAIVideosModel) +} + +func (h *OpenAIAPIHandler) VideosRetrieve(c *gin.Context) { + videoID := strings.TrimSpace(c.Param("video_id")) + if videoID == "" { + c.JSON(http.StatusBadRequest, handlers.ErrorResponse{ + Error: handlers.ErrorDetail{ + Message: "Invalid request: video_id is required", + Type: "invalid_request_error", + }, + }) + return + } + + payload := []byte(`{}`) + payload, _ = sjson.SetBytes(payload, "request_id", videoID) + + c.Header("Content-Type", "application/json") + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiVideosHandlerType, defaultXAIVideosModel, payload, "") + stopKeepAlive() + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + if errMsg.Error != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + } + + out, err := buildVideosRetrieveAPIResponseFromXAI(videoID, resp, defaultXAIVideosModel) + if err != nil { + errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err} + h.WriteErrorResponse(c, errMsg) + cliCancel(err) + return + } + + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + _, _ = c.Writer.Write(out) + cliCancel(nil) +} + +func (h *OpenAIAPIHandler) collectXAIVideosNative(c *gin.Context, rawJSON []byte, model string) { + c.Header("Content-Type", "application/json") + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiVideosHandlerType, model, rawJSON, "") + stopKeepAlive() + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + if errMsg.Error != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + } + + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + _, _ = c.Writer.Write(resp) + cliCancel(nil) +} + +func (h *OpenAIAPIHandler) collectXAIVideosCreate(c *gin.Context, xaiReq []byte, meta xaiVideoCreateMetadata) { + c.Header("Content-Type", "application/json") + + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx) + resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiVideosHandlerType, meta.Model, xaiReq, "") + stopKeepAlive() + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + if errMsg.Error != nil { + cliCancel(errMsg.Error) + } else { + cliCancel(nil) + } + return + } + + out, err := buildVideosCreateAPIResponseFromXAI(resp, meta) + if err != nil { + errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err} + h.WriteErrorResponse(c, errMsg) + cliCancel(err) + return + } + + handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) + _, _ = c.Writer.Write(out) + cliCancel(nil) +} diff --git a/sdk/api/handlers/openai/openai_videos_handlers_test.go b/sdk/api/handlers/openai/openai_videos_handlers_test.go new file mode 100644 index 00000000..d4fed8b4 --- /dev/null +++ b/sdk/api/handlers/openai/openai_videos_handlers_test.go @@ -0,0 +1,227 @@ +package openai + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +func performVideosEndpointRequest(t *testing.T, method string, endpointPath string, contentType string, body io.Reader, handler gin.HandlerFunc) *httptest.ResponseRecorder { + t.Helper() + + gin.SetMode(gin.TestMode) + router := gin.New() + switch method { + case http.MethodGet: + router.GET(endpointPath, handler) + default: + router.POST(endpointPath, handler) + } + + req := httptest.NewRequest(method, endpointPath, body) + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + return resp +} + +func TestVideosModelValidationAllowsXAIVideoModel(t *testing.T) { + for _, model := range []string{"grok-imagine-video", "xai/grok-imagine-video", "x-ai/grok-imagine-video", "grok/grok-imagine-video"} { + if !isSupportedVideosModel(model) { + t.Fatalf("expected %s to be supported", model) + } + } + if isSupportedVideosModel("sora-2") { + t.Fatal("expected sora-2 to be rejected") + } + if isSupportedVideosModel("codex/grok-imagine-video") { + t.Fatal("expected codex/grok-imagine-video to be rejected") + } +} + +func TestBuildXAIVideosCreateRequest(t *testing.T) { + rawJSON := []byte(`{"model":"xai/grok-imagine-video","prompt":"a cat playing piano","seconds":"8","size":"1280x720","input_reference":{"image_url":"https://example.com/cat.png"}}`) + + req, meta, err := buildXAIVideosCreateRequest(rawJSON, "xai/grok-imagine-video") + if err != nil { + t.Fatalf("buildXAIVideosCreateRequest() error = %v", err) + } + + if got := gjson.GetBytes(req, "model").String(); got != defaultXAIVideosModel { + t.Fatalf("model = %q, want %s", got, defaultXAIVideosModel) + } + if got := gjson.GetBytes(req, "prompt").String(); got != "a cat playing piano" { + t.Fatalf("prompt = %q", got) + } + if got := gjson.GetBytes(req, "duration").Int(); got != 8 { + t.Fatalf("duration = %d, want 8", got) + } + if got := gjson.GetBytes(req, "aspect_ratio").String(); got != "16:9" { + t.Fatalf("aspect_ratio = %q, want 16:9", got) + } + if got := gjson.GetBytes(req, "resolution").String(); got != "720p" { + t.Fatalf("resolution = %q, want 720p", got) + } + if got := gjson.GetBytes(req, "image.url").String(); got != "https://example.com/cat.png" { + t.Fatalf("image.url = %q", got) + } + if meta.Seconds != "8" || meta.Size != "1280x720" || meta.Prompt != "a cat playing piano" { + t.Fatalf("unexpected meta: %+v", meta) + } +} + +func TestBuildXAIVideosCreateRequestAllowsCustomSeconds(t *testing.T) { + rawJSON := []byte(`{"model":"grok-imagine-video","prompt":"a cat playing piano","seconds":"6"}`) + + req, meta, err := buildXAIVideosCreateRequest(rawJSON, "grok-imagine-video") + if err != nil { + t.Fatalf("buildXAIVideosCreateRequest() error = %v", err) + } + + if got := gjson.GetBytes(req, "duration").Int(); got != 6 { + t.Fatalf("duration = %d, want 6", got) + } + if meta.Seconds != "6" { + t.Fatalf("meta seconds = %q, want 6", meta.Seconds) + } +} + +func TestBuildXAIVideosCreateRequestRejectsFileIDReference(t *testing.T) { + rawJSON := []byte(`{"prompt":"animate","input_reference":{"file_id":"file_123"}}`) + + _, _, err := buildXAIVideosCreateRequest(rawJSON, defaultXAIVideosModel) + if err == nil || !strings.Contains(err.Error(), "input_reference.file_id is not supported") { + t.Fatalf("error = %v, want unsupported file_id error", err) + } +} + +func TestBuildVideosCreateAPIResponseFromXAI(t *testing.T) { + meta := xaiVideoCreateMetadata{ + Model: defaultXAIVideosModel, + Prompt: "animate", + Seconds: "4", + Size: "720x1280", + CreatedAt: 123, + } + out, err := buildVideosCreateAPIResponseFromXAI([]byte(`{"request_id":"vid_123"}`), meta) + if err != nil { + t.Fatalf("buildVideosCreateAPIResponseFromXAI() error = %v", err) + } + + if got := gjson.GetBytes(out, "id").String(); got != "vid_123" { + t.Fatalf("id = %q, want vid_123", got) + } + if got := gjson.GetBytes(out, "object").String(); got != "video" { + t.Fatalf("object = %q, want video", got) + } + if got := gjson.GetBytes(out, "status").String(); got != "queued" { + t.Fatalf("status = %q, want queued", got) + } + if got := gjson.GetBytes(out, "created_at").Int(); got != 123 { + t.Fatalf("created_at = %d, want 123", got) + } +} + +func TestBuildVideosRetrieveAPIResponseFromXAI(t *testing.T) { + payload := []byte(`{"status":"done","video":{"url":"https://vidgen.x.ai/video.mp4","duration":6,"respect_moderation":true},"model":"grok-imagine-video","usage":{"cost_in_usd_ticks":500000000},"progress":100}`) + + out, err := buildVideosRetrieveAPIResponseFromXAI("vid_123", payload, defaultXAIVideosModel) + if err != nil { + t.Fatalf("buildVideosRetrieveAPIResponseFromXAI() error = %v", err) + } + + if got := gjson.GetBytes(out, "id").String(); got != "vid_123" { + t.Fatalf("id = %q, want vid_123", got) + } + if got := gjson.GetBytes(out, "status").String(); got != "completed" { + t.Fatalf("status = %q, want completed", got) + } + if got := gjson.GetBytes(out, "seconds").String(); got != "6" { + t.Fatalf("seconds = %q, want 6", got) + } + if got := gjson.GetBytes(out, "video.url").String(); got != "https://vidgen.x.ai/video.mp4" { + t.Fatalf("video.url = %q", got) + } + if !gjson.GetBytes(out, "usage").Exists() { + t.Fatalf("usage missing: %s", string(out)) + } +} + +func TestVideosCreateRejectsUnsupportedModel(t *testing.T) { + handler := &OpenAIAPIHandler{} + body := strings.NewReader(`{"model":"sora-2","prompt":"make a video"}`) + + resp := performVideosEndpointRequest(t, http.MethodPost, videosPath, "application/json", body, handler.VideosCreate) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String() + expectedMessage := "Model sora-2 is not supported on " + videosPath + ". Use " + defaultXAIVideosModel + "." + if message != expectedMessage { + t.Fatalf("error message = %q, want %q", message, expectedMessage) + } +} + +func TestXAIVideosNativeRejectsUnsupportedModel(t *testing.T) { + handler := &OpenAIAPIHandler{} + body := strings.NewReader(`{"model":"sora-2","prompt":"make a video"}`) + + resp := performVideosEndpointRequest(t, http.MethodPost, xaiVideosGenerationsAPI, "application/json", body, handler.XAIVideosGenerations) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String() + expectedMessage := "Model sora-2 is not supported on " + xaiVideosGenerationsAPI + ", " + xaiVideosEditsAPI + ", or " + xaiVideosExtensionsAPI + ". Use " + defaultXAIVideosModel + "." + if message != expectedMessage { + t.Fatalf("error message = %q, want %q", message, expectedMessage) + } +} + +func TestXAIVideosNativeRejectsInvalidJSON(t *testing.T) { + handler := &OpenAIAPIHandler{} + body := strings.NewReader(`{"model":`) + + resp := performVideosEndpointRequest(t, http.MethodPost, xaiVideosEditsAPI, "application/json", body, handler.XAIVideosEdits) + + if resp.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d: %s", resp.Code, http.StatusBadRequest, resp.Body.String()) + } + if got := gjson.GetBytes(resp.Body.Bytes(), "error.type").String(); got != "invalid_request_error" { + t.Fatalf("error type = %q, want invalid_request_error", got) + } +} + +func TestVideosCreateFormRequest(t *testing.T) { + rawJSON, err := videosCreateRequestFromFormContext("model=grok-imagine-video&prompt=make+a+video&seconds=4&size=720x1280&input_reference%5Bimage_url%5D=https%3A%2F%2Fexample.com%2Fa.png") + if err != nil { + t.Fatalf("videosCreateRequestFromFormContext() error = %v", err) + } + + if got := gjson.GetBytes(rawJSON, "input_reference.image_url").String(); got != "https://example.com/a.png" { + t.Fatalf("input_reference.image_url = %q", got) + } +} + +func videosCreateRequestFromFormContext(body string) ([]byte, error) { + gin.SetMode(gin.TestMode) + router := gin.New() + var rawJSON []byte + var err error + router.POST(videosPath, func(c *gin.Context) { + rawJSON, err = videosCreateRequestFromForm(c) + }) + req := httptest.NewRequest(http.MethodPost, videosPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + return rawJSON, err +}