Incorporate DINOv3, DINOv2
This commit is contained in:
+41
-5
@@ -1,7 +1,3 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Partly revised by YZ @UCL&Moorfields
|
||||
# --------------------------------------------------------
|
||||
|
||||
from functools import partial
|
||||
|
||||
@@ -10,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
||||
""" Vision Transformer with support for global average pooling
|
||||
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
|
||||
|
||||
|
||||
|
||||
def Dinov2(args, **kwargs):
|
||||
|
||||
if args.model_arch == 'dinov2_vits14':
|
||||
arch = 'vit_small_patch14_dinov2.lvd142m'
|
||||
elif args.model_arch == 'dinov2_vitb14':
|
||||
arch = 'vit_base_patch14_dinov2.lvd142m'
|
||||
elif args.model_arch == 'dinov2_vitl14':
|
||||
arch = 'vit_large_patch14_dinov2.lvd142m'
|
||||
elif args.model_arch == 'dinov2_vitg14':
|
||||
arch = 'vit_giant_patch14_dinov2.lvd142m'
|
||||
else:
|
||||
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
|
||||
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
|
||||
|
||||
model = timm.create_model(
|
||||
arch,
|
||||
pretrained=True,
|
||||
img_size=224,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def RETFound_dinov2(args, **kwargs):
|
||||
model = timm.create_model(
|
||||
'vit_large_patch14_dinov2.lvd142m',
|
||||
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def Dinov3(args, **kwargs):
|
||||
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
|
||||
model = torch.hub.load(
|
||||
repo_or_dir="facebookresearch/dinov3",
|
||||
model=args.model_arch,
|
||||
pretrained=False, # main() will load your checkpoint
|
||||
trust_repo=True,
|
||||
)
|
||||
|
||||
# Figure out feature dimension for the probe
|
||||
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
|
||||
model.head = nn.Linear(feat_dim, args.nb_classes)
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
if model.head.bias is not None:
|
||||
nn.init.zeros_(model.head.bias)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user