feat(api): add OpenAI compatibility for image models
- Introduced OpenAI-compatible image model support in the API, enabling integration through image generation and editing endpoints. - Added registry type for OpenAIImageModelType to classify and validate compatibility. - Implemented request handling for OpenAI-compatible image models, including JSON and multipart formats. - Enhanced executor methods to support OpenAI-compatible image streaming and non-streaming requests. - Included tests to validate model registration, streaming behavior, and multipart payload formatting.
This commit is contained in:
@@ -3,14 +3,17 @@ package openai
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v7/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v7/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v7/sdk/config"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -40,7 +43,7 @@ func assertUnsupportedImagesModelResponse(t *testing.T, resp *httptest.ResponseR
|
||||
}
|
||||
|
||||
message := gjson.GetBytes(resp.Body.Bytes(), "error.message").String()
|
||||
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", or " + xaiImagesQualityModel + "."
|
||||
expectedMessage := "Model " + model + " is not supported on " + imagesGenerationsPath + " or " + imagesEditsPath + ". Use " + defaultImagesToolModel + ", " + defaultXAIImagesModel + ", " + xaiImagesQualityModel + ", or a configured openai-compatibility image model."
|
||||
if message != expectedMessage {
|
||||
t.Fatalf("error message = %q, want %q", message, expectedMessage)
|
||||
}
|
||||
@@ -63,6 +66,25 @@ func TestImagesModelValidationAllowsGPTImage2AndXAIModels(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestImagesModelValidationAllowsOpenAICompatImageModels(t *testing.T) {
|
||||
modelRegistry := registry.GetGlobalRegistry()
|
||||
clientID := "test-openai-compat-image-model-validation"
|
||||
modelRegistry.RegisterClient(clientID, "openai-compatibility", []*registry.ModelInfo{
|
||||
{ID: "compat-image-model", Object: "model", OwnedBy: "compat", Type: registry.OpenAIImageModelType},
|
||||
{ID: "compat-chat-model", Object: "model", OwnedBy: "compat", Type: "openai-compatibility"},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
modelRegistry.UnregisterClient(clientID)
|
||||
})
|
||||
|
||||
if !isSupportedImagesModel("compat-image-model") {
|
||||
t.Fatal("expected configured openai-compatibility image model to be supported")
|
||||
}
|
||||
if isSupportedImagesModel("compat-chat-model") {
|
||||
t.Fatal("expected non-image openai-compatibility model to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildXAIImagesGenerationsRequest(t *testing.T) {
|
||||
rawJSON := []byte(`{"model":"xai/grok-imagine-image-quality","prompt":"abstract art","aspect_ratio":"landscape","resolution":"2k","n":2,"response_format":"url"}`)
|
||||
|
||||
@@ -122,6 +144,100 @@ func TestBuildXAIImagesEditRequestSingleImage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatImagesJSONRequestPreservesStreamForStreaming(t *testing.T) {
|
||||
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":false}`), "upstream-image", true)
|
||||
|
||||
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
|
||||
}
|
||||
if !gjson.GetBytes(req, "stream").Bool() {
|
||||
t.Fatalf("stream flag missing: %s", string(req))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatImagesJSONRequestDropsStreamForNonStreaming(t *testing.T) {
|
||||
req := buildOpenAICompatImagesJSONRequest([]byte(`{"model":"compat-image","prompt":"draw","stream":true}`), "upstream-image", false)
|
||||
|
||||
if got := gjson.GetBytes(req, "model").String(); got != "upstream-image" {
|
||||
t.Fatalf("model = %q, want upstream-image; body=%s", got, string(req))
|
||||
}
|
||||
if gjson.GetBytes(req, "stream").Exists() {
|
||||
t.Fatalf("stream flag should be removed from non-streaming request: %s", string(req))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAICompatImagesMultipartRequestPreservesStreamAndFileContentType(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if errWrite := writer.WriteField("model", "compat-image"); errWrite != nil {
|
||||
t.Fatalf("write model field: %v", errWrite)
|
||||
}
|
||||
if errWrite := writer.WriteField("stream", "false"); errWrite != nil {
|
||||
t.Fatalf("write stream field: %v", errWrite)
|
||||
}
|
||||
if errWrite := writer.WriteField("prompt", "edit"); errWrite != nil {
|
||||
t.Fatalf("write prompt field: %v", errWrite)
|
||||
}
|
||||
header := make(textproto.MIMEHeader)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition("image", "image.png"))
|
||||
header.Set("Content-Type", "image/png")
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
t.Fatalf("create image field: %v", errCreate)
|
||||
}
|
||||
if _, errWrite := part.Write([]byte("png-data")); errWrite != nil {
|
||||
t.Fatalf("write image field: %v", errWrite)
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
t.Fatalf("close multipart writer: %v", errClose)
|
||||
}
|
||||
|
||||
reader := multipart.NewReader(bytes.NewReader(body.Bytes()), writer.Boundary())
|
||||
form, errRead := reader.ReadForm(32 << 20)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read source form: %v", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := form.RemoveAll(); errRemove != nil {
|
||||
t.Fatalf("remove source form files: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
|
||||
out, contentType, errBuild := buildOpenAICompatImagesMultipartRequest(form, "upstream-image", true)
|
||||
if errBuild != nil {
|
||||
t.Fatalf("buildOpenAICompatImagesMultipartRequest error: %v", errBuild)
|
||||
}
|
||||
mediaType, params, errParse := mime.ParseMediaType(contentType)
|
||||
if errParse != nil {
|
||||
t.Fatalf("parse content type: %v", errParse)
|
||||
}
|
||||
if mediaType != "multipart/form-data" {
|
||||
t.Fatalf("media type = %q, want multipart/form-data", mediaType)
|
||||
}
|
||||
rewrittenReader := multipart.NewReader(bytes.NewReader(out), params["boundary"])
|
||||
rewrittenForm, errRead := rewrittenReader.ReadForm(32 << 20)
|
||||
if errRead != nil {
|
||||
t.Fatalf("read rewritten form: %v", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := rewrittenForm.RemoveAll(); errRemove != nil {
|
||||
t.Fatalf("remove rewritten form files: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
if got := rewrittenForm.Value["model"]; len(got) != 1 || got[0] != "upstream-image" {
|
||||
t.Fatalf("model values = %#v, want upstream-image", got)
|
||||
}
|
||||
if got := rewrittenForm.Value["stream"]; len(got) != 1 || got[0] != "true" {
|
||||
t.Fatalf("stream values = %#v, want true", got)
|
||||
}
|
||||
if got := rewrittenForm.Value["prompt"]; len(got) != 1 || got[0] != "edit" {
|
||||
t.Fatalf("prompt values = %#v, want edit", got)
|
||||
}
|
||||
if got := rewrittenForm.File["image"]; len(got) != 1 || got[0].Header.Get("Content-Type") != "image/png" {
|
||||
t.Fatalf("image headers = %#v, want image/png", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildImagesAPIResponseFromXAI(t *testing.T) {
|
||||
payload := []byte(`{"created":123,"data":[{"b64_json":"AA==","revised_prompt":"refined","mime_type":"image/png"}],"usage":{"total_tokens":0}}`)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user