diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index fb0a7655..edb7a677 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -705,7 +705,7 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) { // oauth-model-mappings: map[string][]ModelNameMapping func (h *Handler) GetOAuthModelMappings(c *gin.Context) { - c.JSON(200, gin.H{"oauth-model-mappings": normalizeOAuthModelMappings(h.cfg.OAuthModelMappings)}) + c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)}) } func (h *Handler) PutOAuthModelMappings(c *gin.Context) { @@ -725,7 +725,7 @@ func (h *Handler) PutOAuthModelMappings(c *gin.Context) { } entries = wrapper.Items } - h.cfg.OAuthModelMappings = normalizeOAuthModelMappings(entries) + h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries) h.persist(c) } @@ -751,7 +751,8 @@ func (h *Handler) PatchOAuthModelMappings(c *gin.Context) { return } - normalized := normalizeOAuthModelMappingsList(body.Mappings) + normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings}) + normalized := normalizedMap[channel] if len(normalized) == 0 { if h.cfg.OAuthModelMappings == nil { c.JSON(404, gin.H{"error": "channel not found"}) @@ -1041,60 +1042,26 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) { entry.Models = normalized } -func normalizeOAuthModelMappingsList(entries []config.ModelNameMapping) []config.ModelNameMapping { +func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping { if len(entries) == 0 { return nil } - seenName := make(map[string]struct{}, len(entries)) - seenAlias := make(map[string]struct{}, len(entries)) - clean := make([]config.ModelNameMapping, 0, len(entries)) - for _, mapping := range entries { - name := strings.TrimSpace(mapping.Name) - alias := strings.TrimSpace(mapping.Alias) - if name == "" || alias == "" { + copied := make(map[string][]config.ModelNameMapping, len(entries)) + for channel, mappings := range entries { + if len(mappings) == 0 { continue } - if strings.EqualFold(name, alias) { - continue - } - nameKey := strings.ToLower(name) - aliasKey := strings.ToLower(alias) - if _, ok := seenName[nameKey]; ok { - continue - } - if _, ok := seenAlias[aliasKey]; ok { - continue - } - seenName[nameKey] = struct{}{} - seenAlias[aliasKey] = struct{}{} - clean = append(clean, config.ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork}) + copied[channel] = append([]config.ModelNameMapping(nil), mappings...) } - if len(clean) == 0 { + if len(copied) == 0 { return nil } - return clean -} - -func normalizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping { - if len(entries) == 0 { + cfg := config.Config{OAuthModelMappings: copied} + cfg.SanitizeOAuthModelMappings() + if len(cfg.OAuthModelMappings) == 0 { return nil } - out := make(map[string][]config.ModelNameMapping, len(entries)) - for rawChannel, mappings := range entries { - channel := strings.ToLower(strings.TrimSpace(rawChannel)) - if channel == "" { - continue - } - normalized := normalizeOAuthModelMappingsList(mappings) - if len(normalized) == 0 { - continue - } - out[channel] = normalized - } - if len(out) == 0 { - return nil - } - return out + return cfg.OAuthModelMappings } // GetAmpCode returns the complete ampcode configuration.