feat(api): add OpenAI compatibility for image models
- Introduced OpenAI-compatible image model support in the API, enabling integration through image generation and editing endpoints. - Added registry type for OpenAIImageModelType to classify and validate compatibility. - Implemented request handling for OpenAI-compatible image models, including JSON and multipart formats. - Enhanced executor methods to support OpenAI-compatible image streaming and non-streaming requests. - Included tests to validate model registration, streaming behavior, and multipart payload formatting.
This commit is contained in:
@@ -4,9 +4,13 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,6 +25,14 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
openAICompatImageHandlerType = "openai-image"
|
||||
openAICompatImagesGenerationsPath = "/images/generations"
|
||||
openAICompatImagesEditsPath = "/images/edits"
|
||||
openAICompatDefaultImageEndpoint = openAICompatImagesGenerationsPath
|
||||
openAICompatMultipartMemory int64 = 32 << 20
|
||||
)
|
||||
|
||||
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
|
||||
// It performs request/response translation and executes against the provider base URL
|
||||
// using per-auth credentials (API key) and per-auth HTTP transport (proxy) from context.
|
||||
@@ -71,6 +83,10 @@ func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
|
||||
return e.executeImages(ctx, auth, req, opts, endpointPath)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -179,7 +195,98 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) executeImages(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
baseURL, apiKey := e.resolveCredentials(auth)
|
||||
if baseURL == "" {
|
||||
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), false)
|
||||
if errPrepare != nil {
|
||||
err = errPrepare
|
||||
return resp, err
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + endpointPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
body, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(body)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
reporter.Publish(ctx, helps.ParseOpenAIUsage(body))
|
||||
reporter.EnsurePublished(ctx)
|
||||
resp = cliproxyexecutor.Response{Payload: body, Headers: httpResp.Header.Clone()}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
if endpointPath := openAICompatImageEndpointPath(opts); endpointPath != "" {
|
||||
return e.executeImagesStream(ctx, auth, req, opts, endpointPath)
|
||||
}
|
||||
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
@@ -342,6 +449,121 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) executeImagesStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, endpointPath string) (_ *cliproxyexecutor.StreamResult, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.TrackFailure(ctx, &err)
|
||||
|
||||
baseURL, apiKey := e.resolveCredentials(auth)
|
||||
if baseURL == "" {
|
||||
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload, contentType, errPrepare := prepareOpenAICompatImagesPayload(req.Payload, baseModel, opts.Headers.Get("Content-Type"), true)
|
||||
if errPrepare != nil {
|
||||
err = errPrepare
|
||||
return nil, err
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
url := strings.TrimSuffix(baseURL, "/") + endpointPath
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", contentType)
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
httpReq.Header.Set("Cache-Control", "no-cache")
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
|
||||
var attrs map[string]string
|
||||
if auth != nil {
|
||||
attrs = auth.Attributes
|
||||
}
|
||||
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
body, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return nil, errRead
|
||||
}
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, body)
|
||||
helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), body))
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(body)}
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close response body error: %v", errClose)
|
||||
}
|
||||
reporter.EnsurePublished(ctx)
|
||||
}()
|
||||
buffer := make([]byte, 32*1024)
|
||||
for {
|
||||
n, errRead := httpResp.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
chunk := bytes.Clone(buffer[:n])
|
||||
helps.AppendAPIResponseChunk(ctx, e.cfg, chunk)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Payload: chunk}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
if errRead != nil {
|
||||
if errRead != io.EOF {
|
||||
helps.RecordAPIResponseError(ctx, e.cfg, errRead)
|
||||
reporter.PublishFailure(ctx, errRead)
|
||||
select {
|
||||
case out <- cliproxyexecutor.StreamChunk{Err: errRead}:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
@@ -380,6 +602,124 @@ func (e *OpenAICompatExecutor) Refresh(ctx context.Context, auth *cliproxyauth.A
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func openAICompatImageEndpointPath(opts cliproxyexecutor.Options) string {
|
||||
if opts.SourceFormat.String() != openAICompatImageHandlerType {
|
||||
return ""
|
||||
}
|
||||
path := helps.PayloadRequestPath(opts)
|
||||
if strings.HasSuffix(path, "/images/edits") {
|
||||
return openAICompatImagesEditsPath
|
||||
}
|
||||
if strings.HasSuffix(path, "/images/generations") {
|
||||
return openAICompatImagesGenerationsPath
|
||||
}
|
||||
return openAICompatDefaultImageEndpoint
|
||||
}
|
||||
|
||||
func prepareOpenAICompatImagesPayload(payload []byte, model string, contentType string, stream bool) ([]byte, string, error) {
|
||||
model = strings.TrimSpace(model)
|
||||
contentType = strings.TrimSpace(contentType)
|
||||
if json.Valid(payload) {
|
||||
if model != "" {
|
||||
payload, _ = sjson.SetBytes(payload, "model", model)
|
||||
}
|
||||
if stream {
|
||||
payload, _ = sjson.SetBytes(payload, "stream", true)
|
||||
} else {
|
||||
payload, _ = sjson.DeleteBytes(payload, "stream")
|
||||
}
|
||||
return payload, "application/json", nil
|
||||
}
|
||||
|
||||
mediaType, params, errParse := mime.ParseMediaType(contentType)
|
||||
if errParse != nil || !strings.HasPrefix(strings.ToLower(strings.TrimSpace(mediaType)), "multipart/") {
|
||||
return payload, contentType, nil
|
||||
}
|
||||
boundary := strings.TrimSpace(params["boundary"])
|
||||
if boundary == "" {
|
||||
return nil, "", fmt.Errorf("multipart boundary is missing")
|
||||
}
|
||||
return rewriteOpenAICompatImagesMultipartPayload(payload, model, boundary, stream)
|
||||
}
|
||||
|
||||
func cloneOpenAICompatMIMEHeader(src textproto.MIMEHeader) textproto.MIMEHeader {
|
||||
dst := make(textproto.MIMEHeader, len(src))
|
||||
for key, values := range src {
|
||||
dst[key] = append([]string(nil), values...)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func rewriteOpenAICompatImagesMultipartPayload(payload []byte, model string, boundary string, stream bool) ([]byte, string, error) {
|
||||
reader := multipart.NewReader(bytes.NewReader(payload), boundary)
|
||||
form, errRead := reader.ReadForm(openAICompatMultipartMemory)
|
||||
if errRead != nil {
|
||||
return nil, "", fmt.Errorf("read multipart form failed: %w", errRead)
|
||||
}
|
||||
defer func() {
|
||||
if errRemove := form.RemoveAll(); errRemove != nil {
|
||||
log.Errorf("openai compat executor: remove multipart form files error: %v", errRemove)
|
||||
}
|
||||
}()
|
||||
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
if model != "" {
|
||||
if errWrite := writer.WriteField("model", model); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write model field failed: %w", errWrite)
|
||||
}
|
||||
}
|
||||
if stream {
|
||||
if errWrite := writer.WriteField("stream", "true"); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write stream field failed: %w", errWrite)
|
||||
}
|
||||
}
|
||||
for key, values := range form.Value {
|
||||
if key == "model" || key == "stream" {
|
||||
continue
|
||||
}
|
||||
for _, value := range values {
|
||||
if errWrite := writer.WriteField(key, value); errWrite != nil {
|
||||
return nil, "", fmt.Errorf("write form field %s failed: %w", key, errWrite)
|
||||
}
|
||||
}
|
||||
}
|
||||
for key, files := range form.File {
|
||||
for _, fileHeader := range files {
|
||||
if fileHeader == nil {
|
||||
continue
|
||||
}
|
||||
header := cloneOpenAICompatMIMEHeader(fileHeader.Header)
|
||||
header.Set("Content-Disposition", multipart.FileContentDisposition(key, fileHeader.Filename))
|
||||
if header.Get("Content-Type") == "" {
|
||||
header.Set("Content-Type", "application/octet-stream")
|
||||
}
|
||||
part, errCreate := writer.CreatePart(header)
|
||||
if errCreate != nil {
|
||||
return nil, "", fmt.Errorf("create file field %s failed: %w", key, errCreate)
|
||||
}
|
||||
src, errOpen := fileHeader.Open()
|
||||
if errOpen != nil {
|
||||
return nil, "", fmt.Errorf("open upload file failed: %w", errOpen)
|
||||
}
|
||||
_, errCopy := io.Copy(part, src)
|
||||
if errClose := src.Close(); errClose != nil {
|
||||
log.Errorf("openai compat executor: close upload file error: %v", errClose)
|
||||
if errCopy == nil {
|
||||
errCopy = errClose
|
||||
}
|
||||
}
|
||||
if errCopy != nil {
|
||||
return nil, "", fmt.Errorf("copy upload file failed: %w", errCopy)
|
||||
}
|
||||
}
|
||||
}
|
||||
if errClose := writer.Close(); errClose != nil {
|
||||
return nil, "", fmt.Errorf("close multipart writer failed: %w", errClose)
|
||||
}
|
||||
return body.Bytes(), writer.FormDataContentType(), nil
|
||||
}
|
||||
|
||||
func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (baseURL, apiKey string) {
|
||||
if auth == nil {
|
||||
return "", ""
|
||||
|
||||
Reference in New Issue
Block a user