feat(model_registry): enhance model registration and refresh mechanisms
This commit is contained in:
@@ -15,7 +15,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
modelsFetchTimeout = 30 * time.Second
|
||||
modelsFetchTimeout = 30 * time.Second
|
||||
modelsRefreshInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
var modelsURLs = []string{
|
||||
@@ -35,6 +36,34 @@ var modelsCatalogStore = &modelStore{}
|
||||
|
||||
var updaterOnce sync.Once
|
||||
|
||||
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
|
||||
// changedProviders contains the provider names whose model definitions changed.
|
||||
type ModelRefreshCallback func(changedProviders []string)
|
||||
|
||||
var (
|
||||
refreshCallbackMu sync.Mutex
|
||||
refreshCallback ModelRefreshCallback
|
||||
pendingRefreshChanges []string
|
||||
)
|
||||
|
||||
// SetModelRefreshCallback registers a callback that is invoked when startup or
|
||||
// periodic model refresh detects changes. Only one callback is supported;
|
||||
// subsequent calls replace the previous callback.
|
||||
func SetModelRefreshCallback(cb ModelRefreshCallback) {
|
||||
refreshCallbackMu.Lock()
|
||||
refreshCallback = cb
|
||||
var pending []string
|
||||
if cb != nil && len(pendingRefreshChanges) > 0 {
|
||||
pending = append([]string(nil), pendingRefreshChanges...)
|
||||
pendingRefreshChanges = nil
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
|
||||
if cb != nil && len(pending) > 0 {
|
||||
cb(pending)
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Load embedded data as fallback on startup.
|
||||
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
|
||||
@@ -42,23 +71,76 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
// StartModelsUpdater runs a one-time models refresh on startup.
|
||||
// It blocks until the startup fetch attempt finishes so service initialization
|
||||
// can wait for the refreshed catalog before registering auth-backed models.
|
||||
// Safe to call multiple times; only one refresh will run.
|
||||
// StartModelsUpdater starts a background updater that fetches models
|
||||
// immediately on startup and then refreshes the model catalog every 3 hours.
|
||||
// Safe to call multiple times; only one updater will run.
|
||||
func StartModelsUpdater(ctx context.Context) {
|
||||
updaterOnce.Do(func() {
|
||||
runModelsUpdater(ctx)
|
||||
go runModelsUpdater(ctx)
|
||||
})
|
||||
}
|
||||
|
||||
func runModelsUpdater(ctx context.Context) {
|
||||
// Try network fetch once on startup, then stop.
|
||||
// Periodic refresh is disabled - models are only refreshed at startup.
|
||||
tryRefreshModels(ctx)
|
||||
tryStartupRefresh(ctx)
|
||||
periodicRefresh(ctx)
|
||||
}
|
||||
|
||||
func tryRefreshModels(ctx context.Context) {
|
||||
func periodicRefresh(ctx context.Context) {
|
||||
ticker := time.NewTicker(modelsRefreshInterval)
|
||||
defer ticker.Stop()
|
||||
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
tryPeriodicRefresh(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryPeriodicRefresh fetches models from remote, compares with the current
|
||||
// catalog, and notifies the registered callback if any provider changed.
|
||||
func tryPeriodicRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "periodic model refresh")
|
||||
}
|
||||
|
||||
// tryStartupRefresh fetches models from remote in the background during
|
||||
// process startup. It uses the same change detection as periodic refresh so
|
||||
// existing auth registrations can be updated after the callback is registered.
|
||||
func tryStartupRefresh(ctx context.Context) {
|
||||
tryRefreshModels(ctx, "startup model refresh")
|
||||
}
|
||||
|
||||
func tryRefreshModels(ctx context.Context, label string) {
|
||||
oldData := getModels()
|
||||
|
||||
parsed, url := fetchModelsFromRemote(ctx)
|
||||
if parsed == nil {
|
||||
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
|
||||
return
|
||||
}
|
||||
|
||||
// Detect changes before updating store.
|
||||
changed := detectChangedProviders(oldData, parsed)
|
||||
|
||||
// Update store with new data regardless.
|
||||
modelsCatalogStore.mu.Lock()
|
||||
modelsCatalogStore.data = parsed
|
||||
modelsCatalogStore.mu.Unlock()
|
||||
|
||||
if len(changed) == 0 {
|
||||
log.Infof("%s completed from %s, no changes detected", label, url)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
|
||||
notifyModelRefresh(changed)
|
||||
}
|
||||
|
||||
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
|
||||
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
|
||||
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
|
||||
client := &http.Client{Timeout: modelsFetchTimeout}
|
||||
for _, url := range modelsURLs {
|
||||
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
|
||||
@@ -92,15 +174,126 @@ func tryRefreshModels(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := loadModelsFromBytes(data, url); err != nil {
|
||||
var parsed staticModelsJSON
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
log.Warnf("models parse failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
if err := validateModelsCatalog(&parsed); err != nil {
|
||||
log.Warnf("models validate failed from %s: %v", url, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("models updated from %s", url)
|
||||
return &parsed, url
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// detectChangedProviders compares two model catalogs and returns provider names
|
||||
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
|
||||
// under a single "codex" provider.
|
||||
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
|
||||
if oldData == nil || newData == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
type section struct {
|
||||
provider string
|
||||
oldList []*ModelInfo
|
||||
newList []*ModelInfo
|
||||
}
|
||||
|
||||
sections := []section{
|
||||
{"claude", oldData.Claude, newData.Claude},
|
||||
{"gemini", oldData.Gemini, newData.Gemini},
|
||||
{"vertex", oldData.Vertex, newData.Vertex},
|
||||
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
|
||||
{"aistudio", oldData.AIStudio, newData.AIStudio},
|
||||
{"codex", oldData.CodexFree, newData.CodexFree},
|
||||
{"codex", oldData.CodexTeam, newData.CodexTeam},
|
||||
{"codex", oldData.CodexPlus, newData.CodexPlus},
|
||||
{"codex", oldData.CodexPro, newData.CodexPro},
|
||||
{"qwen", oldData.Qwen, newData.Qwen},
|
||||
{"iflow", oldData.IFlow, newData.IFlow},
|
||||
{"kimi", oldData.Kimi, newData.Kimi},
|
||||
{"antigravity", oldData.Antigravity, newData.Antigravity},
|
||||
}
|
||||
|
||||
seen := make(map[string]bool, len(sections))
|
||||
var changed []string
|
||||
for _, s := range sections {
|
||||
if seen[s.provider] {
|
||||
continue
|
||||
}
|
||||
if modelSectionChanged(s.oldList, s.newList) {
|
||||
changed = append(changed, s.provider)
|
||||
seen[s.provider] = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
// modelSectionChanged reports whether two model slices differ.
|
||||
func modelSectionChanged(a, b []*ModelInfo) bool {
|
||||
if len(a) != len(b) {
|
||||
return true
|
||||
}
|
||||
if len(a) == 0 {
|
||||
return false
|
||||
}
|
||||
aj, err1 := json.Marshal(a)
|
||||
bj, err2 := json.Marshal(b)
|
||||
if err1 != nil || err2 != nil {
|
||||
return true
|
||||
}
|
||||
return string(aj) != string(bj)
|
||||
}
|
||||
|
||||
func notifyModelRefresh(changedProviders []string) {
|
||||
if len(changedProviders) == 0 {
|
||||
return
|
||||
}
|
||||
log.Warn("models refresh failed from all URLs, using local data")
|
||||
|
||||
refreshCallbackMu.Lock()
|
||||
cb := refreshCallback
|
||||
if cb == nil {
|
||||
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
|
||||
refreshCallbackMu.Unlock()
|
||||
return
|
||||
}
|
||||
refreshCallbackMu.Unlock()
|
||||
cb(changedProviders)
|
||||
}
|
||||
|
||||
func mergeProviderNames(existing, incoming []string) []string {
|
||||
if len(incoming) == 0 {
|
||||
return existing
|
||||
}
|
||||
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
||||
merged := make([]string, 0, len(existing)+len(incoming))
|
||||
for _, provider := range existing {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
for _, provider := range incoming {
|
||||
name := strings.ToLower(strings.TrimSpace(provider))
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
seen[name] = struct{}{}
|
||||
merged = append(merged, name)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func loadModelsFromBytes(data []byte, source string) error {
|
||||
|
||||
Reference in New Issue
Block a user