test(api): add validation for unsupported models in OpenAI image handlers
- Introduced tests to ensure unsupported models are rejected in `/images/generations` and `/images/edits`. - Added `isSupportedImagesModel` and `rejectUnsupportedImagesModel` functions for consistent model validation. - Enhanced image handler logic to apply validation checks for model compatibility.
This commit is contained in:
@@ -24,6 +24,8 @@ import (
|
||||
const (
|
||||
defaultImagesMainModel = "gpt-5.4-mini"
|
||||
defaultImagesToolModel = "gpt-image-2"
|
||||
imagesGenerationsPath = "/v1/images/generations"
|
||||
imagesEditsPath = "/v1/images/edits"
|
||||
)
|
||||
|
||||
type imageCallResult struct {
|
||||
@@ -99,6 +101,28 @@ func (a *sseFrameAccumulator) Flush() [][]byte {
|
||||
return frames
|
||||
}
|
||||
|
||||
func isSupportedImagesModel(model string) bool {
|
||||
baseModel := strings.TrimSpace(model)
|
||||
if idx := strings.LastIndex(baseModel, "/"); idx >= 0 && idx < len(baseModel)-1 {
|
||||
baseModel = strings.TrimSpace(baseModel[idx+1:])
|
||||
}
|
||||
return baseModel == defaultImagesToolModel
|
||||
}
|
||||
|
||||
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
|
||||
if isSupportedImagesModel(model) {
|
||||
return false
|
||||
}
|
||||
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
func mimeTypeFromOutputFormat(outputFormat string) string {
|
||||
if outputFormat == "" {
|
||||
return "image/png"
|
||||
@@ -194,6 +218,14 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
|
||||
if imageModel == "" {
|
||||
imageModel = defaultImagesToolModel
|
||||
}
|
||||
if rejectUnsupportedImagesModel(c, imageModel) {
|
||||
return
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
|
||||
if prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
@@ -205,10 +237,6 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
|
||||
if imageModel == "" {
|
||||
imageModel = defaultImagesToolModel
|
||||
}
|
||||
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
@@ -283,6 +311,14 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
imageModel := strings.TrimSpace(c.PostForm("model"))
|
||||
if imageModel == "" {
|
||||
imageModel = defaultImagesToolModel
|
||||
}
|
||||
if rejectUnsupportedImagesModel(c, imageModel) {
|
||||
return
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(c.PostForm("prompt"))
|
||||
if prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
@@ -340,10 +376,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
|
||||
maskDataURL = &dataURL
|
||||
}
|
||||
|
||||
imageModel := strings.TrimSpace(c.PostForm("model"))
|
||||
if imageModel == "" {
|
||||
imageModel = defaultImagesToolModel
|
||||
}
|
||||
responseFormat := strings.TrimSpace(c.PostForm("response_format"))
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
@@ -412,6 +444,14 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
|
||||
if imageModel == "" {
|
||||
imageModel = defaultImagesToolModel
|
||||
}
|
||||
if rejectUnsupportedImagesModel(c, imageModel) {
|
||||
return
|
||||
}
|
||||
|
||||
prompt := strings.TrimSpace(gjson.GetBytes(rawJSON, "prompt").String())
|
||||
if prompt == "" {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
@@ -460,10 +500,6 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
imageModel := strings.TrimSpace(gjson.GetBytes(rawJSON, "model").String())
|
||||
if imageModel == "" {
|
||||
imageModel = defaultImagesToolModel
|
||||
}
|
||||
responseFormat := strings.TrimSpace(gjson.GetBytes(rawJSON, "response_format").String())
|
||||
if responseFormat == "" {
|
||||
responseFormat = "b64_json"
|
||||
|
||||
Reference in New Issue
Block a user