feat(api): add OpenAI compatibility for image models

- Introduced OpenAI-compatible image model support in the API, enabling integration through image generation and editing endpoints.
- Added registry type for OpenAIImageModelType to classify and validate compatibility.
- Implemented request handling for OpenAI-compatible image models, including JSON and multipart formats.
- Enhanced executor methods to support OpenAI-compatible image streaming and non-streaming requests.
- Included tests to validate model registration, streaming behavior, and multipart payload formatting.
This commit is contained in:
Luis Pater
2026-05-19 09:36:05 +08:00
parent b67eb6f25d
commit feebe6c7f2
16 changed files with 1962 additions and 37 deletions
+34 -4
View File
@@ -535,7 +535,16 @@ func appendAPIResponse(c *gin.Context, data []byte) {
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
}
// ExecuteImageWithAuthManager executes an OpenAI-compatible image endpoint request.
func (h *BaseAPIHandler) ExecuteImageWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
return h.executeWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
}
func (h *BaseAPIHandler) executeWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
if errMsg != nil {
return nil, nil, errMsg
}
@@ -632,7 +641,16 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
// This path is the only supported execution route.
// The returned http.Header carries upstream response headers captured before streaming begins.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, false)
}
// ExecuteImageStreamWithAuthManager executes a streaming OpenAI-compatible image endpoint request.
func (h *BaseAPIHandler) ExecuteImageStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
return h.executeStreamWithAuthManager(ctx, handlerType, modelName, rawJSON, alt, true)
}
func (h *BaseAPIHandler) executeStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string, allowImageModel bool) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetailsWithOptions(modelName, allowImageModel)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- errMsg
@@ -848,6 +866,10 @@ func statusFromError(err error) int {
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
return h.getRequestDetailsWithOptions(modelName, false)
}
func (h *BaseAPIHandler) getRequestDetailsWithOptions(modelName string, allowImageModel bool) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
@@ -872,10 +894,10 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
if strings.EqualFold(baseModel, "gpt-image-2") {
if strings.EqualFold(routeModelBaseName(baseModel), "gpt-image-2") && !allowImageModel {
return nil, "", &interfaces.ErrorMessage{
StatusCode: http.StatusServiceUnavailable,
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", baseModel),
Error: fmt.Errorf("model %s is only supported on /v1/images/generations and /v1/images/edits", routeModelBaseName(baseModel)),
}
}
@@ -902,6 +924,14 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
return providers, resolvedModelName, nil
}
func routeModelBaseName(model string) string {
model = strings.TrimSpace(model)
if idx := strings.LastIndex(model, "/"); idx >= 0 && idx < len(model)-1 {
return strings.TrimSpace(model[idx+1:])
}
return model
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil
@@ -104,6 +104,9 @@ func applyCodexClientModelMetadata(entry map[string]any, id string, model map[st
if info.ContextLength > 0 {
contextWindow = info.ContextLength
}
if info.Type == registry.OpenAIImageModelType {
entry["visibility"] = "hide"
}
applyCodexClientThinkingMetadata(entry, info.Thinking)
}
@@ -9,6 +9,7 @@ import (
"io"
"mime/multipart"
"net/http"
"net/textproto"
"strconv"
"strings"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
@@ -143,7 +145,20 @@ func isSupportedImagesModel(model string) bool {
if baseModel == defaultImagesToolModel {
return true
}
return isXAIImagesModel(model)
return isXAIImagesModel(model) || isOpenAICompatImagesModel(model)
}
func isDefaultImagesToolModel(model string) bool {
return imagesModelBase(model) == defaultImagesToolModel
}
func isOpenAICompatImagesModel(model string) bool {
model = strings.TrimSpace(model)
if model == "" {
return false
}
info := registry.LookupModelInfo(model)
return info != nil && info.Type == registry.OpenAIImageModelType
}
func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
@@ -153,7 +168,7 @@ func rejectUnsupportedImagesModel(c *gin.Context, model string) bool {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, or %s.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Message: fmt.Sprintf("Model %s is not supported on %s or %s. Use %s, %s, %s, or a configured openai-compatibility image model.", model, imagesGenerationsPath, imagesEditsPath, defaultImagesToolModel, defaultXAIImagesModel, xaiImagesQualityModel),
Type: "invalid_request_error",
},
})
@@ -376,6 +391,90 @@ func multipartFileToDataURL(fileHeader *multipart.FileHeader) (string, error) {
return "data:" + mediaType + ";base64," + b64, nil
}
func buildOpenAICompatImagesJSONRequest(rawJSON []byte, imageModel string, stream bool) []byte {
payload := rawJSON
if model := strings.TrimSpace(imageModel); model != "" {
payload, _ = sjson.SetBytes(payload, "model", model)
}
if stream {
payload, _ = sjson.SetBytes(payload, "stream", true)
} else {
payload, _ = sjson.DeleteBytes(payload, "stream")
}
return payload
}
func cloneMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
dst := make(textproto.MIMEHeader, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func buildOpenAICompatImagesMultipartRequest(form *multipart.Form, imageModel string, stream bool) ([]byte, string, error) {
if form == nil {
return nil, "", fmt.Errorf("multipart form is nil")
}
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", imageModel); errWrite != nil {
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
}
if stream {
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
}
}
for key, values := range form.Value {
if key == "model" || key == "stream" {
continue
}
for _, value := range values {
if errWrite := writer.WriteField(key, value); errWrite != nil {
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
}
}
}
for key, files := range form.File {
for _, fileHeader := range files {
if fileHeader == nil {
continue
}
header := cloneMIMEHeader(fileHeader.Header)
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/octet-stream")
}
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
}
src, errOpen := fileHeader.Open()
if errOpen != nil {
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
}
_, errCopy := io.Copy(part, src)
if errClose := src.Close(); errClose != nil {
log.Errorf("openai images: close upload file error: %v", errClose)
if errCopy == nil {
errCopy = errClose
}
}
if errCopy != nil {
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
}
}
}
if errClose := writer.Close(); errClose != nil {
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
}
return body.Bytes(), writer.FormDataContentType(), nil
}
func parseIntField(raw string, fallback int64) int64 {
raw = strings.TrimSpace(raw)
if raw == "" {
@@ -454,11 +553,21 @@ func (h *OpenAIAPIHandler) ImagesGenerations(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isDefaultImagesToolModel(imageModel) {
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
xaiReq := buildXAIImagesGenerationsRequest(rawJSON, imageModel, responseFormat)
h.handleXAIImages(c, xaiReq, responseFormat, "image_generation", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_generation", stream)
return
}
tool := []byte(`{"type":"image_generation","action":"generate"}`)
tool, _ = sjson.SetBytes(tool, "model", imageModel)
@@ -589,6 +698,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
}
stream := parseBoolField(c.PostForm("stream"), false)
if isDefaultImagesToolModel(imageModel) {
imageReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
if errBuild != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", errBuild),
Type: "invalid_request_error",
},
})
return
}
c.Request.Header.Set("Content-Type", contentType)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
aspectRatio := xaiImagesAspectRatio(c.PostForm("aspect_ratio"), "")
aspectRatio = xaiImagesAspectRatioFromSize(c.PostForm("size"), aspectRatio)
@@ -598,6 +722,21 @@ func (h *OpenAIAPIHandler) imagesEditsFromMultipart(c *gin.Context) {
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, imageModel, stream)
if errBuild != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", errBuild),
Type: "invalid_request_error",
},
})
return
}
c.Request.Header.Set("Content-Type", contentType)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
return
}
var maskDataURL *string
if maskFiles := form.File["mask"]; len(maskFiles) > 0 && maskFiles[0] != nil {
@@ -701,6 +840,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
}
stream := gjson.GetBytes(rawJSON, "stream").Bool()
if isDefaultImagesToolModel(imageModel) {
imageReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleRoutedImages(c, imageReq, imageModel, stream)
return
}
if isXAIImagesModel(imageModel) {
images := collectXAIImagesFromJSON(rawJSON)
if len(images) == 0 {
@@ -717,6 +861,11 @@ func (h *OpenAIAPIHandler) imagesEditsFromJSON(c *gin.Context) {
h.handleXAIImages(c, xaiReq, responseFormat, "image_edit", stream)
return
}
if isOpenAICompatImagesModel(imageModel) {
compatReq := buildOpenAICompatImagesJSONRequest(rawJSON, imageModel, stream)
h.handleOpenAICompatImages(c, compatReq, imageModel, responseFormat, "image_edit", stream)
return
}
var images []string
imagesResult := gjson.GetBytes(rawJSON, "images")
@@ -904,14 +1053,247 @@ func (h *OpenAIAPIHandler) handleXAIImages(c *gin.Context, xaiReq []byte, respon
h.collectXAIImages(c, xaiReq, responseFormat)
}
func (h *OpenAIAPIHandler) handleOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string, responseFormat string, streamPrefix string, stream bool) {
if stream {
h.streamOpenAICompatImages(c, compatReq, imageModel)
return
}
h.collectImagesWithModel(c, compatReq, imageModel, responseFormat)
}
func (h *OpenAIAPIHandler) handleRoutedImages(c *gin.Context, imageReq []byte, imageModel string, stream bool) {
if stream {
h.streamRoutedImages(c, imageReq, imageModel)
return
}
h.collectRoutedImages(c, imageReq, imageModel)
}
func (h *OpenAIAPIHandler) collectRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(imageModel)
resp, upstreamHeaders, errMsg := h.ExecuteImageWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel(nil)
}
func (h *OpenAIAPIHandler) streamRoutedImages(c *gin.Context, imageReq []byte, imageModel string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = handlers.WithDisallowFreeAuth(cliCtx)
model := strings.TrimSpace(imageModel)
dataChan, upstreamHeaders, errChan := h.ExecuteImageStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
return
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(chunk)
flusher.Flush()
h.forwardRawImageStream(cliCtx, c, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
func (h *OpenAIAPIHandler) forwardRawImageStream(ctx context.Context, c *gin.Context, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
emitError := func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case <-ctx.Done():
cancel(ctx.Err())
return
case errMsg, ok := <-errs:
if ok && errMsg != nil {
emitError(errMsg)
cancel(errMsg.Error)
return
}
errs = nil
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
_, _ = c.Writer.Write(chunk)
if flusher, ok := c.Writer.(http.Flusher); ok {
flusher.Flush()
}
}
}
}
func (h *OpenAIAPIHandler) streamOpenAICompatImages(c *gin.Context, compatReq []byte, imageModel string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(imageModel)
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, xaiImagesHandlerType, model, compatReq, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
}
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(chunk)
flusher.Flush()
h.ForwardStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan, handlers.StreamForwardOptions{
WriteChunk: func(next []byte) {
_, _ = c.Writer.Write(next)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
},
})
return
}
}
}
func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, responseFormat string) {
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
h.collectImagesWithModel(c, xaiReq, model, responseFormat)
}
func (h *OpenAIAPIHandler) collectImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string) {
c.Header("Content-Type", "application/json")
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
model = strings.TrimSpace(model)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
@@ -937,6 +1319,11 @@ func (h *OpenAIAPIHandler) collectXAIImages(c *gin.Context, xaiReq []byte, respo
}
func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, responseFormat string, streamPrefix string) {
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
h.streamImagesWithModel(c, xaiReq, model, responseFormat, streamPrefix)
}
func (h *OpenAIAPIHandler) streamImagesWithModel(c *gin.Context, imageReq []byte, model string, responseFormat string, streamPrefix string) {
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
@@ -949,8 +1336,8 @@ func (h *OpenAIAPIHandler) streamXAIImages(c *gin.Context, xaiReq []byte, respon
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
model := strings.TrimSpace(gjson.GetBytes(xaiReq, "model").String())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, xaiReq, "")
model = strings.TrimSpace(model)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, xaiImagesHandlerType, model, imageReq, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
if errMsg.Error != nil {
@@ -3,14 +3,17 @@ package openai
import (
"bytes"
"io"
"mime"
"mime/multipart"
"net/http"
"net/http/httptest"
"net/textproto"
"strings"
"testing"
"github.com/gin-gonic/gin"
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
"github.com/tidwall/gjson"
@@ -40,7 +43,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
}
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "."
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", " + xaiImagesQualityModel + ", or a configured openai-compatibility image model."
if message != expectedMessage {
t.Fatalf("error message = %q, want %q", message, expectedMessage)
}
@@ -63,6 +66,25 @@ func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) {
}
}
func TestImagesModelValidationAllowsOpenAICompatImageModels(t *testing.T) {
modelRegistry := registry.GetGlobalRegistry()
clientID := "test-openai-compat-image-model-validation"
modelRegistry.RegisterClient(clientID, "openai-compatibility", []*registry.ModelInfo{
{ID: "compat-image-model", Object: "model", OwnedBy: "compat", Type: registry.OpenAIImageModelType},
{ID: "compat-chat-model", Object: "model", OwnedBy: "compat", Type: "openai-compatibility"},
})
t.Cleanup(func() {
modelRegistry.UnregisterClient(clientID)
})
if !isSupportedImagesModel("compat-image-model") {
t.Fatal("expected configured openai-compatibility image model to be supported")
}
if isSupportedImagesModel("compat-chat-model") {
t.Fatal("expected non-image openai-compatibility model to be rejected")
}
}
func TestBuildXAIImagesGenerationsRequest(t *testing.T) {
rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`)
@@ -122,6 +144,100 @@ func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) {
}
}
func TestBuildOpenAICompatImagesJSONRequestPreservesStreamForStreaming(t *testing.T) {
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":false}`), "upstream-image", true)
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
}
if !gjson.GetBytes(req, "stream").Bool() {
t.Fatalf("stream flag missing: %s", string(req))
}
}
func TestBuildOpenAICompatImagesJSONRequestDropsStreamForNonStreaming(t *testing.T) {
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":true}`), "upstream-image", false)
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
}
if gjson.GetBytes(req, "stream").Exists() {
t.Fatalf("stream flag should be removed from non-streaming request: %s", string(req))
}
}
func TestBuildOpenAICompatImagesMultipartRequestPreservesStreamAndFileContentType(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
t.Fatalf("write model field: %v", errWrite)
}
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
t.Fatalf("write stream field: %v", errWrite)
}
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
t.Fatalf("write prompt field: %v", errWrite)
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
header.Set("Content-Type", "image/png")
part, errCreate := writer.CreatePart(header)
if errCreate != nil {
t.Fatalf("create image field: %v", errCreate)
}
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
t.Fatalf("write image field: %v", errWrite)
}
if errClose := writer.Close(); errClose != nil {
t.Fatalf("close multipart writer: %v", errClose)
}
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), writer.Boundary())
form, errRead := reader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read source form: %v", errRead)
}
defer func() {
if errRemove := form.RemoveAll(); errRemove != nil {
t.Fatalf("remove source form files: %v", errRemove)
}
}()
out, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, "upstream-image", true)
if errBuild != nil {
t.Fatalf("buildOpenAICompatImagesMultipartRequest error: %v", errBuild)
}
mediaType, params, errParse := mime.ParseMediaType(contentType)
if errParse != nil {
t.Fatalf("parse content type: %v", errParse)
}
if mediaType != "multipart/form-data" {
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
}
rewrittenReader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
rewrittenForm, errRead := rewrittenReader.ReadForm(32 << 20)
if errRead != nil {
t.Fatalf("read rewritten form: %v", errRead)
}
defer func() {
if errRemove := rewrittenForm.RemoveAll(); errRemove != nil {
t.Fatalf("remove rewritten form files: %v", errRemove)
}
}()
if got := rewrittenForm.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
t.Fatalf("model values = %#v, want upstream-image", got)
}
if got := rewrittenForm.Value["stream"]; len(got) != 1 || got[0] != "true" {
t.Fatalf("stream values = %#v, want true", got)
}
if got := rewrittenForm.Value["prompt"]; len(got) != 1 || got[0] != "edit" {
t.Fatalf("prompt values = %#v, want edit", got)
}
if got := rewrittenForm.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/png" {
t.Fatalf("image headers = %#v, want image/png", got)
}
}
func TestBuildImagesAPIResponseFromXAI(t *testing.T) {
payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`)
+38 -24
View File
@@ -1208,30 +1208,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
if strings.EqualFold(compat.Name, compatName) {
isCompatAuth = true
// Convert compatibility models to registry models
ms := make([]*ModelInfo, 0, len(compat.Models))
for j := range compat.Models {
m := compat.Models[j]
// Use alias as model ID, fallback to name if alias is empty
modelID := m.Alias
if modelID == "" {
modelID = m.Name
}
thinking := m.Thinking
if thinking == nil {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
ms = append(ms, &ModelInfo{
ID: modelID,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: compat.Name,
Type: "openai-compatibility",
DisplayName: modelID,
UserDefined: false,
Thinking: thinking,
})
}
ms := buildOpenAICompatibilityConfigModels(compat)
// Register and return
if len(ms) > 0 {
if providerKey == "" {
@@ -1578,6 +1555,43 @@ type modelEntry interface {
GetAlias() string
}
func buildOpenAICompatibilityConfigModels(compat *config.OpenAICompatibility) []*ModelInfo {
if compat == nil || len(compat.Models) == 0 {
return nil
}
now := time.Now().Unix()
models := make([]*ModelInfo, 0, len(compat.Models))
for i := range compat.Models {
model := compat.Models[i]
modelID := strings.TrimSpace(model.Alias)
if modelID == "" {
modelID = strings.TrimSpace(model.Name)
}
if modelID == "" {
continue
}
modelType := "openai-compatibility"
if model.Image {
modelType = registry.OpenAIImageModelType
}
thinking := model.Thinking
if thinking == nil && !model.Image {
thinking = &registry.ThinkingSupport{Levels: []string{"low", "medium", "high"}}
}
models = append(models, &ModelInfo{
ID: modelID,
Object: "model",
Created: now,
OwnedBy: compat.Name,
Type: modelType,
DisplayName: modelID,
UserDefined: false,
Thinking: thinking,
})
}
return models
}
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
if len(models) == 0 {
return nil
@@ -4,6 +4,7 @@ import (
"strings"
"testing"
internalregistry "github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
)
@@ -63,3 +64,71 @@ func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T
t.Fatal("expected global excluded model to be present when attribute override is set")
}
}
func TestRegisterModelsForAuth_OpenAICompatibilityImageModelType(t *testing.T) {
service := &Service{
cfg: &config.Config{
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "images",
BaseURL: "https://example.com/v1",
Models: []config.OpenAICompatibilityModel{
{Name: "upstream-image", Alias: "compat-image", Image: true},
{Name: "upstream-chat", Alias: "compat-chat"},
},
},
},
},
}
auth := &coreauth.Auth{
ID: "auth-openai-compat-image",
Provider: "openai-compatibility",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "api_key",
"compat_name": "images",
"provider_key": "images",
},
}
modelRegistry := internalregistry.GetGlobalRegistry()
modelRegistry.UnregisterClient(auth.ID)
t.Cleanup(func() {
modelRegistry.UnregisterClient(auth.ID)
})
service.registerModelsForAuth(auth)
models := modelRegistry.GetModelsForClient(auth.ID)
var imageModel *internalregistry.ModelInfo
var chatModel *internalregistry.ModelInfo
for _, model := range models {
if model == nil {
continue
}
switch strings.TrimSpace(model.ID) {
case "compat-image":
imageModel = model
case "compat-chat":
chatModel = model
}
}
if imageModel == nil {
t.Fatal("expected compat-image to be registered")
}
if imageModel.Type != internalregistry.OpenAIImageModelType {
t.Fatalf("image model type = %q, want %q", imageModel.Type, internalregistry.OpenAIImageModelType)
}
if imageModel.Thinking != nil {
t.Fatalf("image model thinking = %+v, want nil", imageModel.Thinking)
}
if chatModel == nil {
t.Fatal("expected compat-chat to be registered")
}
if chatModel.Type != "openai-compatibility" {
t.Fatalf("chat model type = %q, want openai-compatibility", chatModel.Type)
}
if chatModel.Thinking == nil {
t.Fatal("expected chat model to keep default thinking support")
}
}