Merge pull request #1922 from shenshuoyaoyouguang/pr/model-registry-safety
fix(registry): clone model snapshots and invalidate available-model cache
This commit is contained in:
@@ -62,6 +62,11 @@ type ModelInfo struct {
|
|||||||
UserDefined bool `json:"-"`
|
UserDefined bool `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type availableModelsCacheEntry struct {
|
||||||
|
models []map[string]any
|
||||||
|
expiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
// ThinkingSupport describes a model family's supported internal reasoning budget range.
|
||||||
// Values are interpreted in provider-native token units.
|
// Values are interpreted in provider-native token units.
|
||||||
type ThinkingSupport struct {
|
type ThinkingSupport struct {
|
||||||
@@ -116,6 +121,8 @@ type ModelRegistry struct {
|
|||||||
clientProviders map[string]string
|
clientProviders map[string]string
|
||||||
// mutex ensures thread-safe access to the registry
|
// mutex ensures thread-safe access to the registry
|
||||||
mutex *sync.RWMutex
|
mutex *sync.RWMutex
|
||||||
|
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
|
||||||
|
availableModelsCache map[string]availableModelsCacheEntry
|
||||||
// hook is an optional callback sink for model registration changes
|
// hook is an optional callback sink for model registration changes
|
||||||
hook ModelRegistryHook
|
hook ModelRegistryHook
|
||||||
}
|
}
|
||||||
@@ -132,11 +139,24 @@ func GetGlobalRegistry() *ModelRegistry {
|
|||||||
clientModels: make(map[string][]string),
|
clientModels: make(map[string][]string),
|
||||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||||
clientProviders: make(map[string]string),
|
clientProviders: make(map[string]string),
|
||||||
|
availableModelsCache: make(map[string]availableModelsCacheEntry),
|
||||||
mutex: &sync.RWMutex{},
|
mutex: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return globalRegistry
|
return globalRegistry
|
||||||
}
|
}
|
||||||
|
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
|
||||||
|
if r.availableModelsCache == nil {
|
||||||
|
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
|
||||||
|
if len(r.availableModelsCache) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clear(r.availableModelsCache)
|
||||||
|
}
|
||||||
|
|
||||||
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
|
||||||
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
||||||
@@ -151,9 +171,9 @@ func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
|
||||||
return info
|
return cloneModelInfo(info)
|
||||||
}
|
}
|
||||||
return LookupStaticModelInfo(modelID)
|
return cloneModelInfo(LookupStaticModelInfo(modelID))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetHook sets an optional hook for observing model registration changes.
|
// SetHook sets an optional hook for observing model registration changes.
|
||||||
@@ -211,6 +231,7 @@ func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
|
|||||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
provider := strings.ToLower(clientProvider)
|
provider := strings.ToLower(clientProvider)
|
||||||
uniqueModelIDs := make([]string, 0, len(models))
|
uniqueModelIDs := make([]string, 0, len(models))
|
||||||
@@ -236,6 +257,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientModels, clientID)
|
delete(r.clientModels, clientID)
|
||||||
delete(r.clientModelInfos, clientID)
|
delete(r.clientModelInfos, clientID)
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -263,6 +285,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
} else {
|
} else {
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
r.triggerModelsRegistered(provider, clientID, models)
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
@@ -406,6 +429,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
r.triggerModelsRegistered(provider, clientID, models)
|
r.triggerModelsRegistered(provider, clientID, models)
|
||||||
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||||
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||||
@@ -509,6 +533,13 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
|||||||
if len(model.SupportedOutputModalities) > 0 {
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
|
if model.Thinking != nil {
|
||||||
|
copyThinking := *model.Thinking
|
||||||
|
if len(model.Thinking.Levels) > 0 {
|
||||||
|
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
|
||||||
|
}
|
||||||
|
copyModel.Thinking = ©Thinking
|
||||||
|
}
|
||||||
return ©Model
|
return ©Model
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -538,6 +569,7 @@ func (r *ModelRegistry) UnregisterClient(clientID string) {
|
|||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
r.unregisterClientInternal(clientID)
|
r.unregisterClientInternal(clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
}
|
}
|
||||||
|
|
||||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||||
@@ -604,9 +636,12 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
|||||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
registration.QuotaExceededClients[clientID] = new(time.Now())
|
now := time.Now()
|
||||||
|
registration.QuotaExceededClients[clientID] = &now
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -618,9 +653,11 @@ func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
|||||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
if registration, exists := r.models[modelID]; exists {
|
if registration, exists := r.models[modelID]; exists {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -636,6 +673,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
|||||||
}
|
}
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
registration, exists := r.models[modelID]
|
registration, exists := r.models[modelID]
|
||||||
if !exists || registration == nil {
|
if !exists || registration == nil {
|
||||||
@@ -649,6 +687,7 @@ func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
|
|||||||
}
|
}
|
||||||
registration.SuspendedClients[clientID] = reason
|
registration.SuspendedClients[clientID] = reason
|
||||||
registration.LastUpdated = time.Now()
|
registration.LastUpdated = time.Now()
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
if reason != "" {
|
if reason != "" {
|
||||||
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
|
||||||
} else {
|
} else {
|
||||||
@@ -666,6 +705,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
|||||||
}
|
}
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
registration, exists := r.models[modelID]
|
registration, exists := r.models[modelID]
|
||||||
if !exists || registration == nil || registration.SuspendedClients == nil {
|
if !exists || registration == nil || registration.SuspendedClients == nil {
|
||||||
@@ -676,6 +716,7 @@ func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
|
|||||||
}
|
}
|
||||||
delete(registration.SuspendedClients, clientID)
|
delete(registration.SuspendedClients, clientID)
|
||||||
registration.LastUpdated = time.Now()
|
registration.LastUpdated = time.Now()
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
log.Debugf("Resumed client %s for model %s", clientID, modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -711,22 +752,52 @@ func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []map[string]any: List of available models in the requested format
|
// - []map[string]any: List of available models in the requested format
|
||||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
|
|
||||||
models := make([]map[string]any, 0)
|
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
|
||||||
|
|
||||||
for _, registration := range r.models {
|
|
||||||
// Check if model has any non-quota-exceeded clients
|
|
||||||
availableClients := registration.Count
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
// Count clients that have exceeded quota but haven't recovered yet
|
r.mutex.RLock()
|
||||||
|
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||||
|
models := cloneModelMaps(cache.models)
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
r.mutex.RUnlock()
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.ensureAvailableModelsCacheLocked()
|
||||||
|
|
||||||
|
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
|
||||||
|
return cloneModelMaps(cache.models)
|
||||||
|
}
|
||||||
|
|
||||||
|
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
|
||||||
|
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
|
||||||
|
models: cloneModelMaps(models),
|
||||||
|
expiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
|
||||||
|
models := make([]map[string]any, 0, len(r.models))
|
||||||
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
var expiresAt time.Time
|
||||||
|
|
||||||
|
for _, registration := range r.models {
|
||||||
|
availableClients := registration.Count
|
||||||
|
|
||||||
expiredClients := 0
|
expiredClients := 0
|
||||||
for _, quotaTime := range registration.QuotaExceededClients {
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
if quotaTime == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
recoveryAt := quotaTime.Add(quotaExpiredDuration)
|
||||||
|
if now.Before(recoveryAt) {
|
||||||
expiredClients++
|
expiredClients++
|
||||||
|
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
|
||||||
|
expiresAt = recoveryAt
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -747,7 +818,6 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
effectiveClients = 0
|
effectiveClients = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include models that have available clients, or those solely cooling down.
|
|
||||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||||
model := r.convertModelToMap(registration.Info, handlerType)
|
model := r.convertModelToMap(registration.Info, handlerType)
|
||||||
if model != nil {
|
if model != nil {
|
||||||
@@ -756,7 +826,44 @@ func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return models
|
return models, expiresAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelMaps(models []map[string]any) []map[string]any {
|
||||||
|
cloned := make([]map[string]any, 0, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil {
|
||||||
|
cloned = append(cloned, nil)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
copyModel := make(map[string]any, len(model))
|
||||||
|
for key, value := range model {
|
||||||
|
copyModel[key] = cloneModelMapValue(value)
|
||||||
|
}
|
||||||
|
cloned = append(cloned, copyModel)
|
||||||
|
}
|
||||||
|
return cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelMapValue(value any) any {
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case map[string]any:
|
||||||
|
copyMap := make(map[string]any, len(typed))
|
||||||
|
for key, entry := range typed {
|
||||||
|
copyMap[key] = cloneModelMapValue(entry)
|
||||||
|
}
|
||||||
|
return copyMap
|
||||||
|
case []any:
|
||||||
|
copySlice := make([]any, len(typed))
|
||||||
|
for i, entry := range typed {
|
||||||
|
copySlice[i] = cloneModelMapValue(entry)
|
||||||
|
}
|
||||||
|
return copySlice
|
||||||
|
case []string:
|
||||||
|
return append([]string(nil), typed...)
|
||||||
|
default:
|
||||||
|
return value
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
// GetAvailableModelsByProvider returns models available for the given provider identifier.
|
||||||
@@ -872,11 +979,11 @@ func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelIn
|
|||||||
|
|
||||||
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
|
||||||
if entry.info != nil {
|
if entry.info != nil {
|
||||||
result = append(result, entry.info)
|
result = append(result, cloneModelInfo(entry.info))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ok && registration != nil && registration.Info != nil {
|
if ok && registration != nil && registration.Info != nil {
|
||||||
result = append(result, registration.Info)
|
result = append(result, cloneModelInfo(registration.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -985,13 +1092,13 @@ func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
|
|||||||
if reg.Providers != nil {
|
if reg.Providers != nil {
|
||||||
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
if count, ok := reg.Providers[provider]; ok && count > 0 {
|
||||||
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
|
||||||
return info
|
return cloneModelInfo(info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback to global info (last registered)
|
// Fallback to global info (last registered)
|
||||||
return reg.Info
|
return cloneModelInfo(reg.Info)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1031,7 +1138,7 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||||
}
|
}
|
||||||
if len(model.SupportedParameters) > 0 {
|
if len(model.SupportedParameters) > 0 {
|
||||||
result["supported_parameters"] = model.SupportedParameters
|
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1075,13 +1182,13 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
|||||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||||
}
|
}
|
||||||
if len(model.SupportedGenerationMethods) > 0 {
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedInputModalities) > 0 {
|
if len(model.SupportedInputModalities) > 0 {
|
||||||
result["supportedInputModalities"] = model.SupportedInputModalities
|
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
|
||||||
}
|
}
|
||||||
if len(model.SupportedOutputModalities) > 0 {
|
if len(model.SupportedOutputModalities) > 0 {
|
||||||
result["supportedOutputModalities"] = model.SupportedOutputModalities
|
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1111,15 +1218,20 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
quotaExpiredDuration := 5 * time.Minute
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
invalidated := false
|
||||||
|
|
||||||
for modelID, registration := range r.models {
|
for modelID, registration := range r.models {
|
||||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||||
delete(registration.QuotaExceededClients, clientID)
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
invalidated = true
|
||||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if invalidated {
|
||||||
|
r.invalidateAvailableModelsCacheLocked()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||||
@@ -1133,8 +1245,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
|||||||
// - string: The model ID of the first available model, or empty string if none available
|
// - string: The model ID of the first available model, or empty string if none available
|
||||||
// - error: An error if no models are available
|
// - error: An error if no models are available
|
||||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
|
|
||||||
// Get all available models for this handler type
|
// Get all available models for this handler type
|
||||||
models := r.GetAvailableModels(handlerType)
|
models := r.GetAvailableModels(handlerType)
|
||||||
@@ -1194,13 +1304,13 @@ func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
|||||||
// Prefer client's own model info to preserve original type/owned_by
|
// Prefer client's own model info to preserve original type/owned_by
|
||||||
if clientInfos != nil {
|
if clientInfos != nil {
|
||||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||||
result = append(result, info)
|
result = append(result, cloneModelInfo(info))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Fallback to global registry (for backwards compatibility)
|
// Fallback to global registry (for backwards compatibility)
|
||||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||||
result = append(result, reg.Info)
|
result = append(result, cloneModelInfo(reg.Info))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModels("openai")
|
||||||
|
if len(first) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(first))
|
||||||
|
}
|
||||||
|
first[0]["id"] = "mutated"
|
||||||
|
first[0]["display_name"] = "Mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModels("openai")
|
||||||
|
if got := second[0]["id"]; got != "m1" {
|
||||||
|
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
|
||||||
|
}
|
||||||
|
if got := second[0]["display_name"]; got != "Model One" {
|
||||||
|
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
|
||||||
|
|
||||||
|
models := r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected 1 model, got %d", len(models))
|
||||||
|
}
|
||||||
|
if got := models[0]["display_name"]; got != "Model One" {
|
||||||
|
t.Fatalf("expected initial display_name Model One, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if got := models[0]["display_name"]; got != "Model One Updated" {
|
||||||
|
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.SuspendClientModel("client-1", "m1", "manual")
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 0 {
|
||||||
|
t.Fatalf("expected no available models after suspension, got %d", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ResumeClientModel("client-1", "m1")
|
||||||
|
models = r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected model to reappear after resume, got %d", len(models))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetModelInfoReturnsClone(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetModelInfo("m1", "gemini")
|
||||||
|
if first == nil {
|
||||||
|
t.Fatal("expected model info")
|
||||||
|
}
|
||||||
|
first.DisplayName = "mutated"
|
||||||
|
first.Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetModelInfo("m1", "gemini")
|
||||||
|
if second.DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
|
||||||
|
}
|
||||||
|
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelsForClientReturnsClones(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetModelsForClient("client-1")
|
||||||
|
if len(first) != 1 || first[0] == nil {
|
||||||
|
t.Fatalf("expected one model, got %+v", first)
|
||||||
|
}
|
||||||
|
first[0].DisplayName = "mutated"
|
||||||
|
first[0].Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetModelsForClient("client-1")
|
||||||
|
if len(second) != 1 || second[0] == nil {
|
||||||
|
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||||
|
}
|
||||||
|
if second[0].DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||||
|
}
|
||||||
|
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModelsByProvider("gemini")
|
||||||
|
if len(first) != 1 || first[0] == nil {
|
||||||
|
t.Fatalf("expected one model, got %+v", first)
|
||||||
|
}
|
||||||
|
first[0].DisplayName = "mutated"
|
||||||
|
first[0].Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModelsByProvider("gemini")
|
||||||
|
if len(second) != 1 || second[0] == nil {
|
||||||
|
t.Fatalf("expected one model on second fetch, got %+v", second)
|
||||||
|
}
|
||||||
|
if second[0].DisplayName != "Model One" {
|
||||||
|
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
|
||||||
|
}
|
||||||
|
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
|
||||||
|
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
|
||||||
|
r.SetModelQuotaExceeded("client-1", "m1")
|
||||||
|
if models := r.GetAvailableModels("openai"); len(models) != 1 {
|
||||||
|
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.mutex.Lock()
|
||||||
|
quotaTime := time.Now().Add(-6 * time.Minute)
|
||||||
|
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
|
||||||
|
r.mutex.Unlock()
|
||||||
|
|
||||||
|
r.CleanupExpiredQuotas()
|
||||||
|
|
||||||
|
if count := r.GetModelCount("m1"); count != 1 {
|
||||||
|
t.Fatalf("expected model count 1 after cleanup, got %d", count)
|
||||||
|
}
|
||||||
|
models := r.GetAvailableModels("openai")
|
||||||
|
if len(models) != 1 {
|
||||||
|
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
|
||||||
|
}
|
||||||
|
if got := models[0]["id"]; got != "m1" {
|
||||||
|
t.Fatalf("expected model id m1, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
|
||||||
|
r := newTestModelRegistry()
|
||||||
|
r.RegisterClient("client-1", "openai", []*ModelInfo{{
|
||||||
|
ID: "m1",
|
||||||
|
DisplayName: "Model One",
|
||||||
|
SupportedParameters: []string{"temperature", "top_p"},
|
||||||
|
}})
|
||||||
|
|
||||||
|
first := r.GetAvailableModels("openai")
|
||||||
|
if len(first) != 1 {
|
||||||
|
t.Fatalf("expected one model, got %d", len(first))
|
||||||
|
}
|
||||||
|
params, ok := first[0]["supported_parameters"].([]string)
|
||||||
|
if !ok || len(params) != 2 {
|
||||||
|
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
|
||||||
|
}
|
||||||
|
params[0] = "mutated"
|
||||||
|
|
||||||
|
second := r.GetAvailableModels("openai")
|
||||||
|
params, ok = second[0]["supported_parameters"].([]string)
|
||||||
|
if !ok || len(params) != 2 || params[0] != "temperature" {
|
||||||
|
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
|
||||||
|
first := LookupModelInfo("glm-4.6")
|
||||||
|
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
|
||||||
|
t.Fatalf("expected static model with thinking levels, got %+v", first)
|
||||||
|
}
|
||||||
|
first.Thinking.Levels[0] = "mutated"
|
||||||
|
|
||||||
|
second := LookupModelInfo("glm-4.6")
|
||||||
|
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
|
||||||
|
t.Fatalf("expected static lookup clone, got %+v", second)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user