feat(api, xai): integrate xAI Grok image models and extend API endpoints for image support

- Added new xAI Grok image models (`grok-imagine-image`, `grok-imagine-image-quality`) with high-fidelity and aspect ratio configurations.
- Extended `isSupportedImagesModel` logic to validate xAI models.
- Implemented API request builders for image generation/editing with customizable options (e.g., resolution, aspect ratio, response format).
- Enhanced `/v1/images` endpoints to handle xAI model capabilities, including response normalization and model-specific handlers.
- Updated unit tests to validate xAI model validation, request structure, and API integration.
This commit is contained in:
Luis Pater
2026-05-17 01:30:23 +08:00
parent e4c957078c
commit 2ff9e33e26
9 changed files with 779 additions and 43 deletions
+445 -22
View File
@@ -23,10 +23,15 @@ import (
)
const (
defaultImagesMainModel = "gpt-5.4-mini"
defaultImagesToolModel = "gpt-image-2"
imagesGenerationsPath = "/v1/images/generations"
imagesEditsPath = "/v1/images/edits"
defaultImagesMainModel = "gpt-5.4-mini"
defaultImagesToolModel = "gpt-image-2"
defaultXAIImagesModel = "grok-imagine-image"
xaiImagesQualityModel = "grok-imagine-image-quality"
xaiImagesHandlerType = "openai-image"
xaiImagesDefaultAspectRatio = "1:1"
xaiImagesDefaultResolution = "1k"
imagesGenerationsPath = "/v1/images/generations"
imagesEditsPath = "/v1/images/edits"
)
type imageCallResult struct {
@@ -42,6 +47,13 @@ type sseFrameAccumulator struct {
pending []byte
}
type xaiImageResult struct {
B64JSON string
URL string
RevisedPrompt string
MimeType string
}
func (a *sseFrameAccumulator) AddChunk(chunk []byte) [][]byte {
if len(chunk) == 0 {
return nil
@@ -102,12 +114,36 @@ func (a *sseFrameAccumulator) Flush() [][]byte {
return frames
}
func isSupportedImagesModel(model string) bool {
baseModel := strings.TrimSpace(model)
if idx := strings.LastIndex(baseModel, "/"); idx >= 0 && idx < len(baseModel)-1 {
baseModel = strings.TrimSpace(baseModel[idx+1:])
func imagesModelParts(model string) (prefix string, baseModel string) {
model = strings.TrimSpace(model)
if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 {
return strings.TrimSpace(model[:idx]), strings.TrimSpace(model[idx+1:])
}
return baseModel == defaultImagesToolModel
return "", model
}
func imagesModelBase(model string) string {
_, baseModel := imagesModelParts(model)
return strings.ToLower(strings.TrimSpace(baseModel))
}
func isXAIImagesModel(model string) bool {
prefix, baseModel := imagesModelParts(model)
baseModel = strings.ToLower(strings.TrimSpace(baseModel))
if baseModel != defaultXAIImagesModel && baseModel != xaiImagesQualityModel {
return false
}
prefix = strings.ToLower(strings.TrimSpace(prefix))
return prefix == "" || prefix == "xai" || prefix == "x-ai" || prefix == "grok"
}
func isSupportedImagesModel(model string) bool {
baseModel := imagesModelBase(model)
if baseModel == defaultImagesToolModel {
return true
}
return isXAIImagesModel(model)
}
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
@@ -117,13 +153,182 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel),
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Type: "invalid_request_error",
},
})
return true
}
func normalizeImagesResponseFormat(responseFormat string) string {
if strings.EqualFold(strings.TrimSpace(responseFormat), "url") {
return "url"
}
return "b64_json"
}
func canonicalXAIImagesModel(model string) string {
baseModel := imagesModelBase(model)
if baseModel == xaiImagesQualityModel {
return xaiImagesQualityModel
}
return defaultXAIImagesModel
}
func xaiImagesAspectRatio(raw string, fallback string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1:1", "square":
return "1:1"
case "16:9", "landscape":
return "16:9"
case "9:16", "portrait":
return "9:16"
case "4:3":
return "4:3"
case "3:4":
return "3:4"
case "3:2":
return "3:2"
case "2:3":
return "2:3"
default:
return fallback
}
}
func xaiImagesAspectRatioFromSize(size string, fallback string) string {
size = strings.ToLower(strings.TrimSpace(size))
switch size {
case "1024x1024", "2048x2048", "1:1":
return "1:1"
case "1792x1024", "16:9":
return "16:9"
case "1024x1792", "9:16":
return "9:16"
case "1536x1024", "3:2":
return "3:2"
case "1024x1536", "2:3":
return "2:3"
default:
return fallback
}
}
func xaiImagesResolution(raw string, size string, fallback string) string {
switch strings.ToLower(strings.TrimSpace(raw)) {
case "1k", "2k":
return strings.ToLower(strings.TrimSpace(raw))
}
if strings.Contains(strings.ToLower(strings.TrimSpace(size)), "2048") {
return "2k"
}
return fallback
}
func xaiImagesRef(imageURL string) []byte {
ref := []byte(`{"type":"image_url","url":""}`)
ref, _ = sjson.SetBytes(ref, "url", strings.TrimSpace(imageURL))
return ref
}
func buildXAIImagesBaseRequest(model string, prompt string, responseFormat string, aspectRatio string, resolution string, n int64) []byte {
req := []byte(`{}`)
req, _ = sjson.SetBytes(req, "model", canonicalXAIImagesModel(model))
req, _ = sjson.SetBytes(req, "prompt", strings.TrimSpace(prompt))
req, _ = sjson.SetBytes(req, "response_format", normalizeImagesResponseFormat(responseFormat))
if aspectRatio != "" {
req, _ = sjson.SetBytes(req, "aspect_ratio", aspectRatio)
}
if resolution != "" {
req, _ = sjson.SetBytes(req, "resolution", resolution)
}
if n > 0 {
req, _ = sjson.SetBytes(req, "n", n)
}
return req
}
func buildXAIImagesGenerationsRequest(rawJSON []byte, model string, responseFormat string) []byte {
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
size := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String())
aspectRatio := xaiImagesAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), "")
aspectRatio = xaiImagesAspectRatioFromSize(size, aspectRatio)
if aspectRatio == "" {
aspectRatio = xaiImagesDefaultAspectRatio
}
resolution := xaiImagesResolution(gjson.GetBytes(rawJSON, "resolution").String(), size, xaiImagesDefaultResolution)
n := int64(0)
if v := gjson.GetBytes(rawJSON, "n"); v.Exists() && v.Type == gjson.Number {
n = v.Int()
}
return buildXAIImagesBaseRequest(model, prompt, responseFormat, aspectRatio, resolution, n)
}
func buildXAIImagesEditRequest(model string, prompt string, images []string, responseFormat string, aspectRatio string, resolution string, n int64) []byte {
req := buildXAIImagesBaseRequest(model, prompt, responseFormat, aspectRatio, resolution, n)
trimmedImages := make([]string, 0, len(images))
for _, img := range images {
if strings.TrimSpace(img) != "" {
trimmedImages = append(trimmedImages, strings.TrimSpace(img))
}
}
if len(trimmedImages) == 1 {
req, _ = sjson.SetRawBytes(req, "image", xaiImagesRef(trimmedImages[0]))
return req
}
for _, img := range trimmedImages {
req, _ = sjson.SetRawBytes(req, "images.-1", xaiImagesRef(img))
}
return req
}
func collectXAIImagesFromJSON(rawJSON []byte) []string {
var images []string
appendImage := func(url string) {
url = strings.TrimSpace(url)
if url != "" {
images = append(images, url)
}
}
if image := gjson.GetBytes(rawJSON, "image"); image.Exists() {
if image.Type == gjson.String {
appendImage(image.String())
} else if image.Type == gjson.JSON {
appendImage(image.Get("image_url.url").String())
if imageURL := image.Get("image_url"); imageURL.Type == gjson.String {
appendImage(imageURL.String())
}
appendImage(image.Get("url").String())
}
}
if imagesResult := gjson.GetBytes(rawJSON, "images"); imagesResult.IsArray() {
for _, img := range imagesResult.Array() {
if img.Type == gjson.String {
appendImage(img.String())
continue
}
appendImage(img.Get("image_url.url").String())
if imageURL := img.Get("image_url"); imageURL.Type == gjson.String {
appendImage(imageURL.String())
}
appendImage(img.Get("url").String())
}
}
return images
}
func xaiImagesEditOptionsFromJSON(rawJSON []byte) (aspectRatio string, resolution string, n int64) {
size := strings.TrimSpace(gjson.GetBytes(rawJSON, "size").String())
aspectRatio = xaiImagesAspectRatio(gjson.GetBytes(rawJSON, "aspect_ratio").String(), "")
aspectRatio = xaiImagesAspectRatioFromSize(size, aspectRatio)
resolution = xaiImagesResolution(gjson.GetBytes(rawJSON, "resolution").String(), size, "")
if v := gjson.GetBytes(rawJSON, "n"); v.Exists() && v.Type == gjson.Number {
n = v.Int()
}
return aspectRatio, resolution, n
}
func mimeTypeFromOutputFormat(outputFormat string) string {
if outputFormat == "" {
return "image/png"
@@ -249,6 +454,12 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isXAIImagesModel(imageModel) {
xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat)
h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream)
return
}
tool := []byte(`{"type":"image_generation","action":"generate"}`)
tool, _ = sjson.SetBytes(tool, "model", imageModel)
@@ -372,6 +583,22 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
images = append(images, dataURL)
}
responseFormat := strings.TrimSpace(c.PostForm("response_format"))
if responseFormat == "" {
responseFormat = "b64_json"
}
stream := parseBoolField(c.PostForm("stream"), false)
if isXAIImagesModel(imageModel) {
aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "")
aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio)
resolution := xaiImagesResolution(c.PostForm("resolution"), c.PostForm("size"), "")
n := parseIntField(c.PostForm("n"), 0)
xaiReq := buildXAIImagesEditRequest(imageModel, prompt, images, responseFormat, aspectRatio, resolution, n)
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
var maskDataURL *string
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
dataURL, err := multipartFileToDataURL(maskFiles[0])
@@ -387,12 +614,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
maskDataURL = &dataURL
}
responseFormat := strings.TrimSpace(c.PostForm("response_format"))
if responseFormat == "" {
responseFormat = "b64_json"
}
stream := parseBoolField(c.PostForm("stream"), false)
tool := []byte(`{"type":"image_generation","action":"edit"}`)
tool, _ = sjson.SetBytes(tool, "model", imageModel)
@@ -474,6 +695,29 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
return
}
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
if responseFormat == "" {
responseFormat = "b64_json"
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isXAIImagesModel(imageModel) {
images := collectXAIImagesFromJSON(rawJSON)
if len(images) == 0 {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Invalid request: image is required",
Type: "invalid_request_error",
},
})
return
}
aspectRatio, resolution, n := xaiImagesEditOptionsFromJSON(rawJSON)
xaiReq := buildXAIImagesEditRequest(imageModel, prompt, images, responseFormat, aspectRatio, resolution, n)
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
var images []string
imagesResult := gjson.GetBytes(rawJSON, "images")
if imagesResult.IsArray() {
@@ -511,12 +755,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
return
}
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
if responseFormat == "" {
responseFormat = "b64_json"
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
tool := []byte(`{"type":"image_generation","action":"edit"}`)
tool, _ = sjson.SetBytes(tool, "model", imageModel)
@@ -580,6 +818,191 @@ func buildImagesResponsesRequest(prompt string, images []string, toolJSON []byte
return req
}
func extractXAIImagesResponse(payload []byte) (results []xaiImageResult, createdAt int64, usageRaw []byte, err error) {
if !json.Valid(payload) {
return nil, 0, nil, fmt.Errorf("upstream returned invalid image response JSON")
}
createdAt = gjson.GetBytes(payload, "created").Int()
if createdAt <= 0 {
createdAt = time.Now().Unix()
}
data := gjson.GetBytes(payload, "data")
if data.IsArray() {
for _, item := range data.Array() {
result := xaiImageResult{
B64JSON: strings.TrimSpace(item.Get("b64_json").String()),
URL: strings.TrimSpace(item.Get("url").String()),
RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()),
MimeType: strings.TrimSpace(item.Get("mime_type").String()),
}
if result.MimeType == "" {
result.MimeType = mimeTypeFromOutputFormat(strings.TrimSpace(item.Get("output_format").String()))
}
if result.MimeType == "" {
result.MimeType = "image/png"
}
if result.B64JSON == "" && result.URL == "" {
continue
}
results = append(results, result)
}
}
if len(results) == 0 {
return nil, 0, nil, fmt.Errorf("upstream did not return image output")
}
if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && usage.IsObject() {
usageRaw = []byte(usage.Raw)
}
return results, createdAt, usageRaw, nil
}
func buildImagesAPIResponseFromXAI(payload []byte, responseFormat string) ([]byte, error) {
results, createdAt, usageRaw, err := extractXAIImagesResponse(payload)
if err != nil {
return nil, err
}
out := []byte(`{"created":0,"data":[]}`)
out, _ = sjson.SetBytes(out, "created", createdAt)
responseFormat = normalizeImagesResponseFormat(responseFormat)
for _, img := range results {
item := []byte(`{}`)
if responseFormat == "url" {
if img.URL != "" {
item, _ = sjson.SetBytes(item, "url", img.URL)
} else {
item, _ = sjson.SetBytes(item, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON)
}
} else if img.B64JSON != "" {
item, _ = sjson.SetBytes(item, "b64_json", img.B64JSON)
} else {
item, _ = sjson.SetBytes(item, "url", img.URL)
}
if img.RevisedPrompt != "" {
item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt)
}
out, _ = sjson.SetRawBytes(out, "data.-1", item)
}
if len(usageRaw) > 0 && json.Valid(usageRaw) {
out, _ = sjson.SetRawBytes(out, "usage", usageRaw)
}
return out, nil
}
func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string, stream bool) {
if stream {
h.streamXAIImages(c, xaiReq, responseFormat, streamPrefix)
return
}
h.collectXAIImages(c, xaiReq, responseFormat)
}
func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
out, err := buildImagesAPIResponseFromXAI(resp, responseFormat)
if err != nil {
errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}
h.WriteErrorResponse(c, errMsg)
cliCancel(err)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(out)
cliCancel(nil)
}
func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
results, _, usageRaw, err := extractXAIImagesResponse(resp)
if err != nil {
errMsg := &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err}
h.WriteErrorResponse(c, errMsg)
cliCancel(err)
return
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
eventName := streamPrefix + ".completed"
responseFormat = normalizeImagesResponseFormat(responseFormat)
for _, img := range results {
data := []byte(`{"type":""}`)
data, _ = sjson.SetBytes(data, "type", eventName)
if responseFormat == "url" {
if img.URL != "" {
data, _ = sjson.SetBytes(data, "url", img.URL)
} else {
data, _ = sjson.SetBytes(data, "url", "data:"+mimeTypeFromOutputFormat(img.MimeType)+";base64,"+img.B64JSON)
}
} else if img.B64JSON != "" {
data, _ = sjson.SetBytes(data, "b64_json", img.B64JSON)
} else {
data, _ = sjson.SetBytes(data, "url", img.URL)
}
if len(usageRaw) > 0 && json.Valid(usageRaw) {
data, _ = sjson.SetRawBytes(data, "usage", usageRaw)
}
if strings.TrimSpace(eventName) != "" {
_, _ = fmt.Fprintf(c.Writer, "event: %s\n", eventName)
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(data))
flusher.Flush()
}
cliCancel(nil)
}
func (h *OpenAIAPIHandler) collectImagesFromResponses(c *gin.Context, responsesReq []byte, responseFormat string) {
c.Header("Content-Type", "application/json")