Incorporate DINOv3, DINOv2

This commit is contained in:
rmaphoh
2025-08-31 18:03:57 +01:00
parent 897d71c8c9
commit 409f7b6167
5 changed files with 521 additions and 327 deletions
+106 -44
View File
@@ -1,14 +1,15 @@
## RETFound - A foundation model for retinal imaging ## RETFound - A foundation model for retinal imaging
Official repo including a series of retinal foundation models.<br> Official repo including a series of foundation models and applications in retinal imaging.<br>
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).<br> `[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).<br>
[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2): `[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).<br>
`[DINOv2]`:[General-purpose vision foundation models DINOv2](https://github.com/facebookresearch/dinov2).<br>
`[DINOv3]`:[General-purpose vision foundation models DINOv3](https://github.com/facebookresearch/dinov3).<br>
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
### 📝Key features ### 📝Key features
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
### 🎉News ### 🎉News
- 🐉2025/09: **Benchmarking paper for DINOv3, DINOv2, and RETFound will come soon!**
- 🐉2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!** - 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!** - 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!** - 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online! - 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online! - 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation! - 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
### 🔧Install environment ### 🔧Install environment
@@ -40,9 +42,9 @@ conda activate retfound
2. Install dependencies 2. Install dependencies
``` ```
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
git clone https://github.com/rmaphoh/RETFound_MAE/ git clone https://github.com/rmaphoh/RETFound/
cd RETFound_MAE cd RETFound
pip install -r requirements.txt pip install -r requirements.txt
``` ```
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_meh</td> <tr><td align="left">RETFound_mae_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_shanghai</td> <tr><td align="left">RETFound_mae_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_meh</td> <tr><td align="left">RETFound_dinov2_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_shanghai</td> <tr><td align="left">RETFound_dinov2_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
</tbody></table> </tbody></table>
@@ -118,56 +120,116 @@ export HF_ENDPOINT=https://hf-mirror.com
├──class_c ├──class_c
``` ```
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training. 4. If you would like to use DINOv2 and DINOv3, please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
The model and finetune can be selected: 4. Start fine-tuning by running `sh train.sh`.
The model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE` in `train.sh`:
**RETFound**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|--------------------------|--------------------------|
| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
**DINOv3**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|----------------------------------|--------------------------|
| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
**DINOv2**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|------------------------------|--------------------------|
| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
| model | finetune |
|-----------------|--------------------------|
| RETFound_mae | RETFound_mae_natureCFP |
| RETFound_mae | RETFound_mae_natureOCT |
| RETFound_mae | RETFound_mae_meh |
| RETFound_mae | RETFound_mae_shanghai |
| RETFound_dinov2 | RETFound_dinov2_meh |
| RETFound_dinov2 | RETFound_dinov2_shanghai |
``` ```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \ # ==== Model settings ====
--model RETFound_mae \ # adaptation {finetune,lp}
ADAPTATION="finetune"
MODEL="RETFound_dinov2"
MODEL_ARCH="retfound_dinov2"
FINETUNE="RETFound_dinov2_meh"
# ==== Data settings ====
# change the dataset name and corresponding class number
DATASET="MESSIDOR2"
NUM_CLASS=5
data_path="./${DATASET}"
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--finetune "${FINETUNE}" \
--savemodel \ --savemodel \
--global_pool \ --global_pool \
--batch_size 16 \ --batch_size 24 \
--world_size 1 \ --world_size 1 \
--epochs 100 \ --epochs 50 \
--blr 5e-3 --layer_decay 0.65 \ --nb_classes "${NUM_CLASS}" \
--weight_decay 0.05 --drop_path 0.2 \ --data_path "${data_path}" \
--nb_classes 5 \
--data_path ./IDRiD \
--input_size 224 \ --input_size 224 \
--task RETFound_mae_meh-IDRiD \ --task "${task}" \
--finetune RETFound_mae_meh --adaptation "${ADAPTATION}"
``` ```
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below) 4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
``` ```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \ # ==== Model/settings (match training) ====
--model RETFound_mae \ ADAPTATION="finetune"
MODEL="RETFound_dinov2"
MODEL_ARCH="retfound_dinov2"
FINETUNE="RETFound_dinov2_meh"
# ==== Data/settings (match training) ====
DATASET="MESSIDOR2"
NUM_CLASS=5
DATA_PATH="./${DATASET}"
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
# Path to the trained checkpoint (adjust if you saved elsewhere)
CKPT="./output_dir/${TASK}/checkpoint-best.pth"
# ==== Evaluation only ====
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--savemodel \ --savemodel \
--eval \
--global_pool \ --global_pool \
--batch_size 16 \ --batch_size 128 \
--world_size 1 \ --world_size 1 \
--epochs 100 \ --nb_classes "${NUM_CLASS}" \
--blr 5e-3 --layer_decay 0.65 \ --data_path "${DATA_PATH}" \
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \
--data_path ./IDRiD \
--input_size 224 \ --input_size 224 \
--task RETFound_mae_meh-IDRiD \ --task "${TASK}" \
--resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth --adaptation "${ADAPTATION}" \
--eval \
--resume "${CKPT}"
``` ```
+263 -222
View File
@@ -1,174 +1,168 @@
#!/usr/bin/env python3
# =========================
import argparse import argparse
import datetime import datetime
import json import json
import numpy as np
import os import os
import time import time
from pathlib import Path from pathlib import Path
import warnings
import faulthandler
# =========================
import numpy as np
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup from timm.data.mixup import Mixup
from huggingface_hub import hf_hub_download, login # login imported as in original
# =========================
import models_vit as models import models_vit as models
import util.lr_decay as lrd import util.lr_decay as lrd
import util.misc as misc import util.misc as misc
from util.datasets import build_dataset from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate from engine_finetune import train_one_epoch, evaluate
import warnings # =========================
import faulthandler
faulthandler.enable() faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action="ignore", category=FutureWarning)
def get_args_parser(): def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) parser = argparse.ArgumentParser(
parser.add_argument('--batch_size', default=128, type=int, "MAE fine-tuning / linear probing for image classification", add_help=False
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') )
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters # ---- Core training
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', parser.add_argument("--batch_size", default=128, type=int,
help='Name of model to train') help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
parser.add_argument('--input_size', default=256, type=int, parser.add_argument("--epochs", default=50, type=int)
help='images input size') parser.add_argument("--accum_iter", default=1, type=int,
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT', help="Gradient accumulation steps")
help='Drop path rate (default: 0.1)')
# Optimizer parameters # ---- Model parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
help='Clip gradient norm (default: None, no clipping)') help="Model entry in models_vit.py")
parser.add_argument('--weight_decay', type=float, default=0.05, parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
help='weight decay (default: 0.05)') help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
parser.add_argument('--lr', type=float, default=None, metavar='LR', parser.add_argument("--input_size", default=256, type=int, help="Image size")
help='learning rate (absolute lr)') parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR', parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument("--cls_token", action="store_false", dest="global_pool",
parser.add_argument('--layer_decay', type=float, default=0.65, help="Use class token instead of global pool for classification")
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters # ---- Optimizer parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
help='Color jitter factor (enabled only when not using Auto/RandAug)') parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
parser.add_argument('--smoothing', type=float, default=0.1, help="Base LR: lr = blr * total_batch_size / 256")
help='Label smoothing (default: 0.1)') parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
# * Random Erase params # ---- Augmentation
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
help='Random erase prob (default: 0.25)') parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
parser.add_argument('--remode', type=str, default='pixel', parser.add_argument("--smoothing", type=float, default=0.1)
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params # ---- Random erase
parser.add_argument('--mixup', type=float, default=0, parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
help='mixup alpha, mixup enabled if > 0.') parser.add_argument("--remode", type=str, default="pixel")
parser.add_argument('--cutmix', type=float, default=0, parser.add_argument("--recount", type=int, default=1)
help='cutmix alpha, cutmix enabled if > 0.') parser.add_argument("--resplit", action="store_true", default=False)
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params # ---- Mixup/Cutmix
parser.add_argument('--finetune', default='', type=str, parser.add_argument("--mixup", type=float, default=0.0)
help='finetune from checkpoint') parser.add_argument("--cutmix", type=float, default=0.0)
parser.add_argument('--task', default='', type=str, parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
help='finetune from checkpoint') parser.add_argument("--mixup_prob", type=float, default=1.0)
parser.add_argument('--global_pool', action='store_true') parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
parser.set_defaults(global_pool=True) parser.add_argument("--mixup_mode", type=str, default="batch")
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters # ---- Finetuning & adaptation
parser.add_argument('--data_path', default='./data/', type=str, parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
help='dataset path') parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
parser.add_argument('--nb_classes', default=8, type=int, parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
help='number of the classification types') help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.set_defaults(pin_mem=True)
# distributed training parameters # ---- Dataset & paths
parser.add_argument('--world_size', default=1, type=int, parser.add_argument("--data_path", default="./data/", type=str)
help='number of distributed processes') parser.add_argument("--nb_classes", default=8, type=int)
parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument("--output_dir", default="./output_dir")
parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument("--log_dir", default="./output_logs")
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# fine-tuning parameters # >>> NEW: training data efficiency <<<
parser.add_argument('--savemodel', action='store_true', default=True, parser.add_argument(
help='Save model') "--dataratio", type=str, default="1.0",
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method') help=('Training data ratio(s) for subsampling in build_dataset. '
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data') 'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
parser.add_argument('--datasets_seed', default=2026, type=int) '(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
)
parser.add_argument(
"--stratified", action="store_true",
help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
)
# ---- Runtime
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
parser.add_argument("--eval", action="store_true", help="Evaluation only")
parser.add_argument("--dist_eval", action="store_true", default=False,
help="Distributed evaluation (faster monitoring during training)")
parser.add_argument("--num_workers", default=10, type=int)
parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
# ---- Distributed
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--dist_on_itp", action="store_true")
parser.add_argument("--dist_url", default="env://")
# ---- Misc
parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
parser.add_argument("--norm", default="IMAGENET", type=str)
parser.add_argument("--enhance", action="store_true", default=False)
parser.add_argument("--datasets_seed", default=2026, type=int)
return parser return parser
# =========================
# Main
# =========================
def main(args, criterion): def main(args, criterion):
# ---- Optionally load args from resume (when training)
if args.resume and not args.eval: if args.resume and not args.eval:
resume = args.resume resume_path = args.resume
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location="cpu")
print("Load checkpoint from: %s" % args.resume) print(f"Load checkpoint (args) from: {args.resume}")
args = checkpoint['args'] args = checkpoint["args"]
args.resume = resume args.resume = resume_path
# ---- Distributed setup
misc.init_distributed_mode(args) misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
print("{}".format(args).replace(', ', ',\n')) print(f"{args}".replace(", ", ",\n"))
device = torch.device(args.device) device = torch.device(args.device)
# fix the seed for reproducibility # ---- Reproducibility
seed = args.seed + misc.get_rank() seed = args.seed + misc.get_rank()
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
cudnn.benchmark = True cudnn.benchmark = True
if args.model=='RETFound_mae': # ---- Build model
if args.model == "RETFound_mae":
model = models.__dict__[args.model]( model = models.__dict__[args.model](
img_size=args.input_size, img_size=args.input_size,
num_classes=args.nb_classes, num_classes=args.nb_classes,
@@ -182,168 +176,210 @@ def main(args, criterion):
args=args, args=args,
) )
# ---- Load pre-trained weights (if requested and not eval-only)
if args.finetune and not args.eval: if args.finetune and not args.eval:
print(f"Preparing to load pre-trained weights: {args.finetune}")
print(f"Downloading pre-trained weights from: {args.finetune}") if args.model in ["Dinov3", "Dinov2"]:
checkpoint_path = args.finetune # local path
elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
checkpoint_path = hf_hub_download( checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}', repo_id=f"YukunZhou/{args.finetune}",
filename=f'{args.finetune}.pth', filename=f"{args.finetune}.pth",
)
else:
raise ValueError(
f"Unsupported model '{args.model}'. "
f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
) )
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location="cpu")
print("Load pre-trained checkpoint from: %s" % args.finetune) print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
if args.model!='RETFound_mae': if args.model in ["Dinov3", "Dinov2"]:
checkpoint_model = checkpoint['teacher'] checkpoint_model = checkpoint
else: elif args.model == "RETFound_dinov2":
checkpoint_model = checkpoint['model'] checkpoint_model = checkpoint["teacher"]
else: # RETFound_mae
checkpoint_model = checkpoint["model"]
# -- Key hygiene
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
# -- Remove classifier if shape mismatched
state_dict = model.state_dict() state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']: for k in ["head.weight", "head.bias"]:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint") print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k] del checkpoint_model[k]
# interpolate position embedding # -- Interpolate pos embed (ViT)
interpolate_pos_embed(model, checkpoint_model) interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model # -- Load backbone weights (non-strict)
msg = model.load_state_dict(checkpoint_model, strict=False) _ = model.load_state_dict(checkpoint_model, strict=False)
# -- Re-init head
if hasattr(model, "head") and hasattr(model.head, "weight"):
trunc_normal_(model.head.weight, std=2e-5) trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args) # ---- Datasets & samplers
dataset_val = build_dataset(is_train='val', args=args) dataset_train = build_dataset(is_train="train", args=args)
dataset_test = build_dataset(is_train='test', args=args) dataset_val = build_dataset(is_train="val", args=args)
dataset_test = build_dataset(is_train="test", args=args)
if True: # args.distributed:
num_tasks = misc.get_world_size() num_tasks = misc.get_world_size()
global_rank = misc.get_rank() global_rank = misc.get_rank()
if not args.eval: if not args.eval:
sampler_train = torch.utils.data.DistributedSampler( sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
) )
print("Sampler_train = %s" % str(sampler_train)) print(f"Sampler_train = {sampler_train}")
if args.dist_eval: if args.dist_eval:
if len(dataset_val) % num_tasks != 0: if len(dataset_val) % num_tasks != 0:
print( print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler( sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
shuffle=True) # shuffle=True to reduce monitor bias )
else: else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval: if args.dist_eval:
if len(dataset_test) % num_tasks != 0: if len(dataset_test) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler( sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank, dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
shuffle=True) # shuffle=True to reduce monitor bias )
else: else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test) sampler_test = torch.utils.data.SequentialSampler(dataset_test)
# ---- Logging
if global_rank == 0 and args.log_dir is not None and not args.eval: if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task)) log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else: else:
log_writer = None log_writer = None
# ---- DataLoaders
if not args.eval: if not args.eval:
data_loader_train = torch.utils.data.DataLoader( data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train, dataset_train, sampler=sampler_train,
batch_size=args.batch_size, batch_size=args.batch_size, num_workers=args.num_workers,
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True,
pin_memory=args.pin_mem,
drop_last=True,
) )
print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
data_loader_val = torch.utils.data.DataLoader( data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val, dataset_val, sampler=sampler_val,
batch_size=args.batch_size, batch_size=args.batch_size, num_workers=args.num_workers,
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False,
pin_memory=args.pin_mem,
drop_last=False
) )
data_loader_test = torch.utils.data.DataLoader( data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test, dataset_test, sampler=sampler_test,
batch_size=args.batch_size, batch_size=args.batch_size, num_workers=args.num_workers,
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False,
pin_memory=args.pin_mem,
drop_last=False
) )
# ---- Mixup/CutMix
mixup_fn = None mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
if mixup_active: if mixup_active:
print("Mixup is activated!") print("Mixup is activated!")
mixup_fn = Mixup( mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes) label_smoothing=args.smoothing, num_classes=args.nb_classes
)
# ---- Eval-only: resume weights
if args.resume and args.eval: if args.resume and args.eval:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location="cpu")
print("Load checkpoint from: %s" % args.resume) print(f"Load checkpoint for eval from: {args.resume}")
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint["model"])
model.to(device) model.to(device)
model_without_ddp = model model_without_ddp = model
# ---- Adaptation toggle
if args.adaptation == "lp":
for name, param in model.named_parameters():
param.requires_grad = ("head" in name)
print("[Adaptation] Linear probe: training classifier head only.")
else:
print("[Adaptation] Full fine-tuning: training all parameters.")
# ---- Count trainable params
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of model params (M): %.2f' % (n_parameters / 1.e6)) print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
# ---- LR scaling by effective batch size
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None:
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256 args.lr = args.blr * eff_batch_size / 256
print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
print(f"actual lr: {args.lr:.2e}")
print(f"accumulate grad iterations: {args.accum_iter}")
print(f"effective batch size: {eff_batch_size}")
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) # ---- DDP (if available)
print("actual lr: %.2e" % args.lr) if args.distributed and torch.cuda.device_count() > 1:
ddp_kwargs = {}
print("accumulate grad iterations: %d" % args.accum_iter) if args.adaptation == "lp":
print("effective batch size: %d" % eff_batch_size) ddp_kwargs["find_unused_parameters"] = True
model = torch.nn.parallel.DistributedDataParallel(
if args.distributed: model, device_ids=[args.gpu], **ddp_kwargs
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay
) )
model_without_ddp = model.module
else:
model_without_ddp = model # single-GPU
# ---- Optimizer param groups (after freezing)
no_weight_decay = (model_without_ddp.no_weight_decay()
if hasattr(model_without_ddp, "no_weight_decay") else [])
param_groups = lrd.param_groups_lrd(
model_without_ddp,
weight_decay=args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay,
)
for g in param_groups:
g["params"] = [p for p in g["params"] if p.requires_grad]
optimizer = torch.optim.AdamW(param_groups, lr=args.lr) optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler() loss_scaler = NativeScaler()
print(f"criterion = {criterion}")
print("criterion = %s" % str(criterion)) # ---- Load previous full state (optimizer, scaler, etc.)
misc.load_model(args=args, model_without_ddp=model_without_ddp,
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) optimizer=optimizer, loss_scaler=loss_scaler)
# =========================
# Eval-only Short Circuit
# =========================
if args.eval: if args.eval:
if 'epoch' in checkpoint: if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
print("Test with the best model at epoch = %d" % checkpoint['epoch']) print(f"Test with the best model at epoch = {checkpoint['epoch']}")
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test', test_stats, auc_roc = evaluate(
num_class=args.nb_classes, log_writer=log_writer) data_loader_test, model, device, args, epoch=0, mode="test",
exit(0) num_class=args.nb_classes, log_writer=log_writer
)
return
# =========================
# Train Loop
# =========================
print(f"Start training for {args.epochs} epochs") print(f"Start training for {args.epochs} epochs")
start_time = time.time() start_time = time.time()
max_score = 0.0 max_score = 0.0
best_epoch = 0 best_epoch = 0
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
data_loader_train.sampler.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch)
@@ -352,49 +388,55 @@ def main(args, criterion):
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn, args.clip_grad, mixup_fn,
log_writer=log_writer, log_writer=log_writer, args=args
args=args )
val_stats, val_score = evaluate(
data_loader_val, model, device, args, epoch, mode="val",
num_class=args.nb_classes, log_writer=log_writer
) )
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score: if max_score < val_score:
max_score = val_score max_score = val_score
best_epoch = epoch best_epoch = epoch
if args.output_dir and args.savemodel: if args.output_dir and args.savemodel:
misc.save_model( misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, args=args, model=model, model_without_ddp=model_without_ddp,
loss_scaler=loss_scaler, epoch=epoch, mode='best') optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score)) )
print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None: if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch) log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
log_writer.flush()
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
'epoch': epoch, "epoch": epoch,
'n_parameters': n_parameters} "n_parameters": n_parameters}
if args.output_dir and misc.is_main_process(): if args.output_dir and misc.is_main_process():
if log_writer is not None: with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n") f.write(json.dumps(log_stats) + "\n")
# =========================
# Final Test (Best Ckpt)
# =========================
ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
checkpoint = torch.load(ckpt_path, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
_test_stats, _auc_roc = evaluate(
data_loader_test, model, device, args, -1, mode="test",
num_class=args.nb_classes, log_writer=None
)
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str)) print(f"Training time {total_time_str}")
if __name__ == '__main__': if __name__ == "__main__":
args = get_args_parser() args = get_args_parser()
args = args.parse_args() args = args.parse_args()
@@ -402,6 +444,5 @@ if __name__ == '__main__':
if args.output_dir: if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion) main(args, criterion)
+41 -5
View File
@@ -1,7 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
from functools import partial from functools import partial
@@ -10,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from timm.models.layers import trunc_normal_
class VisionTransformer(timm.models.vision_transformer.VisionTransformer): class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling """ Vision Transformer with support for global average pooling
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
def Dinov2(args, **kwargs):
if args.model_arch == 'dinov2_vits14':
arch = 'vit_small_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitb14':
arch = 'vit_base_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitl14':
arch = 'vit_large_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitg14':
arch = 'vit_giant_patch14_dinov2.lvd142m'
else:
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
model = timm.create_model(
arch,
pretrained=True,
img_size=224,
**kwargs
)
return model
def RETFound_dinov2(args, **kwargs): def RETFound_dinov2(args, **kwargs):
model = timm.create_model( model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m', 'vit_large_patch14_dinov2.lvd142m',
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
return model return model
def Dinov3(args, **kwargs):
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
model = torch.hub.load(
repo_or_dir="facebookresearch/dinov3",
model=args.model_arch,
pretrained=False, # main() will load your checkpoint
trust_repo=True,
)
# Figure out feature dimension for the probe
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
model.head = nn.Linear(feat_dim, args.nb_classes)
trunc_normal_(model.head.weight, std=2e-5)
if model.head.bias is not None:
nn.init.zeros_(model.head.bias)
return model
+27
View File
@@ -0,0 +1,27 @@
# ==== Model settings ====
ADAPTATION="finetune"
MODEL="Dinov2"
MODEL_ARCH="dinov2_vitl14"
FINETUNE="dinov2_vitl14_pretrain.pth"
# ==== Data settings ====
DATASET="MESSIDOR2"
NUM_CLASS=5
data_path="/home/jupyter/public_dataset/${DATASET}"
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--finetune "${FINETUNE}" \
--savemodel \
--global_pool \
--batch_size 24 \
--world_size 1 \
--epochs 50 \
--nb_classes "${NUM_CLASS}" \
--data_path "${data_path}" \
--input_size 224 \
--task "${task}" \
--adaptation "${ADAPTATION}"
+49 -21
View File
@@ -1,29 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
import os import os
import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms from torchvision import datasets, transforms
from timm.data import create_transform from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args): def build_dataset(is_train, args):
transform = build_transform(is_train, args) transform = build_transform(is_train, args)
root = os.path.join(args.data_path, is_train) root = os.path.join(args.data_path, is_train)
dataset = datasets.ImageFolder(root, transform=transform) dataset = datasets.ImageFolder(root, transform=transform)
return dataset if is_train == 'train':
ratio = float(getattr(args, "dataratio", 1.0))
seed = int(getattr(args, "seed", 0))
stratified = bool(getattr(args, "stratified", False))
if 0.0 < ratio < 1.0:
if stratified:
idx = _stratified_indices(dataset.targets, ratio, seed)
else:
# simple uniform subsample with torch.Generator for reproducibility
g = torch.Generator().manual_seed(seed)
n = len(dataset)
k = max(1, int(n * ratio))
idx = torch.randperm(n, generator=g)[:k].tolist()
dataset = Subset(dataset, idx)
return dataset
def build_transform(is_train, args): def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD std = IMAGENET_DEFAULT_STD
# train transform
if is_train == 'train': if is_train == 'train':
# this should always dispatch to transforms_imagenet_train return create_transform(
transform = create_transform(
input_size=args.input_size, input_size=args.input_size,
is_training=True, is_training=True,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
mean=mean, mean=mean,
std=std, std=std,
) )
return transform
# eval transform # eval transform
t = [] crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct) size = int(args.input_size / crop_pct)
t.append( t = [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
) transforms.CenterCrop(args.input_size),
t.append(transforms.CenterCrop(args.input_size)) transforms.ToTensor(),
t.append(transforms.ToTensor()) transforms.Normalize(mean, std),
t.append(transforms.Normalize(mean, std)) ]
return transforms.Compose(t) return transforms.Compose(t)
# ---- helpers ----
def _stratified_indices(targets, ratio: float, seed: int):
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
t = torch.as_tensor(targets)
classes = torch.unique(t)
g = torch.Generator().manual_seed(seed)
keep = []
for c in classes.tolist():
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
if len(cls_idx) == 0:
continue
k = max(1, int(round(len(cls_idx) * ratio)))
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
keep.extend(sel.tolist())
# shuffle final indices (stable across seed)
g2 = torch.Generator().manual_seed(seed + 1)
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
return keep