fix(home): implement home dispatch headers and enhance Gemini model handling
This commit is contained in:
+63
-15
@@ -409,7 +409,7 @@ func (s *Server) setupRoutes() {
|
|||||||
{
|
{
|
||||||
v1beta.GET("/models", s.geminiModelsHandler(geminiHandlers))
|
v1beta.GET("/models", s.geminiModelsHandler(geminiHandlers))
|
||||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
v1beta.GET("/models/*action", s.geminiGetHandler(geminiHandlers))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Root endpoint
|
// Root endpoint
|
||||||
@@ -851,6 +851,17 @@ func (s *Server) geminiModelsHandler(geminiHandler *gemini.GeminiAPIHandler) gin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) geminiGetHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if s != nil && s.cfg != nil && s.cfg.Home.Enabled {
|
||||||
|
s.handleHomeGeminiModel(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
geminiHandler.GeminiGetHandler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type homeModelEntry struct {
|
type homeModelEntry struct {
|
||||||
id string
|
id string
|
||||||
created int64
|
created int64
|
||||||
@@ -933,6 +944,29 @@ func (s *Server) handleHomeGeminiModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleHomeGeminiModel(c *gin.Context) {
|
||||||
|
entries, ok := s.loadHomeModelEntries(c)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
action := strings.TrimPrefix(c.Param("action"), "/")
|
||||||
|
action = strings.TrimSpace(action)
|
||||||
|
for _, entry := range entries {
|
||||||
|
if homeGeminiModelMatches(entry, action) {
|
||||||
|
c.JSON(http.StatusOK, formatHomeGeminiModel(entry))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Not Found",
|
||||||
|
Type: "not_found",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) {
|
func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) {
|
||||||
if s == nil || c == nil || c.Request == nil {
|
if s == nil || c == nil || c.Request == nil {
|
||||||
return nil, false
|
return nil, false
|
||||||
@@ -976,24 +1010,38 @@ func (s *Server) loadHomeModelEntries(c *gin.Context) ([]homeModelEntry, bool) {
|
|||||||
func formatHomeGeminiModels(entries []homeModelEntry) []map[string]any {
|
func formatHomeGeminiModels(entries []homeModelEntry) []map[string]any {
|
||||||
out := make([]map[string]any, 0, len(entries))
|
out := make([]map[string]any, 0, len(entries))
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
name := entry.id
|
out = append(out, formatHomeGeminiModel(entry))
|
||||||
if !strings.HasPrefix(name, "models/") {
|
|
||||||
name = "models/" + name
|
|
||||||
}
|
|
||||||
displayName := entry.displayName
|
|
||||||
if displayName == "" {
|
|
||||||
displayName = entry.id
|
|
||||||
}
|
|
||||||
out = append(out, map[string]any{
|
|
||||||
"name": name,
|
|
||||||
"displayName": displayName,
|
|
||||||
"description": displayName,
|
|
||||||
"supportedGenerationMethods": []string{"generateContent"},
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func formatHomeGeminiModel(entry homeModelEntry) map[string]any {
|
||||||
|
name := entry.id
|
||||||
|
if !strings.HasPrefix(name, "models/") {
|
||||||
|
name = "models/" + name
|
||||||
|
}
|
||||||
|
displayName := entry.displayName
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = entry.id
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"name": name,
|
||||||
|
"displayName": displayName,
|
||||||
|
"description": displayName,
|
||||||
|
"supportedGenerationMethods": []string{"generateContent"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func homeGeminiModelMatches(entry homeModelEntry, action string) bool {
|
||||||
|
id := strings.TrimSpace(entry.id)
|
||||||
|
if id == "" || action == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
normalizedAction := strings.TrimPrefix(action, "models/")
|
||||||
|
normalizedID := strings.TrimPrefix(id, "models/")
|
||||||
|
return action == id || action == "models/"+id || normalizedAction == normalizedID
|
||||||
|
}
|
||||||
|
|
||||||
func decodeHomeModels(raw []byte) ([]homeModelEntry, error) {
|
func decodeHomeModels(raw []byte) ([]homeModelEntry, error) {
|
||||||
if len(raw) == 0 {
|
if len(raw) == 0 {
|
||||||
return nil, fmt.Errorf("home models payload is empty")
|
return nil, fmt.Errorf("home models payload is empty")
|
||||||
|
|||||||
@@ -3231,6 +3231,79 @@ func setHomeUserAPIKeyOnGinContext(ctx context.Context, apiKey string) {
|
|||||||
ginCtx.Set("userApiKey", apiKey)
|
ginCtx.Set("userApiKey", apiKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func homeDispatchHeaders(ctx context.Context, headers http.Header) http.Header {
|
||||||
|
apiKey, ok := homeQueryCredentialFromContext(ctx)
|
||||||
|
if !ok {
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
out := headers.Clone()
|
||||||
|
if out == nil {
|
||||||
|
out = http.Header{}
|
||||||
|
}
|
||||||
|
if out.Get("Authorization") != "" || out.Get("X-Goog-Api-Key") != "" || out.Get("X-Api-Key") != "" {
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
out.Set("X-Goog-Api-Key", apiKey)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func homeQueryCredentialFromContext(ctx context.Context) (string, bool) {
|
||||||
|
if ctx == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
if queryCtx, ok := ctx.Value("gin").(interface{ Query(string) string }); ok && queryCtx != nil {
|
||||||
|
if apiKey := strings.TrimSpace(queryCtx.Query("key")); apiKey != "" {
|
||||||
|
return apiKey, true
|
||||||
|
}
|
||||||
|
if apiKey := strings.TrimSpace(queryCtx.Query("auth_token")); apiKey != "" {
|
||||||
|
return apiKey, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ginCtx, ok := ctx.Value("gin").(interface{ Get(string) (any, bool) })
|
||||||
|
if !ok || ginCtx == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
rawMetadata, ok := ginCtx.Get("accessMetadata")
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
source := accessMetadataSource(rawMetadata)
|
||||||
|
if source != "query-key" && source != "query-auth-token" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
rawAPIKey, ok := ginCtx.Get("userApiKey")
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
apiKey := contextStringValue(rawAPIKey)
|
||||||
|
if apiKey == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return apiKey, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func accessMetadataSource(raw any) string {
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case map[string]string:
|
||||||
|
return strings.TrimSpace(v["source"])
|
||||||
|
case map[string]any:
|
||||||
|
return contextStringValue(v["source"])
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextStringValue(raw any) string {
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(v))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func homeExecutionSessionIDFromMetadata(meta map[string]any) string {
|
func homeExecutionSessionIDFromMetadata(meta map[string]any) string {
|
||||||
if len(meta) == 0 {
|
if len(meta) == 0 {
|
||||||
return ""
|
return ""
|
||||||
@@ -3352,8 +3425,9 @@ func (m *Manager) pickNextViaHome(ctx context.Context, model string, opts clipro
|
|||||||
|
|
||||||
requestedModel := requestedModelFromMetadata(opts.Metadata, model)
|
requestedModel := requestedModelFromMetadata(opts.Metadata, model)
|
||||||
sessionID := ExtractSessionID(opts.Headers, opts.OriginalRequest, opts.Metadata)
|
sessionID := ExtractSessionID(opts.Headers, opts.OriginalRequest, opts.Metadata)
|
||||||
|
dispatchHeaders := homeDispatchHeaders(ctx, opts.Headers)
|
||||||
|
|
||||||
raw, err := client.RPopAuth(ctx, requestedModel, sessionID, opts.Headers, count)
|
raw, err := client.RPopAuth(ctx, requestedModel, sessionID, dispatchHeaders, count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", &Error{Code: "auth_not_found", Message: err.Error(), HTTPStatus: http.StatusServiceUnavailable}
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: err.Error(), HTTPStatus: http.StatusServiceUnavailable}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,87 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type homeDispatchTestGinContext struct {
|
||||||
|
values map[string]any
|
||||||
|
query map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c homeDispatchTestGinContext) Get(key string) (any, bool) {
|
||||||
|
v, ok := c.values[key]
|
||||||
|
return v, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c homeDispatchTestGinContext) Query(key string) string {
|
||||||
|
if c.query == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.query[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHomeDispatchHeadersAddsQueryKeyCredential(t *testing.T) {
|
||||||
|
ginCtx := homeDispatchTestGinContext{query: map[string]string{"key": "12345"}}
|
||||||
|
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
|
headers := http.Header{"User-Agent": {"client"}}
|
||||||
|
|
||||||
|
got := homeDispatchHeaders(ctx, headers)
|
||||||
|
|
||||||
|
if got.Get("X-Goog-Api-Key") != "12345" {
|
||||||
|
t.Fatalf("X-Goog-Api-Key = %q, want %q", got.Get("X-Goog-Api-Key"), "12345")
|
||||||
|
}
|
||||||
|
if headers.Get("X-Goog-Api-Key") != "" {
|
||||||
|
t.Fatalf("original headers were mutated: %v", headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHomeDispatchHeadersAddsQueryCredentialFromAccessMetadata(t *testing.T) {
|
||||||
|
ginCtx := homeDispatchTestGinContext{values: map[string]any{
|
||||||
|
"accessMetadata": map[string]string{"source": "query-key"},
|
||||||
|
"userApiKey": "12345",
|
||||||
|
}}
|
||||||
|
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
|
headers := http.Header{"User-Agent": {"client"}}
|
||||||
|
|
||||||
|
got := homeDispatchHeaders(ctx, headers)
|
||||||
|
|
||||||
|
if got.Get("X-Goog-Api-Key") != "12345" {
|
||||||
|
t.Fatalf("X-Goog-Api-Key = %q, want %q", got.Get("X-Goog-Api-Key"), "12345")
|
||||||
|
}
|
||||||
|
if headers.Get("X-Goog-Api-Key") != "" {
|
||||||
|
t.Fatalf("original headers were mutated: %v", headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHomeDispatchHeadersKeepsExistingCredentialHeader(t *testing.T) {
|
||||||
|
ginCtx := homeDispatchTestGinContext{query: map[string]string{"key": "query-key"}}
|
||||||
|
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
|
headers := http.Header{"X-Goog-Api-Key": {"header-key"}}
|
||||||
|
|
||||||
|
got := homeDispatchHeaders(ctx, headers)
|
||||||
|
|
||||||
|
if got.Get("X-Goog-Api-Key") != "header-key" {
|
||||||
|
t.Fatalf("X-Goog-Api-Key = %q, want %q", got.Get("X-Goog-Api-Key"), "header-key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHomeDispatchHeadersIgnoresHeaderCredentialSource(t *testing.T) {
|
||||||
|
ginCtx := homeDispatchTestGinContext{values: map[string]any{
|
||||||
|
"accessMetadata": map[string]string{"source": "authorization"},
|
||||||
|
"userApiKey": "12345",
|
||||||
|
}}
|
||||||
|
ctx := context.WithValue(context.Background(), "gin", ginCtx)
|
||||||
|
headers := http.Header{"Authorization": {"Bearer 12345"}}
|
||||||
|
|
||||||
|
got := homeDispatchHeaders(ctx, headers)
|
||||||
|
|
||||||
|
if got.Get("X-Goog-Api-Key") != "" {
|
||||||
|
t.Fatalf("X-Goog-Api-Key = %q, want empty", got.Get("X-Goog-Api-Key"))
|
||||||
|
}
|
||||||
|
if got.Get("Authorization") != "Bearer 12345" {
|
||||||
|
t.Fatalf("Authorization = %q, want %q", got.Get("Authorization"), "Bearer 12345")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user