Incorporate DINOv3, DINOv2

This commit is contained in:
rmaphoh
2025-08-31 18:03:57 +01:00
parent 897d71c8c9
commit 409f7b6167
5 changed files with 521 additions and 327 deletions
+114 -52
View File
@@ -1,14 +1,15 @@
## RETFound - A foundation model for retinal imaging
Official repo including a series of retinal foundation models.<br>
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).<br>
[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2):
Official repo including a series of foundation models and applications in retinal imaging.<br>
`[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).<br>
`[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).<br>
`[DINOv2]`:[General-purpose vision foundation models DINOv2](https://github.com/facebookresearch/dinov2).<br>
`[DINOv3]`:[General-purpose vision foundation models DINOv3](https://github.com/facebookresearch/dinov3).<br>
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
### 📝Key features
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
### 🎉News
- 🐉2025/09: **Benchmarking paper for DINOv3, DINOv2, and RETFound will come soon!**
- 🐉2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
### 🔧Install environment
@@ -40,9 +42,9 @@ conda activate retfound
2. Install dependencies
```
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
git clone https://github.com/rmaphoh/RETFound_MAE/
cd RETFound_MAE
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
git clone https://github.com/rmaphoh/RETFound/
cd RETFound
pip install -r requirements.txt
```
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
<!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
<td align="center">TBD</a></td>
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
<td align="center">TBD</a></td>
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
<td align="center">TBD</a></td>
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
<td align="center">TBD</a></td>
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr>
</tbody></table>
@@ -118,56 +120,116 @@ export HF_ENDPOINT=https://hf-mirror.com
├──class_c
```
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
4. If you would like to use DINOv2 and DINOv3, please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
The model and finetune can be selected:
4. Start fine-tuning by running `sh train.sh`.
The model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE` in `train.sh`:
**RETFound**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|--------------------------|--------------------------|
| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
**DINOv3**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|----------------------------------|--------------------------|
| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
**DINOv2**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|------------------------------|--------------------------|
| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
| model | finetune |
|-----------------|--------------------------|
| RETFound_mae | RETFound_mae_natureCFP |
| RETFound_mae | RETFound_mae_natureOCT |
| RETFound_mae | RETFound_mae_meh |
| RETFound_mae | RETFound_mae_shanghai |
| RETFound_dinov2 | RETFound_dinov2_meh |
| RETFound_dinov2 | RETFound_dinov2_shanghai |
```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
--model RETFound_mae \
--savemodel \
--global_pool \
--batch_size 16 \
--world_size 1 \
--epochs 100 \
--blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD \
--input_size 224 \
--task RETFound_mae_meh-IDRiD \
--finetune RETFound_mae_meh
# ==== Model settings ====
# adaptation {finetune,lp}
ADAPTATION="finetune"
MODEL="RETFound_dinov2"
MODEL_ARCH="retfound_dinov2"
FINETUNE="RETFound_dinov2_meh"
# ==== Data settings ====
# change the dataset name and corresponding class number
DATASET="MESSIDOR2"
NUM_CLASS=5
data_path="./${DATASET}"
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--finetune "${FINETUNE}" \
--savemodel \
--global_pool \
--batch_size 24 \
--world_size 1 \
--epochs 50 \
--nb_classes "${NUM_CLASS}" \
--data_path "${data_path}" \
--input_size 224 \
--task "${task}" \
--adaptation "${ADAPTATION}"
```
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
--model RETFound_mae \
--savemodel \
--eval \
--global_pool \
--batch_size 16 \
--world_size 1 \
--epochs 100 \
--blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD \
--input_size 224 \
--task RETFound_mae_meh-IDRiD \
--resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth
# ==== Model/settings (match training) ====
ADAPTATION="finetune"
MODEL="RETFound_dinov2"
MODEL_ARCH="retfound_dinov2"
FINETUNE="RETFound_dinov2_meh"
# ==== Data/settings (match training) ====
DATASET="MESSIDOR2"
NUM_CLASS=5
DATA_PATH="./${DATASET}"
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
# Path to the trained checkpoint (adjust if you saved elsewhere)
CKPT="./output_dir/${TASK}/checkpoint-best.pth"
# ==== Evaluation only ====
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--savemodel \
--global_pool \
--batch_size 128 \
--world_size 1 \
--nb_classes "${NUM_CLASS}" \
--data_path "${DATA_PATH}" \
--input_size 224 \
--task "${TASK}" \
--adaptation "${ADAPTATION}" \
--eval \
--resume "${CKPT}"
```
+290 -249
View File
@@ -1,349 +1,385 @@
#!/usr/bin/env python3
# =========================
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import warnings
import faulthandler
# =========================
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
from huggingface_hub import hf_hub_download, login # login imported as in original
# =========================
import models_vit as models
import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate
import warnings
import faulthandler
# =========================
faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action="ignore", category=FutureWarning)
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
parser = argparse.ArgumentParser(
"MAE fine-tuning / linear probing for image classification", add_help=False
)
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=256, type=int,
help='images input size')
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: 0.1)')
# ---- Core training
parser.add_argument("--batch_size", default=128, type=int,
help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
parser.add_argument("--epochs", default=50, type=int)
parser.add_argument("--accum_iter", default=1, type=int,
help="Gradient accumulation steps")
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.65,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# ---- Model parameters
parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
help="Model entry in models_vit.py")
parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
parser.add_argument("--input_size", default=256, type=int, help="Image size")
parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
parser.add_argument("--cls_token", action="store_false", dest="global_pool",
help="Use class token instead of global pool for classification")
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# ---- Optimizer parameters
parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
help="Base LR: lr = blr * total_batch_size / 256")
parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# ---- Augmentation
parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
parser.add_argument("--smoothing", type=float, default=0.1)
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# ---- Random erase
parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
parser.add_argument("--remode", type=str, default="pixel")
parser.add_argument("--recount", type=int, default=1)
parser.add_argument("--resplit", action="store_true", default=False)
# * Finetuning params
parser.add_argument('--finetune', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--task', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# ---- Mixup/Cutmix
parser.add_argument("--mixup", type=float, default=0.0)
parser.add_argument("--cutmix", type=float, default=0.0)
parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
parser.add_argument("--mixup_prob", type=float, default=1.0)
parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
parser.add_argument("--mixup_mode", type=str, default="batch")
# Dataset parameters
parser.add_argument('--data_path', default='./data/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=8, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.set_defaults(pin_mem=True)
# ---- Finetuning & adaptation
parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# ---- Dataset & paths
parser.add_argument("--data_path", default="./data/", type=str)
parser.add_argument("--nb_classes", default=8, type=int)
parser.add_argument("--output_dir", default="./output_dir")
parser.add_argument("--log_dir", default="./output_logs")
# fine-tuning parameters
parser.add_argument('--savemodel', action='store_true', default=True,
help='Save model')
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
parser.add_argument('--datasets_seed', default=2026, type=int)
# >>> NEW: training data efficiency <<<
parser.add_argument(
"--dataratio", type=str, default="1.0",
help=('Training data ratio(s) for subsampling in build_dataset. '
'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
'(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
)
parser.add_argument(
"--stratified", action="store_true",
help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
)
# ---- Runtime
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
parser.add_argument("--eval", action="store_true", help="Evaluation only")
parser.add_argument("--dist_eval", action="store_true", default=False,
help="Distributed evaluation (faster monitoring during training)")
parser.add_argument("--num_workers", default=10, type=int)
parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
# ---- Distributed
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--dist_on_itp", action="store_true")
parser.add_argument("--dist_url", default="env://")
# ---- Misc
parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
parser.add_argument("--norm", default="IMAGENET", type=str)
parser.add_argument("--enhance", action="store_true", default=False)
parser.add_argument("--datasets_seed", default=2026, type=int)
return parser
# =========================
# Main
# =========================
def main(args, criterion):
# ---- Optionally load args from resume (when training)
if args.resume and not args.eval:
resume = args.resume
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
args = checkpoint['args']
args.resume = resume
resume_path = args.resume
checkpoint = torch.load(args.resume, map_location="cpu")
print(f"Load checkpoint (args) from: {args.resume}")
args = checkpoint["args"]
args.resume = resume_path
# ---- Distributed setup
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
print(f"{args}".replace(", ", ",\n"))
device = torch.device(args.device)
# fix the seed for reproducibility
# ---- Reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
if args.model=='RETFound_mae':
# ---- Build model
if args.model == "RETFound_mae":
model = models.__dict__[args.model](
img_size=args.input_size,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
img_size=args.input_size,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
else:
model = models.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
args=args,
)
if args.finetune and not args.eval:
print(f"Downloading pre-trained weights from: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}',
filename=f'{args.finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
# ---- Load pre-trained weights (if requested and not eval-only)
if args.finetune and not args.eval:
print(f"Preparing to load pre-trained weights: {args.finetune}")
if args.model in ["Dinov3", "Dinov2"]:
checkpoint_path = args.finetune # local path
elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f"YukunZhou/{args.finetune}",
filename=f"{args.finetune}.pth",
)
else:
raise ValueError(
f"Unsupported model '{args.model}'. "
f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
if args.model in ["Dinov3", "Dinov2"]:
checkpoint_model = checkpoint
elif args.model == "RETFound_dinov2":
checkpoint_model = checkpoint["teacher"]
else: # RETFound_mae
checkpoint_model = checkpoint["model"]
# -- Key hygiene
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
# -- Remove classifier if shape mismatched
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
for k in ["head.weight", "head.bias"]:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
# -- Interpolate pos embed (ViT)
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
# -- Load backbone weights (non-strict)
_ = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5)
# -- Re-init head
if hasattr(model, "head") and hasattr(model.head, "weight"):
trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args)
dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args)
# ---- Datasets & samplers
dataset_train = build_dataset(is_train="train", args=args)
dataset_val = build_dataset(is_train="val", args=args)
dataset_test = build_dataset(is_train="test", args=args)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if not args.eval:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print(
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if not args.eval:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print(f"Sampler_train = {sampler_train}")
if args.dist_eval:
if len(dataset_test) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
if len(dataset_val) % num_tasks != 0:
print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval:
if len(dataset_test) % num_tasks != 0:
print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
# ---- Logging
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else:
log_writer = None
# ---- DataLoaders
if not args.eval:
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
batch_size=args.batch_size, num_workers=args.num_workers,
pin_memory=args.pin_mem, drop_last=True,
)
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
batch_size=args.batch_size, num_workers=args.num_workers,
pin_memory=args.pin_mem, drop_last=False,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
batch_size=args.batch_size, num_workers=args.num_workers,
pin_memory=args.pin_mem, drop_last=False,
)
# ---- Mixup/CutMix
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
label_smoothing=args.smoothing, num_classes=args.nb_classes
)
# ---- Eval-only: resume weights
if args.resume and args.eval:
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
model.load_state_dict(checkpoint['model'])
checkpoint = torch.load(args.resume, map_location="cpu")
print(f"Load checkpoint for eval from: {args.resume}")
model.load_state_dict(checkpoint["model"])
model.to(device)
model_without_ddp = model
# ---- Adaptation toggle
if args.adaptation == "lp":
for name, param in model.named_parameters():
param.requires_grad = ("head" in name)
print("[Adaptation] Linear probe: training classifier head only.")
else:
print("[Adaptation] Full fine-tuning: training all parameters.")
# ---- Count trainable params
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
# ---- LR scaling by effective batch size
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
if args.lr is None:
args.lr = args.blr * eff_batch_size / 256
print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
print(f"actual lr: {args.lr:.2e}")
print(f"accumulate grad iterations: {args.accum_iter}")
print(f"effective batch size: {eff_batch_size}")
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
# ---- DDP (if available)
if args.distributed and torch.cuda.device_count() > 1:
ddp_kwargs = {}
if args.adaptation == "lp":
ddp_kwargs["find_unused_parameters"] = True
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu], **ddp_kwargs
)
model_without_ddp = model.module
else:
model_without_ddp = model # single-GPU
# ---- Optimizer param groups (after freezing)
no_weight_decay = (model_without_ddp.no_weight_decay()
if hasattr(model_without_ddp, "no_weight_decay") else [])
param_groups = lrd.param_groups_lrd(
model_without_ddp,
weight_decay=args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay,
)
for g in param_groups:
g["params"] = [p for p in g["params"] if p.requires_grad]
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler()
print(f"criterion = {criterion}")
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
# ---- Load previous full state (optimizer, scaler, etc.)
misc.load_model(args=args, model_without_ddp=model_without_ddp,
optimizer=optimizer, loss_scaler=loss_scaler)
# =========================
# Eval-only Short Circuit
# =========================
if args.eval:
if 'epoch' in checkpoint:
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
num_class=args.nb_classes, log_writer=log_writer)
exit(0)
if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
print(f"Test with the best model at epoch = {checkpoint['epoch']}")
test_stats, auc_roc = evaluate(
data_loader_test, model, device, args, epoch=0, mode="test",
num_class=args.nb_classes, log_writer=log_writer
)
return
# =========================
# Train Loop
# =========================
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_score = 0.0
best_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
@@ -352,49 +388,55 @@ def main(args, criterion):
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn,
log_writer=log_writer,
args=args
log_writer=log_writer, args=args
)
val_stats, val_score = evaluate(
data_loader_val, model, device, args, epoch, mode="val",
num_class=args.nb_classes, log_writer=log_writer
)
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score:
max_score = val_score
best_epoch = epoch
if args.output_dir and args.savemodel:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, mode='best')
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
args=args, model=model, model_without_ddp=model_without_ddp,
optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
)
print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
log_writer.flush()
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
# =========================
# Final Test (Best Ckpt)
# =========================
ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
checkpoint = torch.load(ckpt_path, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
_test_stats, _auc_roc = evaluate(
data_loader_test, model, device, args, -1, mode="test",
num_class=args.nb_classes, log_writer=None
)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
print(f"Training time {total_time_str}")
if __name__ == '__main__':
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
@@ -402,6 +444,5 @@ if __name__ == '__main__':
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion)
+41 -5
View File
@@ -1,7 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
from functools import partial
@@ -10,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.models.layers import trunc_normal_
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
def Dinov2(args, **kwargs):
if args.model_arch == 'dinov2_vits14':
arch = 'vit_small_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitb14':
arch = 'vit_base_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitl14':
arch = 'vit_large_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitg14':
arch = 'vit_giant_patch14_dinov2.lvd142m'
else:
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
model = timm.create_model(
arch,
pretrained=True,
img_size=224,
**kwargs
)
return model
def RETFound_dinov2(args, **kwargs):
model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m',
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
return model
def Dinov3(args, **kwargs):
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
model = torch.hub.load(
repo_or_dir="facebookresearch/dinov3",
model=args.model_arch,
pretrained=False, # main() will load your checkpoint
trust_repo=True,
)
# Figure out feature dimension for the probe
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
model.head = nn.Linear(feat_dim, args.nb_classes)
trunc_normal_(model.head.weight, std=2e-5)
if model.head.bias is not None:
nn.init.zeros_(model.head.bias)
return model
+27
View File
@@ -0,0 +1,27 @@
# ==== Model settings ====
ADAPTATION="finetune"
MODEL="Dinov2"
MODEL_ARCH="dinov2_vitl14"
FINETUNE="dinov2_vitl14_pretrain.pth"
# ==== Data settings ====
DATASET="MESSIDOR2"
NUM_CLASS=5
data_path="/home/jupyter/public_dataset/${DATASET}"
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--finetune "${FINETUNE}" \
--savemodel \
--global_pool \
--batch_size 24 \
--world_size 1 \
--epochs 50 \
--nb_classes "${NUM_CLASS}" \
--data_path "${data_path}" \
--input_size 224 \
--task "${task}" \
--adaptation "${ADAPTATION}"
+49 -21
View File
@@ -1,29 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
import os
import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, is_train)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset
if is_train == 'train':
ratio = float(getattr(args, "dataratio", 1.0))
seed = int(getattr(args, "seed", 0))
stratified = bool(getattr(args, "stratified", False))
if 0.0 < ratio < 1.0:
if stratified:
idx = _stratified_indices(dataset.targets, ratio, seed)
else:
# simple uniform subsample with torch.Generator for reproducibility
g = torch.Generator().manual_seed(seed)
n = len(dataset)
k = max(1, int(n * ratio))
idx = torch.randperm(n, generator=g)[:k].tolist()
dataset = Subset(dataset, idx)
return dataset
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
# train transform
if is_train == 'train':
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
return create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
mean=mean,
std=std,
)
return transform
# eval transform
t = []
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
size = int(args.input_size / crop_pct)
t.append(
t = [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
transforms.CenterCrop(args.input_size),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
return transforms.Compose(t)
# ---- helpers ----
def _stratified_indices(targets, ratio: float, seed: int):
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
t = torch.as_tensor(targets)
classes = torch.unique(t)
g = torch.Generator().manual_seed(seed)
keep = []
for c in classes.tolist():
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
if len(cls_idx) == 0:
continue
k = max(1, int(round(len(cls_idx) * ratio)))
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
keep.extend(sel.tolist())
# shuffle final indices (stable across seed)
g2 = torch.Generator().manual_seed(seed + 1)
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
return keep