test(api): add validation for unsupported models in OpenAI image handlers

- Introduced tests to ensure unsupported models are rejected in `/images/generations` and `/images/edits`.
- Added `isSupportedImagesModel` and `rejectUnsupportedImagesModel` functions for consistent model validation.
- Enhanced image handler logic to apply validation checks for model compatibility.
This commit is contained in:
Luis Pater
2026-04-28 17:19:12 +08:00
parent 34027da7f1
commit 9fb6a49260
2 changed files with 143 additions and 12 deletions
@@ -24,6 +24,8 @@ import (
const (
defaultImagesMainModel = "gpt-5.4-mini"
defaultImagesToolModel = "gpt-image-2"
imagesGenerationsPath = "/v1/images/generations"
imagesEditsPath = "/v1/images/edits"
)
type imageCallResult struct {
@@ -99,6 +101,28 @@ func (a *sseFrameAccumulator) Flush() [][]byte {
return frames
}
func isSupportedImagesModel(model string) bool {
baseModel := strings.TrimSpace(model)
if idx := strings.LastIndex(baseModel, "/"); idx >= 0 && idx < len(baseModel)-1 {
baseModel = strings.TrimSpace(baseModel[idx+1:])
}
return baseModel == defaultImagesToolModel
}
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
if isSupportedImagesModel(model) {
return false
}
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel),
Type: "invalid_request_error",
},
})
return true
}
func mimeTypeFromOutputFormat(outputFormat string) string {
if outputFormat == "" {
return "image/png"
@@ -194,6 +218,14 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
return
}
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
if imageModel == "" {
imageModel = defaultImagesToolModel
}
if rejectUnsupportedImagesModel(c, imageModel) {
return
}
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
if prompt == "" {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
@@ -205,10 +237,6 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
return
}
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
if imageModel == "" {
imageModel = defaultImagesToolModel
}
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
if responseFormat == "" {
responseFormat = "b64_json"
@@ -283,6 +311,14 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
return
}
imageModel := strings.TrimSpace(c.PostForm("model"))
if imageModel == "" {
imageModel = defaultImagesToolModel
}
if rejectUnsupportedImagesModel(c, imageModel) {
return
}
prompt := strings.TrimSpace(c.PostForm("prompt"))
if prompt == "" {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
@@ -340,10 +376,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
maskDataURL = &dataURL
}
imageModel := strings.TrimSpace(c.PostForm("model"))
if imageModel == "" {
imageModel = defaultImagesToolModel
}
responseFormat := strings.TrimSpace(c.PostForm("response_format"))
if responseFormat == "" {
responseFormat = "b64_json"
@@ -412,6 +444,14 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
return
}
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
if imageModel == "" {
imageModel = defaultImagesToolModel
}
if rejectUnsupportedImagesModel(c, imageModel) {
return
}
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
if prompt == "" {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
@@ -460,10 +500,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
return
}
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
if imageModel == "" {
imageModel = defaultImagesToolModel
}
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
if responseFormat == "" {
responseFormat = "b64_json"