Incorporate DINOv3, DINOv2

This commit is contained in:
rmaphoh
2025-08-31 18:03:57 +01:00
parent 897d71c8c9
commit 409f7b6167
5 changed files with 521 additions and 327 deletions
+41 -5
View File
@@ -1,7 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
from functools import partial
@@ -10,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.models.layers import trunc_normal_
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
def Dinov2(args, **kwargs):
if args.model_arch == 'dinov2_vits14':
arch = 'vit_small_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitb14':
arch = 'vit_base_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitl14':
arch = 'vit_large_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitg14':
arch = 'vit_giant_patch14_dinov2.lvd142m'
else:
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
model = timm.create_model(
arch,
pretrained=True,
img_size=224,
**kwargs
)
return model
def RETFound_dinov2(args, **kwargs):
model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m',
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
return model
def Dinov3(args, **kwargs):
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
model = torch.hub.load(
repo_or_dir="facebookresearch/dinov3",
model=args.model_arch,
pretrained=False, # main() will load your checkpoint
trust_repo=True,
)
# Figure out feature dimension for the probe
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
model.head = nn.Linear(feat_dim, args.nb_classes)
trunc_normal_(model.head.weight, std=2e-5)
if model.head.bias is not None:
nn.init.zeros_(model.head.bias)
return model