From 409f7b6167c2febd3cde29f8d6338d9137e59b47 Mon Sep 17 00:00:00 2001 From: rmaphoh Date: Sun, 31 Aug 2025 18:03:57 +0100 Subject: [PATCH] Incorporate DINOv3, DINOv2 --- README.md | 166 ++++++++++----- main_finetune.py | 539 +++++++++++++++++++++++++---------------------- models_vit.py | 46 +++- train.sh | 27 +++ util/datasets.py | 70 ++++-- 5 files changed, 521 insertions(+), 327 deletions(-) create mode 100644 train.sh diff --git a/README.md b/README.md index 78cbd54..4aa64d6 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,15 @@ ## RETFound - A foundation model for retinal imaging -Official repo including a series of retinal foundation models.
-[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).
-[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2): +Official repo including a series of foundation models and applications in retinal imaging.
+`[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).
+`[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).
+`[DINOv2]`:[General-purpose vision foundation models DINOv2](https://github.com/facebookresearch/dinov2).
+`[DINOv3]`:[General-purpose vision foundation models DINOv3](https://github.com/facebookresearch/dinov3).
+ Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. -Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE) - ### πŸ“Key features @@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/ ### πŸŽ‰News +- πŸ‰2025/09: **Benchmarking paper for DINOv3, DINOv2, and RETFound will come soon!** +- πŸ‰2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!** - πŸ‰2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!** - πŸ‰2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!** - πŸ‰2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!** - πŸ‰2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online! - πŸ‰2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online! - πŸŽ„2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation! -- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size ### πŸ”§Install environment @@ -40,9 +42,9 @@ conda activate retfound 2. Install dependencies ``` -conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia -git clone https://github.com/rmaphoh/RETFound_MAE/ -cd RETFound_MAE +pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121 +git clone https://github.com/rmaphoh/RETFound/ +cd RETFound pip install -r requirements.txt ``` @@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps: RETFound_mae_meh access -TBD +FM data paper RETFound_mae_shanghai access -TBD +FM data paper RETFound_dinov2_meh access -TBD +FM data paper RETFound_dinov2_shanghai access -TBD +FM data paper @@ -118,56 +120,116 @@ export HF_ENDPOINT=https://hf-mirror.com β”œβ”€β”€class_c ``` -4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training. +4. If you would like to use DINOv2 and DINOv3, please visit their GitHub repositories to download the model weights and put them in the RETFound folder. -The model and finetune can be selected: +4. Start fine-tuning by running `sh train.sh`. + + +The model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE` in `train.sh`: + +**RETFound**: + +| MODEL | MODEL_ARCH | FINETUNE | SIZE | +|-----------------|--------------------------|--------------------------|--------------------------| +| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M | +| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M | +| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M | +| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M | +| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M | +| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M | + + +**DINOv3**: + +| MODEL | MODEL_ARCH | FINETUNE | SIZE | +|-----------------|--------------------------|----------------------------------|--------------------------| +| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M | +| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M | +| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M | +| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M | +| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M | +| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B | + + +**DINOv2**: + +| MODEL | MODEL_ARCH | FINETUNE | SIZE | +|-----------------|--------------------------|------------------------------|--------------------------| +| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M | +| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M | +| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M | +| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B | -| model | finetune | -|-----------------|--------------------------| -| RETFound_mae | RETFound_mae_natureCFP | -| RETFound_mae | RETFound_mae_natureOCT | -| RETFound_mae | RETFound_mae_meh | -| RETFound_mae | RETFound_mae_shanghai | -| RETFound_dinov2 | RETFound_dinov2_meh | -| RETFound_dinov2 | RETFound_dinov2_shanghai | ``` -torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \ - --model RETFound_mae \ - --savemodel \ - --global_pool \ - --batch_size 16 \ - --world_size 1 \ - --epochs 100 \ - --blr 5e-3 --layer_decay 0.65 \ - --weight_decay 0.05 --drop_path 0.2 \ - --nb_classes 5 \ - --data_path ./IDRiD \ - --input_size 224 \ - --task RETFound_mae_meh-IDRiD \ - --finetune RETFound_mae_meh +# ==== Model settings ==== +# adaptation {finetune,lp} +ADAPTATION="finetune" +MODEL="RETFound_dinov2" +MODEL_ARCH="retfound_dinov2" +FINETUNE="RETFound_dinov2_meh" + +# ==== Data settings ==== +# change the dataset name and corresponding class number +DATASET="MESSIDOR2" +NUM_CLASS=5 +data_path="./${DATASET}" +task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}" + +torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \ + --model "${MODEL}" \ + --model_arch "${MODEL_ARCH}" \ + --finetune "${FINETUNE}" \ + --savemodel \ + --global_pool \ + --batch_size 24 \ + --world_size 1 \ + --epochs 50 \ + --nb_classes "${NUM_CLASS}" \ + --data_path "${data_path}" \ + --input_size 224 \ + --task "${task}" \ + --adaptation "${ADAPTATION}" + ``` + 4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below) ``` -torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \ - --model RETFound_mae \ - --savemodel \ - --eval \ - --global_pool \ - --batch_size 16 \ - --world_size 1 \ - --epochs 100 \ - --blr 5e-3 --layer_decay 0.65 \ - --weight_decay 0.05 --drop_path 0.2 \ - --nb_classes 5 \ - --data_path ./IDRiD \ - --input_size 224 \ - --task RETFound_mae_meh-IDRiD \ - --resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth +# ==== Model/settings (match training) ==== +ADAPTATION="finetune" +MODEL="RETFound_dinov2" +MODEL_ARCH="retfound_dinov2" +FINETUNE="RETFound_dinov2_meh" + +# ==== Data/settings (match training) ==== +DATASET="MESSIDOR2" +NUM_CLASS=5 +DATA_PATH="./${DATASET}" +TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}" + +# Path to the trained checkpoint (adjust if you saved elsewhere) +CKPT="./output_dir/${TASK}/checkpoint-best.pth" + +# ==== Evaluation only ==== +torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \ + --model "${MODEL}" \ + --model_arch "${MODEL_ARCH}" \ + --savemodel \ + --global_pool \ + --batch_size 128 \ + --world_size 1 \ + --nb_classes "${NUM_CLASS}" \ + --data_path "${DATA_PATH}" \ + --input_size 224 \ + --task "${TASK}" \ + --adaptation "${ADAPTATION}" \ + --eval \ + --resume "${CKPT}" + ``` diff --git a/main_finetune.py b/main_finetune.py index 75c822d..0f4513f 100644 --- a/main_finetune.py +++ b/main_finetune.py @@ -1,349 +1,385 @@ +#!/usr/bin/env python3 + +# ========================= import argparse import datetime import json - -import numpy as np import os import time from pathlib import Path +import warnings +import faulthandler +# ========================= +import numpy as np import torch import torch.backends.cudnn as cudnn from torch.utils.tensorboard import SummaryWriter from timm.models.layers import trunc_normal_ from timm.data.mixup import Mixup +from huggingface_hub import hf_hub_download, login # login imported as in original +# ========================= import models_vit as models import util.lr_decay as lrd import util.misc as misc from util.datasets import build_dataset from util.pos_embed import interpolate_pos_embed from util.misc import NativeScalerWithGradNormCount as NativeScaler -from huggingface_hub import hf_hub_download, login from engine_finetune import train_one_epoch, evaluate -import warnings -import faulthandler - +# ========================= faulthandler.enable() -warnings.simplefilter(action='ignore', category=FutureWarning) +warnings.simplefilter(action="ignore", category=FutureWarning) def get_args_parser(): - parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) - parser.add_argument('--batch_size', default=128, type=int, - help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') - parser.add_argument('--epochs', default=50, type=int) - parser.add_argument('--accum_iter', default=1, type=int, - help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + parser = argparse.ArgumentParser( + "MAE fine-tuning / linear probing for image classification", add_help=False + ) - # Model parameters - parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', - help='Name of model to train') - parser.add_argument('--input_size', default=256, type=int, - help='images input size') - parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT', - help='Drop path rate (default: 0.1)') + # ---- Core training + parser.add_argument("--batch_size", default=128, type=int, + help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)") + parser.add_argument("--epochs", default=50, type=int) + parser.add_argument("--accum_iter", default=1, type=int, + help="Gradient accumulation steps") - # Optimizer parameters - parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', - help='Clip gradient norm (default: None, no clipping)') - parser.add_argument('--weight_decay', type=float, default=0.05, - help='weight decay (default: 0.05)') - parser.add_argument('--lr', type=float, default=None, metavar='LR', - help='learning rate (absolute lr)') - parser.add_argument('--blr', type=float, default=5e-3, metavar='LR', - help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') - parser.add_argument('--layer_decay', type=float, default=0.65, - help='layer-wise lr decay from ELECTRA/BEiT') - parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', - help='lower lr bound for cyclic schedulers that hit 0') - parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', - help='epochs to warmup LR') + # ---- Model parameters + parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL", + help="Model entry in models_vit.py") + parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH", + help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)") + parser.add_argument("--input_size", default=256, type=int, help="Image size") + parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate") + parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True) + parser.add_argument("--cls_token", action="store_false", dest="global_pool", + help="Use class token instead of global pool for classification") - # Augmentation parameters - parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', - help='Color jitter factor (enabled only when not using Auto/RandAug)') - parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', - help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), - parser.add_argument('--smoothing', type=float, default=0.1, - help='Label smoothing (default: 0.1)') + # ---- Optimizer parameters + parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm") + parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay") + parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)") + parser.add_argument("--blr", type=float, default=5e-3, metavar="LR", + help="Base LR: lr = blr * total_batch_size / 256") + parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)") + parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound") + parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs") - # * Random Erase params - parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', - help='Random erase prob (default: 0.25)') - parser.add_argument('--remode', type=str, default='pixel', - help='Random erase mode (default: "pixel")') - parser.add_argument('--recount', type=int, default=1, - help='Random erase count (default: 1)') - parser.add_argument('--resplit', action='store_true', default=False, - help='Do not random erase first (clean) augmentation split') + # ---- Augmentation + parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT") + parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME") + parser.add_argument("--smoothing", type=float, default=0.1) - # * Mixup params - parser.add_argument('--mixup', type=float, default=0, - help='mixup alpha, mixup enabled if > 0.') - parser.add_argument('--cutmix', type=float, default=0, - help='cutmix alpha, cutmix enabled if > 0.') - parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, - help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') - parser.add_argument('--mixup_prob', type=float, default=1.0, - help='Probability of performing mixup or cutmix when either/both is enabled') - parser.add_argument('--mixup_switch_prob', type=float, default=0.5, - help='Probability of switching to cutmix when both mixup and cutmix enabled') - parser.add_argument('--mixup_mode', type=str, default='batch', - help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + # ---- Random erase + parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT") + parser.add_argument("--remode", type=str, default="pixel") + parser.add_argument("--recount", type=int, default=1) + parser.add_argument("--resplit", action="store_true", default=False) - # * Finetuning params - parser.add_argument('--finetune', default='', type=str, - help='finetune from checkpoint') - parser.add_argument('--task', default='', type=str, - help='finetune from checkpoint') - parser.add_argument('--global_pool', action='store_true') - parser.set_defaults(global_pool=True) - parser.add_argument('--cls_token', action='store_false', dest='global_pool', - help='Use class token instead of global pool for classification') + # ---- Mixup/Cutmix + parser.add_argument("--mixup", type=float, default=0.0) + parser.add_argument("--cutmix", type=float, default=0.0) + parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None) + parser.add_argument("--mixup_prob", type=float, default=1.0) + parser.add_argument("--mixup_switch_prob", type=float, default=0.5) + parser.add_argument("--mixup_mode", type=str, default="batch") - # Dataset parameters - parser.add_argument('--data_path', default='./data/', type=str, - help='dataset path') - parser.add_argument('--nb_classes', default=8, type=int, - help='number of the classification types') - parser.add_argument('--output_dir', default='./output_dir', - help='path where to save, empty for no saving') - parser.add_argument('--log_dir', default='./output_logs', - help='path where to tensorboard log') - parser.add_argument('--device', default='cuda', - help='device to use for training / testing') - parser.add_argument('--seed', default=0, type=int) - parser.add_argument('--resume', default='', - help='resume from checkpoint') - parser.add_argument('--start_epoch', default=0, type=int, metavar='N', - help='start epoch') - parser.add_argument('--eval', action='store_true', - help='Perform evaluation only') - parser.add_argument('--dist_eval', action='store_true', default=False, - help='Enabling distributed evaluation (recommended during training for faster monitor') - parser.add_argument('--num_workers', default=10, type=int) - parser.add_argument('--pin_mem', action='store_true', - help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') - parser.set_defaults(pin_mem=True) + # ---- Finetuning & adaptation + parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)") + parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping") + parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"], + help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)") - # distributed training parameters - parser.add_argument('--world_size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--local_rank', default=-1, type=int) - parser.add_argument('--dist_on_itp', action='store_true') - parser.add_argument('--dist_url', default='env://', - help='url used to set up distributed training') + # ---- Dataset & paths + parser.add_argument("--data_path", default="./data/", type=str) + parser.add_argument("--nb_classes", default=8, type=int) + parser.add_argument("--output_dir", default="./output_dir") + parser.add_argument("--log_dir", default="./output_logs") - # fine-tuning parameters - parser.add_argument('--savemodel', action='store_true', default=True, - help='Save model') - parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method') - parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data') - parser.add_argument('--datasets_seed', default=2026, type=int) + # >>> NEW: training data efficiency <<< + parser.add_argument( + "--dataratio", type=str, default="1.0", + help=('Training data ratio(s) for subsampling in build_dataset. ' + 'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list ' + '(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.') + ) + parser.add_argument( + "--stratified", action="store_true", + help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)." + ) + + # ---- Runtime + parser.add_argument("--device", default="cuda") + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)") + parser.add_argument("--start_epoch", default=0, type=int, metavar="N") + parser.add_argument("--eval", action="store_true", help="Evaluation only") + parser.add_argument("--dist_eval", action="store_true", default=False, + help="Distributed evaluation (faster monitoring during training)") + parser.add_argument("--num_workers", default=10, type=int) + parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True) + + # ---- Distributed + parser.add_argument("--world_size", default=1, type=int) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument("--dist_on_itp", action="store_true") + parser.add_argument("--dist_url", default="env://") + + # ---- Misc + parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model") + parser.add_argument("--norm", default="IMAGENET", type=str) + parser.add_argument("--enhance", action="store_true", default=False) + parser.add_argument("--datasets_seed", default=2026, type=int) return parser +# ========================= +# Main +# ========================= def main(args, criterion): + # ---- Optionally load args from resume (when training) if args.resume and not args.eval: - resume = args.resume - checkpoint = torch.load(args.resume, map_location='cpu') - print("Load checkpoint from: %s" % args.resume) - args = checkpoint['args'] - args.resume = resume + resume_path = args.resume + checkpoint = torch.load(args.resume, map_location="cpu") + print(f"Load checkpoint (args) from: {args.resume}") + args = checkpoint["args"] + args.resume = resume_path + # ---- Distributed setup misc.init_distributed_mode(args) - print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) - print("{}".format(args).replace(', ', ',\n')) + print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}") + print(f"{args}".replace(", ", ",\n")) device = torch.device(args.device) - # fix the seed for reproducibility + # ---- Reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) - cudnn.benchmark = True - if args.model=='RETFound_mae': + # ---- Build model + if args.model == "RETFound_mae": model = models.__dict__[args.model]( - img_size=args.input_size, - num_classes=args.nb_classes, - drop_path_rate=args.drop_path, - global_pool=args.global_pool, - ) + img_size=args.input_size, + num_classes=args.nb_classes, + drop_path_rate=args.drop_path, + global_pool=args.global_pool, + ) else: model = models.__dict__[args.model]( num_classes=args.nb_classes, drop_path_rate=args.drop_path, args=args, ) - - if args.finetune and not args.eval: - - print(f"Downloading pre-trained weights from: {args.finetune}") - - checkpoint_path = hf_hub_download( - repo_id=f'YukunZhou/{args.finetune}', - filename=f'{args.finetune}.pth', - ) - - checkpoint = torch.load(checkpoint_path, map_location='cpu') - print("Load pre-trained checkpoint from: %s" % args.finetune) - - if args.model!='RETFound_mae': - checkpoint_model = checkpoint['teacher'] - else: - checkpoint_model = checkpoint['model'] + # ---- Load pre-trained weights (if requested and not eval-only) + if args.finetune and not args.eval: + print(f"Preparing to load pre-trained weights: {args.finetune}") + + if args.model in ["Dinov3", "Dinov2"]: + checkpoint_path = args.finetune # local path + elif args.model in ["RETFound_dinov2", "RETFound_mae"]: + print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}") + checkpoint_path = hf_hub_download( + repo_id=f"YukunZhou/{args.finetune}", + filename=f"{args.finetune}.pth", + ) + else: + raise ValueError( + f"Unsupported model '{args.model}'. " + f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae" + ) + + checkpoint = torch.load(checkpoint_path, map_location="cpu") + print(f"Loaded pre-trained checkpoint from: {checkpoint_path}") + + if args.model in ["Dinov3", "Dinov2"]: + checkpoint_model = checkpoint + elif args.model == "RETFound_dinov2": + checkpoint_model = checkpoint["teacher"] + else: # RETFound_mae + checkpoint_model = checkpoint["model"] + + # -- Key hygiene checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()} - + + # -- Remove classifier if shape mismatched state_dict = model.state_dict() - for k in ['head.weight', 'head.bias']: + for k in ["head.weight", "head.bias"]: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] - # interpolate position embedding + # -- Interpolate pos embed (ViT) interpolate_pos_embed(model, checkpoint_model) - # load pre-trained model - msg = model.load_state_dict(checkpoint_model, strict=False) + # -- Load backbone weights (non-strict) + _ = model.load_state_dict(checkpoint_model, strict=False) - trunc_normal_(model.head.weight, std=2e-5) + # -- Re-init head + if hasattr(model, "head") and hasattr(model.head, "weight"): + trunc_normal_(model.head.weight, std=2e-5) - dataset_train = build_dataset(is_train='train', args=args) - dataset_val = build_dataset(is_train='val', args=args) - dataset_test = build_dataset(is_train='test', args=args) + # ---- Datasets & samplers + dataset_train = build_dataset(is_train="train", args=args) + dataset_val = build_dataset(is_train="val", args=args) + dataset_test = build_dataset(is_train="test", args=args) + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() - if True: # args.distributed: - num_tasks = misc.get_world_size() - global_rank = misc.get_rank() - if not args.eval: - sampler_train = torch.utils.data.DistributedSampler( - dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True - ) - print("Sampler_train = %s" % str(sampler_train)) - if args.dist_eval: - if len(dataset_val) % num_tasks != 0: - print( - 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' - 'This will slightly alter validation results as extra duplicate entries are added to achieve ' - 'equal num of samples per-process.') - sampler_val = torch.utils.data.DistributedSampler( - dataset_val, num_replicas=num_tasks, rank=global_rank, - shuffle=True) # shuffle=True to reduce monitor bias - else: - sampler_val = torch.utils.data.SequentialSampler(dataset_val) - + if not args.eval: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print(f"Sampler_train = {sampler_train}") if args.dist_eval: - if len(dataset_test) % num_tasks != 0: - print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' - 'This will slightly alter validation results as extra duplicate entries are added to achieve ' - 'equal num of samples per-process.') - sampler_test = torch.utils.data.DistributedSampler( - dataset_test, num_replicas=num_tasks, rank=global_rank, - shuffle=True) # shuffle=True to reduce monitor bias + if len(dataset_val) % num_tasks != 0: + print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.") + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) else: - sampler_test = torch.utils.data.SequentialSampler(dataset_test) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + if args.dist_eval: + if len(dataset_test) % num_tasks != 0: + print("Warning: dist eval test set not divisible by #procs; results may differ slightly.") + sampler_test = torch.utils.data.DistributedSampler( + dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + else: + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + # ---- Logging if global_rank == 0 and args.log_dir is not None and not args.eval: os.makedirs(args.log_dir, exist_ok=True) log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task)) else: log_writer = None + # ---- DataLoaders if not args.eval: data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, - batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=args.pin_mem, - drop_last=True, + batch_size=args.batch_size, num_workers=args.num_workers, + pin_memory=args.pin_mem, drop_last=True, ) - - print(f'len of train_set: {len(data_loader_train) * args.batch_size}') + print(f"len of train_set: {len(data_loader_train) * args.batch_size}") data_loader_val = torch.utils.data.DataLoader( dataset_val, sampler=sampler_val, - batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=args.pin_mem, - drop_last=False + batch_size=args.batch_size, num_workers=args.num_workers, + pin_memory=args.pin_mem, drop_last=False, ) data_loader_test = torch.utils.data.DataLoader( dataset_test, sampler=sampler_test, - batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=args.pin_mem, - drop_last=False + batch_size=args.batch_size, num_workers=args.num_workers, + pin_memory=args.pin_mem, drop_last=False, ) + # ---- Mixup/CutMix mixup_fn = None - mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None) if mixup_active: print("Mixup is activated!") mixup_fn = Mixup( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, - label_smoothing=args.smoothing, num_classes=args.nb_classes) + label_smoothing=args.smoothing, num_classes=args.nb_classes + ) + # ---- Eval-only: resume weights if args.resume and args.eval: - checkpoint = torch.load(args.resume, map_location='cpu') - print("Load checkpoint from: %s" % args.resume) - model.load_state_dict(checkpoint['model']) + checkpoint = torch.load(args.resume, map_location="cpu") + print(f"Load checkpoint for eval from: {args.resume}") + model.load_state_dict(checkpoint["model"]) model.to(device) model_without_ddp = model + # ---- Adaptation toggle + if args.adaptation == "lp": + for name, param in model.named_parameters(): + param.requires_grad = ("head" in name) + print("[Adaptation] Linear probe: training classifier head only.") + else: + print("[Adaptation] Full fine-tuning: training all parameters.") + + # ---- Count trainable params n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) - print('number of model params (M): %.2f' % (n_parameters / 1.e6)) + print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}") + # ---- LR scaling by effective batch size eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() - - if args.lr is None: # only base_lr is specified + if args.lr is None: args.lr = args.blr * eff_batch_size / 256 + print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}") + print(f"actual lr: {args.lr:.2e}") + print(f"accumulate grad iterations: {args.accum_iter}") + print(f"effective batch size: {eff_batch_size}") - print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) - print("actual lr: %.2e" % args.lr) - - print("accumulate grad iterations: %d" % args.accum_iter) - print("effective batch size: %d" % eff_batch_size) - - if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + # ---- DDP (if available) + if args.distributed and torch.cuda.device_count() > 1: + ddp_kwargs = {} + if args.adaptation == "lp": + ddp_kwargs["find_unused_parameters"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], **ddp_kwargs + ) model_without_ddp = model.module + else: + model_without_ddp = model # single-GPU + + # ---- Optimizer param groups (after freezing) + no_weight_decay = (model_without_ddp.no_weight_decay() + if hasattr(model_without_ddp, "no_weight_decay") else []) + + + param_groups = lrd.param_groups_lrd( + model_without_ddp, + weight_decay=args.weight_decay, + no_weight_decay_list=no_weight_decay, + layer_decay=args.layer_decay, + ) + for g in param_groups: + g["params"] = [p for p in g["params"] if p.requires_grad] - no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else [] - param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, - no_weight_decay_list=no_weight_decay, - layer_decay=args.layer_decay - ) optimizer = torch.optim.AdamW(param_groups, lr=args.lr) loss_scaler = NativeScaler() + print(f"criterion = {criterion}") - print("criterion = %s" % str(criterion)) - - misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + # ---- Load previous full state (optimizer, scaler, etc.) + misc.load_model(args=args, model_without_ddp=model_without_ddp, + optimizer=optimizer, loss_scaler=loss_scaler) + # ========================= + # Eval-only Short Circuit + # ========================= if args.eval: - if 'epoch' in checkpoint: - print("Test with the best model at epoch = %d" % checkpoint['epoch']) - test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test', - num_class=args.nb_classes, log_writer=log_writer) - exit(0) + if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint): + print(f"Test with the best model at epoch = {checkpoint['epoch']}") + test_stats, auc_roc = evaluate( + data_loader_test, model, device, args, epoch=0, mode="test", + num_class=args.nb_classes, log_writer=log_writer + ) + return + # ========================= + # Train Loop + # ========================= print(f"Start training for {args.epochs} epochs") start_time = time.time() max_score = 0.0 best_epoch = 0 + for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) @@ -352,49 +388,55 @@ def main(args, criterion): model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, args.clip_grad, mixup_fn, - log_writer=log_writer, - args=args + log_writer=log_writer, args=args + ) + + val_stats, val_score = evaluate( + data_loader_val, model, device, args, epoch, mode="val", + num_class=args.nb_classes, log_writer=log_writer ) - val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val', - num_class=args.nb_classes, log_writer=log_writer) if max_score < val_score: max_score = val_score best_epoch = epoch if args.output_dir and args.savemodel: misc.save_model( - args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, - loss_scaler=loss_scaler, epoch=epoch, mode='best') - print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score)) - - - if epoch == (args.epochs - 1): - checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model'], strict=False) - model.to(device) - print("Test with the best model, epoch = %d:" % checkpoint['epoch']) - test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test', - num_class=args.nb_classes, log_writer=None) + args=args, model=model, model_without_ddp=model_without_ddp, + optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best" + ) + print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}") if log_writer is not None: - log_writer.add_scalar('loss/val', val_stats['loss'], epoch) + log_writer.add_scalar("loss/val", val_stats["loss"], epoch) + log_writer.flush() - log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, - 'epoch': epoch, - 'n_parameters': n_parameters} + log_stats = {**{f"train_{k}": v for k, v in train_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters} if args.output_dir and misc.is_main_process(): - if log_writer is not None: - log_writer.flush() - with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f: + with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") + # ========================= + # Final Test (Best Ckpt) + # ========================= + ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth") + checkpoint = torch.load(ckpt_path, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:") + _test_stats, _auc_roc = evaluate( + data_loader_test, model, device, args, -1, mode="test", + num_class=args.nb_classes, log_writer=None + ) + total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print(f"Training time {total_time_str}") -if __name__ == '__main__': +if __name__ == "__main__": args = get_args_parser() args = args.parse_args() @@ -402,6 +444,5 @@ if __name__ == '__main__': if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args, criterion) - - diff --git a/models_vit.py b/models_vit.py index 8cea227..82e7fbd 100644 --- a/models_vit.py +++ b/models_vit.py @@ -1,7 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# Partly revised by YZ @UCL&Moorfields -# -------------------------------------------------------- from functools import partial @@ -10,7 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor - +from timm.models.layers import trunc_normal_ class VisionTransformer(timm.models.vision_transformer.VisionTransformer): """ Vision Transformer with support for global average pooling @@ -56,6 +52,30 @@ def RETFound_mae(**kwargs): +def Dinov2(args, **kwargs): + + if args.model_arch == 'dinov2_vits14': + arch = 'vit_small_patch14_dinov2.lvd142m' + elif args.model_arch == 'dinov2_vitb14': + arch = 'vit_base_patch14_dinov2.lvd142m' + elif args.model_arch == 'dinov2_vitl14': + arch = 'vit_large_patch14_dinov2.lvd142m' + elif args.model_arch == 'dinov2_vitg14': + arch = 'vit_giant_patch14_dinov2.lvd142m' + else: + raise ValueError(f"Unknown model_arch '{args.model_arch}'. " + f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14") + + model = timm.create_model( + arch, + pretrained=True, + img_size=224, + **kwargs + ) + return model + + + def RETFound_dinov2(args, **kwargs): model = timm.create_model( 'vit_large_patch14_dinov2.lvd142m', @@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs): return model +def Dinov3(args, **kwargs): + # Load ViT-L/16 backbone (hub model has `head = Identity` by default) + model = torch.hub.load( + repo_or_dir="facebookresearch/dinov3", + model=args.model_arch, + pretrained=False, # main() will load your checkpoint + trust_repo=True, + ) + # Figure out feature dimension for the probe + feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None) + model.head = nn.Linear(feat_dim, args.nb_classes) + trunc_normal_(model.head.weight, std=2e-5) + if model.head.bias is not None: + nn.init.zeros_(model.head.bias) + + return model diff --git a/train.sh b/train.sh new file mode 100644 index 0000000..0a3ee46 --- /dev/null +++ b/train.sh @@ -0,0 +1,27 @@ +# ==== Model settings ==== +ADAPTATION="finetune" +MODEL="Dinov2" +MODEL_ARCH="dinov2_vitl14" +FINETUNE="dinov2_vitl14_pretrain.pth" + +# ==== Data settings ==== +DATASET="MESSIDOR2" +NUM_CLASS=5 + +data_path="/home/jupyter/public_dataset/${DATASET}" +task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}" + +CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \ + --model "${MODEL}" \ + --model_arch "${MODEL_ARCH}" \ + --finetune "${FINETUNE}" \ + --savemodel \ + --global_pool \ + --batch_size 24 \ + --world_size 1 \ + --epochs 50 \ + --nb_classes "${NUM_CLASS}" \ + --data_path "${data_path}" \ + --input_size 224 \ + --task "${task}" \ + --adaptation "${ADAPTATION}" diff --git a/util/datasets.py b/util/datasets.py index 0305c27..20f65f2 100644 --- a/util/datasets.py +++ b/util/datasets.py @@ -1,29 +1,39 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# Partly revised by YZ @UCL&Moorfields -# -------------------------------------------------------- - import os +import torch +from torch.utils.data import Subset from torchvision import datasets, transforms from timm.data import create_transform from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD - def build_dataset(is_train, args): transform = build_transform(is_train, args) root = os.path.join(args.data_path, is_train) dataset = datasets.ImageFolder(root, transform=transform) - return dataset + if is_train == 'train': + ratio = float(getattr(args, "dataratio", 1.0)) + seed = int(getattr(args, "seed", 0)) + stratified = bool(getattr(args, "stratified", False)) + if 0.0 < ratio < 1.0: + if stratified: + idx = _stratified_indices(dataset.targets, ratio, seed) + else: + # simple uniform subsample with torch.Generator for reproducibility + g = torch.Generator().manual_seed(seed) + n = len(dataset) + k = max(1, int(n * ratio)) + idx = torch.randperm(n, generator=g)[:k].tolist() + dataset = Subset(dataset, idx) + + return dataset def build_transform(is_train, args): mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD - # train transform + if is_train == 'train': - # this should always dispatch to transforms_imagenet_train - transform = create_transform( + return create_transform( input_size=args.input_size, is_training=True, color_jitter=args.color_jitter, @@ -35,19 +45,37 @@ def build_transform(is_train, args): mean=mean, std=std, ) - return transform # eval transform - t = [] - if args.input_size <= 224: - crop_pct = 224 / 256 - else: - crop_pct = 1.0 + crop_pct = 224 / 256 if args.input_size <= 224 else 1.0 size = int(args.input_size / crop_pct) - t.append( + t = [ transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), - ) - t.append(transforms.CenterCrop(args.input_size)) - t.append(transforms.ToTensor()) - t.append(transforms.Normalize(mean, std)) + transforms.CenterCrop(args.input_size), + transforms.ToTensor(), + transforms.Normalize(mean, std), + ] return transforms.Compose(t) + +# ---- helpers ---- + +def _stratified_indices(targets, ratio: float, seed: int): + """Maintain class proportions. Ensures at least 1 sample per class when possible.""" + t = torch.as_tensor(targets) + classes = torch.unique(t) + g = torch.Generator().manual_seed(seed) + + keep = [] + for c in classes.tolist(): + cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1) + if len(cls_idx) == 0: + continue + k = max(1, int(round(len(cls_idx) * ratio))) + sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]] + keep.extend(sel.tolist()) + + # shuffle final indices (stable across seed) + g2 = torch.Generator().manual_seed(seed + 1) + keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist() + return keep +