diff --git a/README.md b/README.md
index 78cbd54..4aa64d6 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,15 @@
## RETFound - A foundation model for retinal imaging
-Official repo including a series of retinal foundation models.
-[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).
-[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2):
+Official repo including a series of foundation models and applications in retinal imaging.
+`[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).
+`[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).
+`[DINOv2]`:[General-purpose vision foundation models DINOv2](https://github.com/facebookresearch/dinov2).
+`[DINOv3]`:[General-purpose vision foundation models DINOv3](https://github.com/facebookresearch/dinov3).
+
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
-Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
-
### πKey features
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
### πNews
+- π2025/09: **Benchmarking paper for DINOv3, DINOv2, and RETFound will come soon!**
+- π2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
- π2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
- π2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
- π2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
- π2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
- π2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
- π2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
-- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
### π§Install environment
@@ -40,9 +42,9 @@ conda activate retfound
2. Install dependencies
```
-conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
-git clone https://github.com/rmaphoh/RETFound_MAE/
-cd RETFound_MAE
+pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
+git clone https://github.com/rmaphoh/RETFound/
+cd RETFound
pip install -r requirements.txt
```
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
| RETFound_mae_meh |
access |
-TBD |
+FM data paper |
| RETFound_mae_shanghai |
access |
-TBD |
+FM data paper |
| RETFound_dinov2_meh |
access |
-TBD |
+FM data paper |
| RETFound_dinov2_shanghai |
access |
-TBD |
+FM data paper |
@@ -118,56 +120,116 @@ export HF_ENDPOINT=https://hf-mirror.com
βββclass_c
```
-4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
+4. If you would like to use DINOv2 and DINOv3, please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
-The model and finetune can be selected:
+4. Start fine-tuning by running `sh train.sh`.
+
+
+The model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE` in `train.sh`:
+
+**RETFound**:
+
+| MODEL | MODEL_ARCH | FINETUNE | SIZE |
+|-----------------|--------------------------|--------------------------|--------------------------|
+| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
+| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
+| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
+| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
+| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
+| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
+
+
+**DINOv3**:
+
+| MODEL | MODEL_ARCH | FINETUNE | SIZE |
+|-----------------|--------------------------|----------------------------------|--------------------------|
+| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
+| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
+| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
+| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
+| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
+| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
+
+
+**DINOv2**:
+
+| MODEL | MODEL_ARCH | FINETUNE | SIZE |
+|-----------------|--------------------------|------------------------------|--------------------------|
+| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
+| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
+| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
+| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
-| model | finetune |
-|-----------------|--------------------------|
-| RETFound_mae | RETFound_mae_natureCFP |
-| RETFound_mae | RETFound_mae_natureOCT |
-| RETFound_mae | RETFound_mae_meh |
-| RETFound_mae | RETFound_mae_shanghai |
-| RETFound_dinov2 | RETFound_dinov2_meh |
-| RETFound_dinov2 | RETFound_dinov2_shanghai |
```
-torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
- --model RETFound_mae \
- --savemodel \
- --global_pool \
- --batch_size 16 \
- --world_size 1 \
- --epochs 100 \
- --blr 5e-3 --layer_decay 0.65 \
- --weight_decay 0.05 --drop_path 0.2 \
- --nb_classes 5 \
- --data_path ./IDRiD \
- --input_size 224 \
- --task RETFound_mae_meh-IDRiD \
- --finetune RETFound_mae_meh
+# ==== Model settings ====
+# adaptation {finetune,lp}
+ADAPTATION="finetune"
+MODEL="RETFound_dinov2"
+MODEL_ARCH="retfound_dinov2"
+FINETUNE="RETFound_dinov2_meh"
+
+# ==== Data settings ====
+# change the dataset name and corresponding class number
+DATASET="MESSIDOR2"
+NUM_CLASS=5
+data_path="./${DATASET}"
+task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
+
+torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
+ --model "${MODEL}" \
+ --model_arch "${MODEL_ARCH}" \
+ --finetune "${FINETUNE}" \
+ --savemodel \
+ --global_pool \
+ --batch_size 24 \
+ --world_size 1 \
+ --epochs 50 \
+ --nb_classes "${NUM_CLASS}" \
+ --data_path "${data_path}" \
+ --input_size 224 \
+ --task "${task}" \
+ --adaptation "${ADAPTATION}"
+
```
+
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
```
-torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
- --model RETFound_mae \
- --savemodel \
- --eval \
- --global_pool \
- --batch_size 16 \
- --world_size 1 \
- --epochs 100 \
- --blr 5e-3 --layer_decay 0.65 \
- --weight_decay 0.05 --drop_path 0.2 \
- --nb_classes 5 \
- --data_path ./IDRiD \
- --input_size 224 \
- --task RETFound_mae_meh-IDRiD \
- --resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth
+# ==== Model/settings (match training) ====
+ADAPTATION="finetune"
+MODEL="RETFound_dinov2"
+MODEL_ARCH="retfound_dinov2"
+FINETUNE="RETFound_dinov2_meh"
+
+# ==== Data/settings (match training) ====
+DATASET="MESSIDOR2"
+NUM_CLASS=5
+DATA_PATH="./${DATASET}"
+TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
+
+# Path to the trained checkpoint (adjust if you saved elsewhere)
+CKPT="./output_dir/${TASK}/checkpoint-best.pth"
+
+# ==== Evaluation only ====
+torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
+ --model "${MODEL}" \
+ --model_arch "${MODEL_ARCH}" \
+ --savemodel \
+ --global_pool \
+ --batch_size 128 \
+ --world_size 1 \
+ --nb_classes "${NUM_CLASS}" \
+ --data_path "${DATA_PATH}" \
+ --input_size 224 \
+ --task "${TASK}" \
+ --adaptation "${ADAPTATION}" \
+ --eval \
+ --resume "${CKPT}"
+
```
diff --git a/main_finetune.py b/main_finetune.py
index 75c822d..0f4513f 100644
--- a/main_finetune.py
+++ b/main_finetune.py
@@ -1,349 +1,385 @@
+#!/usr/bin/env python3
+
+# =========================
import argparse
import datetime
import json
-
-import numpy as np
import os
import time
from pathlib import Path
+import warnings
+import faulthandler
+# =========================
+import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
+from huggingface_hub import hf_hub_download, login # login imported as in original
+# =========================
import models_vit as models
import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
-from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate
-import warnings
-import faulthandler
-
+# =========================
faulthandler.enable()
-warnings.simplefilter(action='ignore', category=FutureWarning)
+warnings.simplefilter(action="ignore", category=FutureWarning)
def get_args_parser():
- parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
- parser.add_argument('--batch_size', default=128, type=int,
- help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
- parser.add_argument('--epochs', default=50, type=int)
- parser.add_argument('--accum_iter', default=1, type=int,
- help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+ parser = argparse.ArgumentParser(
+ "MAE fine-tuning / linear probing for image classification", add_help=False
+ )
- # Model parameters
- parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
- help='Name of model to train')
- parser.add_argument('--input_size', default=256, type=int,
- help='images input size')
- parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
- help='Drop path rate (default: 0.1)')
+ # ---- Core training
+ parser.add_argument("--batch_size", default=128, type=int,
+ help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
+ parser.add_argument("--epochs", default=50, type=int)
+ parser.add_argument("--accum_iter", default=1, type=int,
+ help="Gradient accumulation steps")
- # Optimizer parameters
- parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
- help='Clip gradient norm (default: None, no clipping)')
- parser.add_argument('--weight_decay', type=float, default=0.05,
- help='weight decay (default: 0.05)')
- parser.add_argument('--lr', type=float, default=None, metavar='LR',
- help='learning rate (absolute lr)')
- parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
- help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
- parser.add_argument('--layer_decay', type=float, default=0.65,
- help='layer-wise lr decay from ELECTRA/BEiT')
- parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
- help='lower lr bound for cyclic schedulers that hit 0')
- parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
- help='epochs to warmup LR')
+ # ---- Model parameters
+ parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
+ help="Model entry in models_vit.py")
+ parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
+ help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
+ parser.add_argument("--input_size", default=256, type=int, help="Image size")
+ parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
+ parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
+ parser.add_argument("--cls_token", action="store_false", dest="global_pool",
+ help="Use class token instead of global pool for classification")
- # Augmentation parameters
- parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
- help='Color jitter factor (enabled only when not using Auto/RandAug)')
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
- help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
- parser.add_argument('--smoothing', type=float, default=0.1,
- help='Label smoothing (default: 0.1)')
+ # ---- Optimizer parameters
+ parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
+ parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
+ parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
+ parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
+ help="Base LR: lr = blr * total_batch_size / 256")
+ parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
+ parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
+ parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
- # * Random Erase params
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
- help='Random erase prob (default: 0.25)')
- parser.add_argument('--remode', type=str, default='pixel',
- help='Random erase mode (default: "pixel")')
- parser.add_argument('--recount', type=int, default=1,
- help='Random erase count (default: 1)')
- parser.add_argument('--resplit', action='store_true', default=False,
- help='Do not random erase first (clean) augmentation split')
+ # ---- Augmentation
+ parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
+ parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
+ parser.add_argument("--smoothing", type=float, default=0.1)
- # * Mixup params
- parser.add_argument('--mixup', type=float, default=0,
- help='mixup alpha, mixup enabled if > 0.')
- parser.add_argument('--cutmix', type=float, default=0,
- help='cutmix alpha, cutmix enabled if > 0.')
- parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
- help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
- parser.add_argument('--mixup_prob', type=float, default=1.0,
- help='Probability of performing mixup or cutmix when either/both is enabled')
- parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
- help='Probability of switching to cutmix when both mixup and cutmix enabled')
- parser.add_argument('--mixup_mode', type=str, default='batch',
- help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+ # ---- Random erase
+ parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
+ parser.add_argument("--remode", type=str, default="pixel")
+ parser.add_argument("--recount", type=int, default=1)
+ parser.add_argument("--resplit", action="store_true", default=False)
- # * Finetuning params
- parser.add_argument('--finetune', default='', type=str,
- help='finetune from checkpoint')
- parser.add_argument('--task', default='', type=str,
- help='finetune from checkpoint')
- parser.add_argument('--global_pool', action='store_true')
- parser.set_defaults(global_pool=True)
- parser.add_argument('--cls_token', action='store_false', dest='global_pool',
- help='Use class token instead of global pool for classification')
+ # ---- Mixup/Cutmix
+ parser.add_argument("--mixup", type=float, default=0.0)
+ parser.add_argument("--cutmix", type=float, default=0.0)
+ parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
+ parser.add_argument("--mixup_prob", type=float, default=1.0)
+ parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
+ parser.add_argument("--mixup_mode", type=str, default="batch")
- # Dataset parameters
- parser.add_argument('--data_path', default='./data/', type=str,
- help='dataset path')
- parser.add_argument('--nb_classes', default=8, type=int,
- help='number of the classification types')
- parser.add_argument('--output_dir', default='./output_dir',
- help='path where to save, empty for no saving')
- parser.add_argument('--log_dir', default='./output_logs',
- help='path where to tensorboard log')
- parser.add_argument('--device', default='cuda',
- help='device to use for training / testing')
- parser.add_argument('--seed', default=0, type=int)
- parser.add_argument('--resume', default='',
- help='resume from checkpoint')
- parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
- help='start epoch')
- parser.add_argument('--eval', action='store_true',
- help='Perform evaluation only')
- parser.add_argument('--dist_eval', action='store_true', default=False,
- help='Enabling distributed evaluation (recommended during training for faster monitor')
- parser.add_argument('--num_workers', default=10, type=int)
- parser.add_argument('--pin_mem', action='store_true',
- help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
- parser.set_defaults(pin_mem=True)
+ # ---- Finetuning & adaptation
+ parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
+ parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
+ parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
+ help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
- # distributed training parameters
- parser.add_argument('--world_size', default=1, type=int,
- help='number of distributed processes')
- parser.add_argument('--local_rank', default=-1, type=int)
- parser.add_argument('--dist_on_itp', action='store_true')
- parser.add_argument('--dist_url', default='env://',
- help='url used to set up distributed training')
+ # ---- Dataset & paths
+ parser.add_argument("--data_path", default="./data/", type=str)
+ parser.add_argument("--nb_classes", default=8, type=int)
+ parser.add_argument("--output_dir", default="./output_dir")
+ parser.add_argument("--log_dir", default="./output_logs")
- # fine-tuning parameters
- parser.add_argument('--savemodel', action='store_true', default=True,
- help='Save model')
- parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
- parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
- parser.add_argument('--datasets_seed', default=2026, type=int)
+ # >>> NEW: training data efficiency <<<
+ parser.add_argument(
+ "--dataratio", type=str, default="1.0",
+ help=('Training data ratio(s) for subsampling in build_dataset. '
+ 'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
+ '(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
+ )
+ parser.add_argument(
+ "--stratified", action="store_true",
+ help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
+ )
+
+ # ---- Runtime
+ parser.add_argument("--device", default="cuda")
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
+ parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
+ parser.add_argument("--eval", action="store_true", help="Evaluation only")
+ parser.add_argument("--dist_eval", action="store_true", default=False,
+ help="Distributed evaluation (faster monitoring during training)")
+ parser.add_argument("--num_workers", default=10, type=int)
+ parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
+
+ # ---- Distributed
+ parser.add_argument("--world_size", default=1, type=int)
+ parser.add_argument("--local_rank", default=-1, type=int)
+ parser.add_argument("--dist_on_itp", action="store_true")
+ parser.add_argument("--dist_url", default="env://")
+
+ # ---- Misc
+ parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
+ parser.add_argument("--norm", default="IMAGENET", type=str)
+ parser.add_argument("--enhance", action="store_true", default=False)
+ parser.add_argument("--datasets_seed", default=2026, type=int)
return parser
+# =========================
+# Main
+# =========================
def main(args, criterion):
+ # ---- Optionally load args from resume (when training)
if args.resume and not args.eval:
- resume = args.resume
- checkpoint = torch.load(args.resume, map_location='cpu')
- print("Load checkpoint from: %s" % args.resume)
- args = checkpoint['args']
- args.resume = resume
+ resume_path = args.resume
+ checkpoint = torch.load(args.resume, map_location="cpu")
+ print(f"Load checkpoint (args) from: {args.resume}")
+ args = checkpoint["args"]
+ args.resume = resume_path
+ # ---- Distributed setup
misc.init_distributed_mode(args)
- print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
- print("{}".format(args).replace(', ', ',\n'))
+ print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
+ print(f"{args}".replace(", ", ",\n"))
device = torch.device(args.device)
- # fix the seed for reproducibility
+ # ---- Reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
-
cudnn.benchmark = True
- if args.model=='RETFound_mae':
+ # ---- Build model
+ if args.model == "RETFound_mae":
model = models.__dict__[args.model](
- img_size=args.input_size,
- num_classes=args.nb_classes,
- drop_path_rate=args.drop_path,
- global_pool=args.global_pool,
- )
+ img_size=args.input_size,
+ num_classes=args.nb_classes,
+ drop_path_rate=args.drop_path,
+ global_pool=args.global_pool,
+ )
else:
model = models.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
args=args,
)
-
- if args.finetune and not args.eval:
-
- print(f"Downloading pre-trained weights from: {args.finetune}")
-
- checkpoint_path = hf_hub_download(
- repo_id=f'YukunZhou/{args.finetune}',
- filename=f'{args.finetune}.pth',
- )
-
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
- print("Load pre-trained checkpoint from: %s" % args.finetune)
-
- if args.model!='RETFound_mae':
- checkpoint_model = checkpoint['teacher']
- else:
- checkpoint_model = checkpoint['model']
+ # ---- Load pre-trained weights (if requested and not eval-only)
+ if args.finetune and not args.eval:
+ print(f"Preparing to load pre-trained weights: {args.finetune}")
+
+ if args.model in ["Dinov3", "Dinov2"]:
+ checkpoint_path = args.finetune # local path
+ elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
+ print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
+ checkpoint_path = hf_hub_download(
+ repo_id=f"YukunZhou/{args.finetune}",
+ filename=f"{args.finetune}.pth",
+ )
+ else:
+ raise ValueError(
+ f"Unsupported model '{args.model}'. "
+ f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
+ )
+
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
+ print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
+
+ if args.model in ["Dinov3", "Dinov2"]:
+ checkpoint_model = checkpoint
+ elif args.model == "RETFound_dinov2":
+ checkpoint_model = checkpoint["teacher"]
+ else: # RETFound_mae
+ checkpoint_model = checkpoint["model"]
+
+ # -- Key hygiene
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
-
+
+ # -- Remove classifier if shape mismatched
state_dict = model.state_dict()
- for k in ['head.weight', 'head.bias']:
+ for k in ["head.weight", "head.bias"]:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
- # interpolate position embedding
+ # -- Interpolate pos embed (ViT)
interpolate_pos_embed(model, checkpoint_model)
- # load pre-trained model
- msg = model.load_state_dict(checkpoint_model, strict=False)
+ # -- Load backbone weights (non-strict)
+ _ = model.load_state_dict(checkpoint_model, strict=False)
- trunc_normal_(model.head.weight, std=2e-5)
+ # -- Re-init head
+ if hasattr(model, "head") and hasattr(model.head, "weight"):
+ trunc_normal_(model.head.weight, std=2e-5)
- dataset_train = build_dataset(is_train='train', args=args)
- dataset_val = build_dataset(is_train='val', args=args)
- dataset_test = build_dataset(is_train='test', args=args)
+ # ---- Datasets & samplers
+ dataset_train = build_dataset(is_train="train", args=args)
+ dataset_val = build_dataset(is_train="val", args=args)
+ dataset_test = build_dataset(is_train="test", args=args)
+ num_tasks = misc.get_world_size()
+ global_rank = misc.get_rank()
- if True: # args.distributed:
- num_tasks = misc.get_world_size()
- global_rank = misc.get_rank()
- if not args.eval:
- sampler_train = torch.utils.data.DistributedSampler(
- dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
- )
- print("Sampler_train = %s" % str(sampler_train))
- if args.dist_eval:
- if len(dataset_val) % num_tasks != 0:
- print(
- 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
- 'This will slightly alter validation results as extra duplicate entries are added to achieve '
- 'equal num of samples per-process.')
- sampler_val = torch.utils.data.DistributedSampler(
- dataset_val, num_replicas=num_tasks, rank=global_rank,
- shuffle=True) # shuffle=True to reduce monitor bias
- else:
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
-
+ if not args.eval:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ print(f"Sampler_train = {sampler_train}")
if args.dist_eval:
- if len(dataset_test) % num_tasks != 0:
- print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
- 'This will slightly alter validation results as extra duplicate entries are added to achieve '
- 'equal num of samples per-process.')
- sampler_test = torch.utils.data.DistributedSampler(
- dataset_test, num_replicas=num_tasks, rank=global_rank,
- shuffle=True) # shuffle=True to reduce monitor bias
+ if len(dataset_val) % num_tasks != 0:
+ print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
else:
- sampler_test = torch.utils.data.SequentialSampler(dataset_test)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ if args.dist_eval:
+ if len(dataset_test) % num_tasks != 0:
+ print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
+ sampler_test = torch.utils.data.DistributedSampler(
+ dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ else:
+ sampler_test = torch.utils.data.SequentialSampler(dataset_test)
+
+ # ---- Logging
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else:
log_writer = None
+ # ---- DataLoaders
if not args.eval:
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
- batch_size=args.batch_size,
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=True,
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ pin_memory=args.pin_mem, drop_last=True,
)
-
- print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
+ print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
- batch_size=args.batch_size,
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ pin_memory=args.pin_mem, drop_last=False,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test,
- batch_size=args.batch_size,
- num_workers=args.num_workers,
- pin_memory=args.pin_mem,
- drop_last=False
+ batch_size=args.batch_size, num_workers=args.num_workers,
+ pin_memory=args.pin_mem, drop_last=False,
)
+ # ---- Mixup/CutMix
mixup_fn = None
- mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+ mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
- label_smoothing=args.smoothing, num_classes=args.nb_classes)
+ label_smoothing=args.smoothing, num_classes=args.nb_classes
+ )
+ # ---- Eval-only: resume weights
if args.resume and args.eval:
- checkpoint = torch.load(args.resume, map_location='cpu')
- print("Load checkpoint from: %s" % args.resume)
- model.load_state_dict(checkpoint['model'])
+ checkpoint = torch.load(args.resume, map_location="cpu")
+ print(f"Load checkpoint for eval from: {args.resume}")
+ model.load_state_dict(checkpoint["model"])
model.to(device)
model_without_ddp = model
+ # ---- Adaptation toggle
+ if args.adaptation == "lp":
+ for name, param in model.named_parameters():
+ param.requires_grad = ("head" in name)
+ print("[Adaptation] Linear probe: training classifier head only.")
+ else:
+ print("[Adaptation] Full fine-tuning: training all parameters.")
+
+ # ---- Count trainable params
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
- print('number of model params (M): %.2f' % (n_parameters / 1.e6))
+ print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
+ # ---- LR scaling by effective batch size
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
-
- if args.lr is None: # only base_lr is specified
+ if args.lr is None:
args.lr = args.blr * eff_batch_size / 256
+ print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
+ print(f"actual lr: {args.lr:.2e}")
+ print(f"accumulate grad iterations: {args.accum_iter}")
+ print(f"effective batch size: {eff_batch_size}")
- print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
- print("actual lr: %.2e" % args.lr)
-
- print("accumulate grad iterations: %d" % args.accum_iter)
- print("effective batch size: %d" % eff_batch_size)
-
- if args.distributed:
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
+ # ---- DDP (if available)
+ if args.distributed and torch.cuda.device_count() > 1:
+ ddp_kwargs = {}
+ if args.adaptation == "lp":
+ ddp_kwargs["find_unused_parameters"] = True
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[args.gpu], **ddp_kwargs
+ )
model_without_ddp = model.module
+ else:
+ model_without_ddp = model # single-GPU
+
+ # ---- Optimizer param groups (after freezing)
+ no_weight_decay = (model_without_ddp.no_weight_decay()
+ if hasattr(model_without_ddp, "no_weight_decay") else [])
+
+
+ param_groups = lrd.param_groups_lrd(
+ model_without_ddp,
+ weight_decay=args.weight_decay,
+ no_weight_decay_list=no_weight_decay,
+ layer_decay=args.layer_decay,
+ )
+ for g in param_groups:
+ g["params"] = [p for p in g["params"] if p.requires_grad]
- no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
- param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
- no_weight_decay_list=no_weight_decay,
- layer_decay=args.layer_decay
- )
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler()
+ print(f"criterion = {criterion}")
- print("criterion = %s" % str(criterion))
-
- misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
+ # ---- Load previous full state (optimizer, scaler, etc.)
+ misc.load_model(args=args, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler)
+ # =========================
+ # Eval-only Short Circuit
+ # =========================
if args.eval:
- if 'epoch' in checkpoint:
- print("Test with the best model at epoch = %d" % checkpoint['epoch'])
- test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
- num_class=args.nb_classes, log_writer=log_writer)
- exit(0)
+ if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
+ print(f"Test with the best model at epoch = {checkpoint['epoch']}")
+ test_stats, auc_roc = evaluate(
+ data_loader_test, model, device, args, epoch=0, mode="test",
+ num_class=args.nb_classes, log_writer=log_writer
+ )
+ return
+ # =========================
+ # Train Loop
+ # =========================
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_score = 0.0
best_epoch = 0
+
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
@@ -352,49 +388,55 @@ def main(args, criterion):
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn,
- log_writer=log_writer,
- args=args
+ log_writer=log_writer, args=args
+ )
+
+ val_stats, val_score = evaluate(
+ data_loader_val, model, device, args, epoch, mode="val",
+ num_class=args.nb_classes, log_writer=log_writer
)
- val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
- num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score:
max_score = val_score
best_epoch = epoch
if args.output_dir and args.savemodel:
misc.save_model(
- args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
- loss_scaler=loss_scaler, epoch=epoch, mode='best')
- print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
-
-
- if epoch == (args.epochs - 1):
- checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
- model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
- model.to(device)
- print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
- test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
- num_class=args.nb_classes, log_writer=None)
+ args=args, model=model, model_without_ddp=model_without_ddp,
+ optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
+ )
+ print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
if log_writer is not None:
- log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
+ log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
+ log_writer.flush()
- log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
- 'epoch': epoch,
- 'n_parameters': n_parameters}
+ log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
+ "epoch": epoch,
+ "n_parameters": n_parameters}
if args.output_dir and misc.is_main_process():
- if log_writer is not None:
- log_writer.flush()
- with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
+ with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
+ # =========================
+ # Final Test (Best Ckpt)
+ # =========================
+ ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
+ model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
+ model.to(device)
+ print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
+ _test_stats, _auc_roc = evaluate(
+ data_loader_test, model, device, args, -1, mode="test",
+ num_class=args.nb_classes, log_writer=None
+ )
+
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
- print('Training time {}'.format(total_time_str))
+ print(f"Training time {total_time_str}")
-if __name__ == '__main__':
+if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
@@ -402,6 +444,5 @@ if __name__ == '__main__':
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
main(args, criterion)
-
-
diff --git a/models_vit.py b/models_vit.py
index 8cea227..82e7fbd 100644
--- a/models_vit.py
+++ b/models_vit.py
@@ -1,7 +1,3 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-# Partly revised by YZ @UCL&Moorfields
-# --------------------------------------------------------
from functools import partial
@@ -10,7 +6,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
-
+from timm.models.layers import trunc_normal_
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
+def Dinov2(args, **kwargs):
+
+ if args.model_arch == 'dinov2_vits14':
+ arch = 'vit_small_patch14_dinov2.lvd142m'
+ elif args.model_arch == 'dinov2_vitb14':
+ arch = 'vit_base_patch14_dinov2.lvd142m'
+ elif args.model_arch == 'dinov2_vitl14':
+ arch = 'vit_large_patch14_dinov2.lvd142m'
+ elif args.model_arch == 'dinov2_vitg14':
+ arch = 'vit_giant_patch14_dinov2.lvd142m'
+ else:
+ raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
+ f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
+
+ model = timm.create_model(
+ arch,
+ pretrained=True,
+ img_size=224,
+ **kwargs
+ )
+ return model
+
+
+
def RETFound_dinov2(args, **kwargs):
model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m',
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
return model
+def Dinov3(args, **kwargs):
+ # Load ViT-L/16 backbone (hub model has `head = Identity` by default)
+ model = torch.hub.load(
+ repo_or_dir="facebookresearch/dinov3",
+ model=args.model_arch,
+ pretrained=False, # main() will load your checkpoint
+ trust_repo=True,
+ )
+ # Figure out feature dimension for the probe
+ feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
+ model.head = nn.Linear(feat_dim, args.nb_classes)
+ trunc_normal_(model.head.weight, std=2e-5)
+ if model.head.bias is not None:
+ nn.init.zeros_(model.head.bias)
+
+ return model
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000..0a3ee46
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,27 @@
+# ==== Model settings ====
+ADAPTATION="finetune"
+MODEL="Dinov2"
+MODEL_ARCH="dinov2_vitl14"
+FINETUNE="dinov2_vitl14_pretrain.pth"
+
+# ==== Data settings ====
+DATASET="MESSIDOR2"
+NUM_CLASS=5
+
+data_path="/home/jupyter/public_dataset/${DATASET}"
+task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
+
+CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
+ --model "${MODEL}" \
+ --model_arch "${MODEL_ARCH}" \
+ --finetune "${FINETUNE}" \
+ --savemodel \
+ --global_pool \
+ --batch_size 24 \
+ --world_size 1 \
+ --epochs 50 \
+ --nb_classes "${NUM_CLASS}" \
+ --data_path "${data_path}" \
+ --input_size 224 \
+ --task "${task}" \
+ --adaptation "${ADAPTATION}"
diff --git a/util/datasets.py b/util/datasets.py
index 0305c27..20f65f2 100644
--- a/util/datasets.py
+++ b/util/datasets.py
@@ -1,29 +1,39 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-# Partly revised by YZ @UCL&Moorfields
-# --------------------------------------------------------
-
import os
+import torch
+from torch.utils.data import Subset
from torchvision import datasets, transforms
from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
-
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
root = os.path.join(args.data_path, is_train)
dataset = datasets.ImageFolder(root, transform=transform)
- return dataset
+ if is_train == 'train':
+ ratio = float(getattr(args, "dataratio", 1.0))
+ seed = int(getattr(args, "seed", 0))
+ stratified = bool(getattr(args, "stratified", False))
+ if 0.0 < ratio < 1.0:
+ if stratified:
+ idx = _stratified_indices(dataset.targets, ratio, seed)
+ else:
+ # simple uniform subsample with torch.Generator for reproducibility
+ g = torch.Generator().manual_seed(seed)
+ n = len(dataset)
+ k = max(1, int(n * ratio))
+ idx = torch.randperm(n, generator=g)[:k].tolist()
+ dataset = Subset(dataset, idx)
+
+ return dataset
def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
- # train transform
+
if is_train == 'train':
- # this should always dispatch to transforms_imagenet_train
- transform = create_transform(
+ return create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
mean=mean,
std=std,
)
- return transform
# eval transform
- t = []
- if args.input_size <= 224:
- crop_pct = 224 / 256
- else:
- crop_pct = 1.0
+ crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
size = int(args.input_size / crop_pct)
- t.append(
+ t = [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
- )
- t.append(transforms.CenterCrop(args.input_size))
- t.append(transforms.ToTensor())
- t.append(transforms.Normalize(mean, std))
+ transforms.CenterCrop(args.input_size),
+ transforms.ToTensor(),
+ transforms.Normalize(mean, std),
+ ]
return transforms.Compose(t)
+
+# ---- helpers ----
+
+def _stratified_indices(targets, ratio: float, seed: int):
+ """Maintain class proportions. Ensures at least 1 sample per class when possible."""
+ t = torch.as_tensor(targets)
+ classes = torch.unique(t)
+ g = torch.Generator().manual_seed(seed)
+
+ keep = []
+ for c in classes.tolist():
+ cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
+ if len(cls_idx) == 0:
+ continue
+ k = max(1, int(round(len(cls_idx) * ratio)))
+ sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
+ keep.extend(sel.tolist())
+
+ # shuffle final indices (stable across seed)
+ g2 = torch.Generator().manual_seed(seed + 1)
+ keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
+ return keep
+