|
|
|
|
@@ -6,6 +6,7 @@ package handlers
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"net/http"
|
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
|
|
|
|
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
|
|
|
|
|
|
|
|
|
|
// Cfg holds the current application configuration.
|
|
|
|
|
Cfg *config.SDKConfig
|
|
|
|
|
|
|
|
|
|
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
|
|
|
|
OpenAICompatProviders []string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewBaseAPIHandlers creates a new API handlers instance.
|
|
|
|
|
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
|
|
|
|
|
//
|
|
|
|
|
// Returns:
|
|
|
|
|
// - *BaseAPIHandler: A new API handlers instance
|
|
|
|
|
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
|
|
|
|
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
|
|
|
|
return &BaseAPIHandler{
|
|
|
|
|
Cfg: cfg,
|
|
|
|
|
AuthManager: authManager,
|
|
|
|
|
Cfg: cfg,
|
|
|
|
|
AuthManager: authManager,
|
|
|
|
|
OpenAICompatProviders: openAICompatProviders,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
|
|
|
|
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
|
|
|
|
// This path is the only supported execution route.
|
|
|
|
|
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
|
|
|
|
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
|
|
|
|
providers := util.GetProviderName(normalizedModel)
|
|
|
|
|
if len(providers) == 0 {
|
|
|
|
|
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
|
|
|
|
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
|
|
|
|
if errMsg != nil {
|
|
|
|
|
return nil, errMsg
|
|
|
|
|
}
|
|
|
|
|
req := coreexecutor.Request{
|
|
|
|
|
Model: normalizedModel,
|
|
|
|
|
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|
|
|
|
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
|
|
|
|
// This path is the only supported execution route.
|
|
|
|
|
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
|
|
|
|
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
|
|
|
|
providers := util.GetProviderName(normalizedModel)
|
|
|
|
|
if len(providers) == 0 {
|
|
|
|
|
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
|
|
|
|
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
|
|
|
|
if errMsg != nil {
|
|
|
|
|
return nil, errMsg
|
|
|
|
|
}
|
|
|
|
|
req := coreexecutor.Request{
|
|
|
|
|
Model: normalizedModel,
|
|
|
|
|
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|
|
|
|
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
|
|
|
|
// This path is the only supported execution route.
|
|
|
|
|
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
|
|
|
|
normalizedModel, metadata := normalizeModelMetadata(modelName)
|
|
|
|
|
providers := util.GetProviderName(normalizedModel)
|
|
|
|
|
if len(providers) == 0 {
|
|
|
|
|
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
|
|
|
|
if errMsg != nil {
|
|
|
|
|
errChan := make(chan *interfaces.ErrorMessage, 1)
|
|
|
|
|
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
|
|
|
|
errChan <- errMsg
|
|
|
|
|
close(errChan)
|
|
|
|
|
return nil, errChan
|
|
|
|
|
}
|
|
|
|
|
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|
|
|
|
return dataChan, errChan
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
|
|
|
|
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
|
|
|
|
|
|
|
|
|
|
// First, normalize the model name to handle suffixes like "-thinking-128"
|
|
|
|
|
// This needs to happen before determining the provider for non-dynamic models.
|
|
|
|
|
normalizedModel, metadata = normalizeModelMetadata(modelName)
|
|
|
|
|
|
|
|
|
|
if isDynamic {
|
|
|
|
|
providers = []string{providerName}
|
|
|
|
|
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
|
|
|
|
|
// so we use it as the final normalizedModel.
|
|
|
|
|
normalizedModel = extractedModelName
|
|
|
|
|
} else {
|
|
|
|
|
// For non-dynamic models, use the normalizedModel to get the provider name.
|
|
|
|
|
providers = util.GetProviderName(normalizedModel)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(providers) == 0 {
|
|
|
|
|
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
|
|
|
|
|
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
|
|
|
|
|
// So, normalizedModel is already correctly set at this point.
|
|
|
|
|
|
|
|
|
|
return providers, normalizedModel, metadata, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
|
|
|
|
var providerPart, modelPart string
|
|
|
|
|
for _, sep := range []string{"://"} {
|
|
|
|
|
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
|
|
|
|
providerPart = parts[0]
|
|
|
|
|
modelPart = parts[1]
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if providerPart == "" {
|
|
|
|
|
return "", modelName, false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Check if the provider is a configured openai-compatibility provider
|
|
|
|
|
for _, pName := range h.OpenAICompatProviders {
|
|
|
|
|
if pName == providerPart {
|
|
|
|
|
return providerPart, modelPart, true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return "", modelName, false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func cloneBytes(src []byte) []byte {
|
|
|
|
|
if len(src) == 0 {
|
|
|
|
|
return nil
|
|
|
|
|
|