Incorporate DINOv3, DINOv2
This commit is contained in:
@@ -1,14 +1,15 @@
|
|||||||
## RETFound - A foundation model for retinal imaging
|
## RETFound - A foundation model for retinal imaging
|
||||||
|
|
||||||
|
|
||||||
Official repo including a series of retinal foundation models.<br>
|
Official repo including a series of foundation models and applications in retinal imaging.<br>
|
||||||
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).<br>
|
`[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).<br>
|
||||||
[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2):
|
`[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).<br>
|
||||||
|
`[DINOv2]`:[General-purpose vision foundation models DINOv2](https://github.com/facebookresearch/dinov2).<br>
|
||||||
|
`[DINOv3]`:[General-purpose vision foundation models DINOv3](https://github.com/facebookresearch/dinov3).<br>
|
||||||
|
|
||||||
|
|
||||||
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
|
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
|
||||||
|
|
||||||
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
|
|
||||||
|
|
||||||
|
|
||||||
### 📝Key features
|
### 📝Key features
|
||||||
|
|
||||||
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
|
|||||||
|
|
||||||
### 🎉News
|
### 🎉News
|
||||||
|
|
||||||
|
- 🐉2025/09: **Benchmarking paper for DINOv3, DINOv2, and RETFound will come soon!**
|
||||||
|
- 🐉2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
|
||||||
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
|
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
|
||||||
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
|
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
|
||||||
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
|
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
|
||||||
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
|
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
|
||||||
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
|
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
|
||||||
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
|
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
|
||||||
- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
|
|
||||||
|
|
||||||
|
|
||||||
### 🔧Install environment
|
### 🔧Install environment
|
||||||
@@ -40,9 +42,9 @@ conda activate retfound
|
|||||||
2. Install dependencies
|
2. Install dependencies
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
|
||||||
git clone https://github.com/rmaphoh/RETFound_MAE/
|
git clone https://github.com/rmaphoh/RETFound/
|
||||||
cd RETFound_MAE
|
cd RETFound
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
|
|||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_mae_meh</td>
|
<tr><td align="left">RETFound_mae_meh</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_mae_shanghai</td>
|
<tr><td align="left">RETFound_mae_shanghai</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_dinov2_meh</td>
|
<tr><td align="left">RETFound_dinov2_meh</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_dinov2_shanghai</td>
|
<tr><td align="left">RETFound_dinov2_shanghai</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody></table>
|
</tbody></table>
|
||||||
|
|
||||||
@@ -118,56 +120,116 @@ export HF_ENDPOINT=https://hf-mirror.com
|
|||||||
├──class_c
|
├──class_c
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
|
4. If you would like to use DINOv2 and DINOv3, please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
|
||||||
|
|
||||||
The model and finetune can be selected:
|
4. Start fine-tuning by running `sh train.sh`.
|
||||||
|
|
||||||
|
|
||||||
|
The model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE` in `train.sh`:
|
||||||
|
|
||||||
|
**RETFound**:
|
||||||
|
|
||||||
|
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||||
|
|-----------------|--------------------------|--------------------------|--------------------------|
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
|
||||||
|
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
|
||||||
|
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
|
||||||
|
|
||||||
|
|
||||||
|
**DINOv3**:
|
||||||
|
|
||||||
|
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||||
|
|-----------------|--------------------------|----------------------------------|--------------------------|
|
||||||
|
| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
|
||||||
|
| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
|
||||||
|
| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
|
||||||
|
| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
|
||||||
|
| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
|
||||||
|
| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
|
||||||
|
|
||||||
|
|
||||||
|
**DINOv2**:
|
||||||
|
|
||||||
|
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||||
|
|-----------------|--------------------------|------------------------------|--------------------------|
|
||||||
|
| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
|
||||||
|
| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
|
||||||
|
| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
|
||||||
|
| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
|
||||||
|
|
||||||
| model | finetune |
|
|
||||||
|-----------------|--------------------------|
|
|
||||||
| RETFound_mae | RETFound_mae_natureCFP |
|
|
||||||
| RETFound_mae | RETFound_mae_natureOCT |
|
|
||||||
| RETFound_mae | RETFound_mae_meh |
|
|
||||||
| RETFound_mae | RETFound_mae_shanghai |
|
|
||||||
| RETFound_dinov2 | RETFound_dinov2_meh |
|
|
||||||
| RETFound_dinov2 | RETFound_dinov2_shanghai |
|
|
||||||
|
|
||||||
```
|
```
|
||||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
# ==== Model settings ====
|
||||||
--model RETFound_mae \
|
# adaptation {finetune,lp}
|
||||||
|
ADAPTATION="finetune"
|
||||||
|
MODEL="RETFound_dinov2"
|
||||||
|
MODEL_ARCH="retfound_dinov2"
|
||||||
|
FINETUNE="RETFound_dinov2_meh"
|
||||||
|
|
||||||
|
# ==== Data settings ====
|
||||||
|
# change the dataset name and corresponding class number
|
||||||
|
DATASET="MESSIDOR2"
|
||||||
|
NUM_CLASS=5
|
||||||
|
data_path="./${DATASET}"
|
||||||
|
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||||
|
|
||||||
|
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||||
|
--model "${MODEL}" \
|
||||||
|
--model_arch "${MODEL_ARCH}" \
|
||||||
|
--finetune "${FINETUNE}" \
|
||||||
--savemodel \
|
--savemodel \
|
||||||
--global_pool \
|
--global_pool \
|
||||||
--batch_size 16 \
|
--batch_size 24 \
|
||||||
--world_size 1 \
|
--world_size 1 \
|
||||||
--epochs 100 \
|
--epochs 50 \
|
||||||
--blr 5e-3 --layer_decay 0.65 \
|
--nb_classes "${NUM_CLASS}" \
|
||||||
--weight_decay 0.05 --drop_path 0.2 \
|
--data_path "${data_path}" \
|
||||||
--nb_classes 5 \
|
|
||||||
--data_path ./IDRiD \
|
|
||||||
--input_size 224 \
|
--input_size 224 \
|
||||||
--task RETFound_mae_meh-IDRiD \
|
--task "${task}" \
|
||||||
--finetune RETFound_mae_meh
|
--adaptation "${ADAPTATION}"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
|
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
# ==== Model/settings (match training) ====
|
||||||
--model RETFound_mae \
|
ADAPTATION="finetune"
|
||||||
|
MODEL="RETFound_dinov2"
|
||||||
|
MODEL_ARCH="retfound_dinov2"
|
||||||
|
FINETUNE="RETFound_dinov2_meh"
|
||||||
|
|
||||||
|
# ==== Data/settings (match training) ====
|
||||||
|
DATASET="MESSIDOR2"
|
||||||
|
NUM_CLASS=5
|
||||||
|
DATA_PATH="./${DATASET}"
|
||||||
|
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||||
|
|
||||||
|
# Path to the trained checkpoint (adjust if you saved elsewhere)
|
||||||
|
CKPT="./output_dir/${TASK}/checkpoint-best.pth"
|
||||||
|
|
||||||
|
# ==== Evaluation only ====
|
||||||
|
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||||
|
--model "${MODEL}" \
|
||||||
|
--model_arch "${MODEL_ARCH}" \
|
||||||
--savemodel \
|
--savemodel \
|
||||||
--eval \
|
|
||||||
--global_pool \
|
--global_pool \
|
||||||
--batch_size 16 \
|
--batch_size 128 \
|
||||||
--world_size 1 \
|
--world_size 1 \
|
||||||
--epochs 100 \
|
--nb_classes "${NUM_CLASS}" \
|
||||||
--blr 5e-3 --layer_decay 0.65 \
|
--data_path "${DATA_PATH}" \
|
||||||
--weight_decay 0.05 --drop_path 0.2 \
|
|
||||||
--nb_classes 5 \
|
|
||||||
--data_path ./IDRiD \
|
|
||||||
--input_size 224 \
|
--input_size 224 \
|
||||||
--task RETFound_mae_meh-IDRiD \
|
--task "${TASK}" \
|
||||||
--resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth
|
--adaptation "${ADAPTATION}" \
|
||||||
|
--eval \
|
||||||
|
--resume "${CKPT}"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+263
-222
@@ -1,174 +1,168 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# =========================
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
import faulthandler
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from timm.models.layers import trunc_normal_
|
from timm.models.layers import trunc_normal_
|
||||||
from timm.data.mixup import Mixup
|
from timm.data.mixup import Mixup
|
||||||
|
from huggingface_hub import hf_hub_download, login # login imported as in original
|
||||||
|
|
||||||
|
# =========================
|
||||||
import models_vit as models
|
import models_vit as models
|
||||||
import util.lr_decay as lrd
|
import util.lr_decay as lrd
|
||||||
import util.misc as misc
|
import util.misc as misc
|
||||||
from util.datasets import build_dataset
|
from util.datasets import build_dataset
|
||||||
from util.pos_embed import interpolate_pos_embed
|
from util.pos_embed import interpolate_pos_embed
|
||||||
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
||||||
from huggingface_hub import hf_hub_download, login
|
|
||||||
from engine_finetune import train_one_epoch, evaluate
|
from engine_finetune import train_one_epoch, evaluate
|
||||||
|
|
||||||
import warnings
|
# =========================
|
||||||
import faulthandler
|
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
def get_args_parser():
|
||||||
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--batch_size', default=128, type=int,
|
"MAE fine-tuning / linear probing for image classification", add_help=False
|
||||||
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
)
|
||||||
parser.add_argument('--epochs', default=50, type=int)
|
|
||||||
parser.add_argument('--accum_iter', default=1, type=int,
|
|
||||||
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
|
||||||
|
|
||||||
# Model parameters
|
# ---- Core training
|
||||||
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
parser.add_argument("--batch_size", default=128, type=int,
|
||||||
help='Name of model to train')
|
help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
|
||||||
parser.add_argument('--input_size', default=256, type=int,
|
parser.add_argument("--epochs", default=50, type=int)
|
||||||
help='images input size')
|
parser.add_argument("--accum_iter", default=1, type=int,
|
||||||
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
|
help="Gradient accumulation steps")
|
||||||
help='Drop path rate (default: 0.1)')
|
|
||||||
|
|
||||||
# Optimizer parameters
|
# ---- Model parameters
|
||||||
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
|
||||||
help='Clip gradient norm (default: None, no clipping)')
|
help="Model entry in models_vit.py")
|
||||||
parser.add_argument('--weight_decay', type=float, default=0.05,
|
parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
|
||||||
help='weight decay (default: 0.05)')
|
help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
|
||||||
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
parser.add_argument("--input_size", default=256, type=int, help="Image size")
|
||||||
help='learning rate (absolute lr)')
|
parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
|
||||||
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
|
parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
|
||||||
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
parser.add_argument("--cls_token", action="store_false", dest="global_pool",
|
||||||
parser.add_argument('--layer_decay', type=float, default=0.65,
|
help="Use class token instead of global pool for classification")
|
||||||
help='layer-wise lr decay from ELECTRA/BEiT')
|
|
||||||
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
|
||||||
help='lower lr bound for cyclic schedulers that hit 0')
|
|
||||||
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
|
||||||
help='epochs to warmup LR')
|
|
||||||
|
|
||||||
# Augmentation parameters
|
# ---- Optimizer parameters
|
||||||
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
|
parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
|
||||||
help='Color jitter factor (enabled only when not using Auto/RandAug)')
|
parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
|
||||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
|
||||||
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
|
parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
|
||||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
help="Base LR: lr = blr * total_batch_size / 256")
|
||||||
help='Label smoothing (default: 0.1)')
|
parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
|
||||||
|
parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
|
||||||
|
parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
|
||||||
|
|
||||||
# * Random Erase params
|
# ---- Augmentation
|
||||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
|
||||||
help='Random erase prob (default: 0.25)')
|
parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
|
||||||
parser.add_argument('--remode', type=str, default='pixel',
|
parser.add_argument("--smoothing", type=float, default=0.1)
|
||||||
help='Random erase mode (default: "pixel")')
|
|
||||||
parser.add_argument('--recount', type=int, default=1,
|
|
||||||
help='Random erase count (default: 1)')
|
|
||||||
parser.add_argument('--resplit', action='store_true', default=False,
|
|
||||||
help='Do not random erase first (clean) augmentation split')
|
|
||||||
|
|
||||||
# * Mixup params
|
# ---- Random erase
|
||||||
parser.add_argument('--mixup', type=float, default=0,
|
parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
|
||||||
help='mixup alpha, mixup enabled if > 0.')
|
parser.add_argument("--remode", type=str, default="pixel")
|
||||||
parser.add_argument('--cutmix', type=float, default=0,
|
parser.add_argument("--recount", type=int, default=1)
|
||||||
help='cutmix alpha, cutmix enabled if > 0.')
|
parser.add_argument("--resplit", action="store_true", default=False)
|
||||||
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
|
||||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
|
||||||
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
|
||||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
|
||||||
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
|
||||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
|
||||||
parser.add_argument('--mixup_mode', type=str, default='batch',
|
|
||||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
|
||||||
|
|
||||||
# * Finetuning params
|
# ---- Mixup/Cutmix
|
||||||
parser.add_argument('--finetune', default='', type=str,
|
parser.add_argument("--mixup", type=float, default=0.0)
|
||||||
help='finetune from checkpoint')
|
parser.add_argument("--cutmix", type=float, default=0.0)
|
||||||
parser.add_argument('--task', default='', type=str,
|
parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
|
||||||
help='finetune from checkpoint')
|
parser.add_argument("--mixup_prob", type=float, default=1.0)
|
||||||
parser.add_argument('--global_pool', action='store_true')
|
parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
|
||||||
parser.set_defaults(global_pool=True)
|
parser.add_argument("--mixup_mode", type=str, default="batch")
|
||||||
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
|
||||||
help='Use class token instead of global pool for classification')
|
|
||||||
|
|
||||||
# Dataset parameters
|
# ---- Finetuning & adaptation
|
||||||
parser.add_argument('--data_path', default='./data/', type=str,
|
parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
|
||||||
help='dataset path')
|
parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
|
||||||
parser.add_argument('--nb_classes', default=8, type=int,
|
parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
|
||||||
help='number of the classification types')
|
help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
|
||||||
parser.add_argument('--output_dir', default='./output_dir',
|
|
||||||
help='path where to save, empty for no saving')
|
|
||||||
parser.add_argument('--log_dir', default='./output_logs',
|
|
||||||
help='path where to tensorboard log')
|
|
||||||
parser.add_argument('--device', default='cuda',
|
|
||||||
help='device to use for training / testing')
|
|
||||||
parser.add_argument('--seed', default=0, type=int)
|
|
||||||
parser.add_argument('--resume', default='',
|
|
||||||
help='resume from checkpoint')
|
|
||||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
|
||||||
help='start epoch')
|
|
||||||
parser.add_argument('--eval', action='store_true',
|
|
||||||
help='Perform evaluation only')
|
|
||||||
parser.add_argument('--dist_eval', action='store_true', default=False,
|
|
||||||
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
|
||||||
parser.add_argument('--num_workers', default=10, type=int)
|
|
||||||
parser.add_argument('--pin_mem', action='store_true',
|
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
|
||||||
parser.set_defaults(pin_mem=True)
|
|
||||||
|
|
||||||
# distributed training parameters
|
# ---- Dataset & paths
|
||||||
parser.add_argument('--world_size', default=1, type=int,
|
parser.add_argument("--data_path", default="./data/", type=str)
|
||||||
help='number of distributed processes')
|
parser.add_argument("--nb_classes", default=8, type=int)
|
||||||
parser.add_argument('--local_rank', default=-1, type=int)
|
parser.add_argument("--output_dir", default="./output_dir")
|
||||||
parser.add_argument('--dist_on_itp', action='store_true')
|
parser.add_argument("--log_dir", default="./output_logs")
|
||||||
parser.add_argument('--dist_url', default='env://',
|
|
||||||
help='url used to set up distributed training')
|
|
||||||
|
|
||||||
# fine-tuning parameters
|
# >>> NEW: training data efficiency <<<
|
||||||
parser.add_argument('--savemodel', action='store_true', default=True,
|
parser.add_argument(
|
||||||
help='Save model')
|
"--dataratio", type=str, default="1.0",
|
||||||
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
|
help=('Training data ratio(s) for subsampling in build_dataset. '
|
||||||
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
|
'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
|
||||||
parser.add_argument('--datasets_seed', default=2026, type=int)
|
'(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stratified", action="store_true",
|
||||||
|
help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Runtime
|
||||||
|
parser.add_argument("--device", default="cuda")
|
||||||
|
parser.add_argument("--seed", default=0, type=int)
|
||||||
|
parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
|
||||||
|
parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
|
||||||
|
parser.add_argument("--eval", action="store_true", help="Evaluation only")
|
||||||
|
parser.add_argument("--dist_eval", action="store_true", default=False,
|
||||||
|
help="Distributed evaluation (faster monitoring during training)")
|
||||||
|
parser.add_argument("--num_workers", default=10, type=int)
|
||||||
|
parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
|
||||||
|
|
||||||
|
# ---- Distributed
|
||||||
|
parser.add_argument("--world_size", default=1, type=int)
|
||||||
|
parser.add_argument("--local_rank", default=-1, type=int)
|
||||||
|
parser.add_argument("--dist_on_itp", action="store_true")
|
||||||
|
parser.add_argument("--dist_url", default="env://")
|
||||||
|
|
||||||
|
# ---- Misc
|
||||||
|
parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
|
||||||
|
parser.add_argument("--norm", default="IMAGENET", type=str)
|
||||||
|
parser.add_argument("--enhance", action="store_true", default=False)
|
||||||
|
parser.add_argument("--datasets_seed", default=2026, type=int)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Main
|
||||||
|
# =========================
|
||||||
def main(args, criterion):
|
def main(args, criterion):
|
||||||
|
# ---- Optionally load args from resume (when training)
|
||||||
if args.resume and not args.eval:
|
if args.resume and not args.eval:
|
||||||
resume = args.resume
|
resume_path = args.resume
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||||
print("Load checkpoint from: %s" % args.resume)
|
print(f"Load checkpoint (args) from: {args.resume}")
|
||||||
args = checkpoint['args']
|
args = checkpoint["args"]
|
||||||
args.resume = resume
|
args.resume = resume_path
|
||||||
|
|
||||||
|
# ---- Distributed setup
|
||||||
misc.init_distributed_mode(args)
|
misc.init_distributed_mode(args)
|
||||||
|
|
||||||
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
|
||||||
print("{}".format(args).replace(', ', ',\n'))
|
print(f"{args}".replace(", ", ",\n"))
|
||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|
||||||
# fix the seed for reproducibility
|
# ---- Reproducibility
|
||||||
seed = args.seed + misc.get_rank()
|
seed = args.seed + misc.get_rank()
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
if args.model=='RETFound_mae':
|
# ---- Build model
|
||||||
|
if args.model == "RETFound_mae":
|
||||||
model = models.__dict__[args.model](
|
model = models.__dict__[args.model](
|
||||||
img_size=args.input_size,
|
img_size=args.input_size,
|
||||||
num_classes=args.nb_classes,
|
num_classes=args.nb_classes,
|
||||||
@@ -182,168 +176,210 @@ def main(args, criterion):
|
|||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- Load pre-trained weights (if requested and not eval-only)
|
||||||
if args.finetune and not args.eval:
|
if args.finetune and not args.eval:
|
||||||
|
print(f"Preparing to load pre-trained weights: {args.finetune}")
|
||||||
|
|
||||||
print(f"Downloading pre-trained weights from: {args.finetune}")
|
if args.model in ["Dinov3", "Dinov2"]:
|
||||||
|
checkpoint_path = args.finetune # local path
|
||||||
|
elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
|
||||||
|
print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
|
||||||
checkpoint_path = hf_hub_download(
|
checkpoint_path = hf_hub_download(
|
||||||
repo_id=f'YukunZhou/{args.finetune}',
|
repo_id=f"YukunZhou/{args.finetune}",
|
||||||
filename=f'{args.finetune}.pth',
|
filename=f"{args.finetune}.pth",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model '{args.model}'. "
|
||||||
|
f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
|
||||||
)
|
)
|
||||||
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
|
||||||
|
|
||||||
if args.model!='RETFound_mae':
|
if args.model in ["Dinov3", "Dinov2"]:
|
||||||
checkpoint_model = checkpoint['teacher']
|
checkpoint_model = checkpoint
|
||||||
else:
|
elif args.model == "RETFound_dinov2":
|
||||||
checkpoint_model = checkpoint['model']
|
checkpoint_model = checkpoint["teacher"]
|
||||||
|
else: # RETFound_mae
|
||||||
|
checkpoint_model = checkpoint["model"]
|
||||||
|
|
||||||
|
# -- Key hygiene
|
||||||
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
||||||
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
||||||
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
||||||
|
|
||||||
|
# -- Remove classifier if shape mismatched
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
for k in ['head.weight', 'head.bias']:
|
for k in ["head.weight", "head.bias"]:
|
||||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||||
print(f"Removing key {k} from pretrained checkpoint")
|
print(f"Removing key {k} from pretrained checkpoint")
|
||||||
del checkpoint_model[k]
|
del checkpoint_model[k]
|
||||||
|
|
||||||
# interpolate position embedding
|
# -- Interpolate pos embed (ViT)
|
||||||
interpolate_pos_embed(model, checkpoint_model)
|
interpolate_pos_embed(model, checkpoint_model)
|
||||||
|
|
||||||
# load pre-trained model
|
# -- Load backbone weights (non-strict)
|
||||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
_ = model.load_state_dict(checkpoint_model, strict=False)
|
||||||
|
|
||||||
|
# -- Re-init head
|
||||||
|
if hasattr(model, "head") and hasattr(model.head, "weight"):
|
||||||
trunc_normal_(model.head.weight, std=2e-5)
|
trunc_normal_(model.head.weight, std=2e-5)
|
||||||
|
|
||||||
dataset_train = build_dataset(is_train='train', args=args)
|
# ---- Datasets & samplers
|
||||||
dataset_val = build_dataset(is_train='val', args=args)
|
dataset_train = build_dataset(is_train="train", args=args)
|
||||||
dataset_test = build_dataset(is_train='test', args=args)
|
dataset_val = build_dataset(is_train="val", args=args)
|
||||||
|
dataset_test = build_dataset(is_train="test", args=args)
|
||||||
|
|
||||||
|
|
||||||
if True: # args.distributed:
|
|
||||||
num_tasks = misc.get_world_size()
|
num_tasks = misc.get_world_size()
|
||||||
global_rank = misc.get_rank()
|
global_rank = misc.get_rank()
|
||||||
|
|
||||||
if not args.eval:
|
if not args.eval:
|
||||||
sampler_train = torch.utils.data.DistributedSampler(
|
sampler_train = torch.utils.data.DistributedSampler(
|
||||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
)
|
)
|
||||||
print("Sampler_train = %s" % str(sampler_train))
|
print(f"Sampler_train = {sampler_train}")
|
||||||
if args.dist_eval:
|
if args.dist_eval:
|
||||||
if len(dataset_val) % num_tasks != 0:
|
if len(dataset_val) % num_tasks != 0:
|
||||||
print(
|
print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
|
||||||
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
||||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
|
||||||
'equal num of samples per-process.')
|
|
||||||
sampler_val = torch.utils.data.DistributedSampler(
|
sampler_val = torch.utils.data.DistributedSampler(
|
||||||
dataset_val, num_replicas=num_tasks, rank=global_rank,
|
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
shuffle=True) # shuffle=True to reduce monitor bias
|
)
|
||||||
else:
|
else:
|
||||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||||
|
|
||||||
if args.dist_eval:
|
if args.dist_eval:
|
||||||
if len(dataset_test) % num_tasks != 0:
|
if len(dataset_test) % num_tasks != 0:
|
||||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
|
||||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
|
||||||
'equal num of samples per-process.')
|
|
||||||
sampler_test = torch.utils.data.DistributedSampler(
|
sampler_test = torch.utils.data.DistributedSampler(
|
||||||
dataset_test, num_replicas=num_tasks, rank=global_rank,
|
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
shuffle=True) # shuffle=True to reduce monitor bias
|
)
|
||||||
else:
|
else:
|
||||||
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
||||||
|
|
||||||
|
# ---- Logging
|
||||||
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
||||||
os.makedirs(args.log_dir, exist_ok=True)
|
os.makedirs(args.log_dir, exist_ok=True)
|
||||||
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
||||||
else:
|
else:
|
||||||
log_writer = None
|
log_writer = None
|
||||||
|
|
||||||
|
# ---- DataLoaders
|
||||||
if not args.eval:
|
if not args.eval:
|
||||||
data_loader_train = torch.utils.data.DataLoader(
|
data_loader_train = torch.utils.data.DataLoader(
|
||||||
dataset_train, sampler=sampler_train,
|
dataset_train, sampler=sampler_train,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers,
|
pin_memory=args.pin_mem, drop_last=True,
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
)
|
||||||
|
print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
|
||||||
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
|
|
||||||
|
|
||||||
data_loader_val = torch.utils.data.DataLoader(
|
data_loader_val = torch.utils.data.DataLoader(
|
||||||
dataset_val, sampler=sampler_val,
|
dataset_val, sampler=sampler_val,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers,
|
pin_memory=args.pin_mem, drop_last=False,
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader_test = torch.utils.data.DataLoader(
|
data_loader_test = torch.utils.data.DataLoader(
|
||||||
dataset_test, sampler=sampler_test,
|
dataset_test, sampler=sampler_test,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers,
|
pin_memory=args.pin_mem, drop_last=False,
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- Mixup/CutMix
|
||||||
mixup_fn = None
|
mixup_fn = None
|
||||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
|
||||||
if mixup_active:
|
if mixup_active:
|
||||||
print("Mixup is activated!")
|
print("Mixup is activated!")
|
||||||
mixup_fn = Mixup(
|
mixup_fn = Mixup(
|
||||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
label_smoothing=args.smoothing, num_classes=args.nb_classes
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Eval-only: resume weights
|
||||||
if args.resume and args.eval:
|
if args.resume and args.eval:
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||||
print("Load checkpoint from: %s" % args.resume)
|
print(f"Load checkpoint for eval from: {args.resume}")
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model_without_ddp = model
|
model_without_ddp = model
|
||||||
|
|
||||||
|
# ---- Adaptation toggle
|
||||||
|
if args.adaptation == "lp":
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param.requires_grad = ("head" in name)
|
||||||
|
print("[Adaptation] Linear probe: training classifier head only.")
|
||||||
|
else:
|
||||||
|
print("[Adaptation] Full fine-tuning: training all parameters.")
|
||||||
|
|
||||||
|
# ---- Count trainable params
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
|
print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
|
||||||
|
|
||||||
|
# ---- LR scaling by effective batch size
|
||||||
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
||||||
|
if args.lr is None:
|
||||||
if args.lr is None: # only base_lr is specified
|
|
||||||
args.lr = args.blr * eff_batch_size / 256
|
args.lr = args.blr * eff_batch_size / 256
|
||||||
|
print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
|
||||||
|
print(f"actual lr: {args.lr:.2e}")
|
||||||
|
print(f"accumulate grad iterations: {args.accum_iter}")
|
||||||
|
print(f"effective batch size: {eff_batch_size}")
|
||||||
|
|
||||||
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
# ---- DDP (if available)
|
||||||
print("actual lr: %.2e" % args.lr)
|
if args.distributed and torch.cuda.device_count() > 1:
|
||||||
|
ddp_kwargs = {}
|
||||||
print("accumulate grad iterations: %d" % args.accum_iter)
|
if args.adaptation == "lp":
|
||||||
print("effective batch size: %d" % eff_batch_size)
|
ddp_kwargs["find_unused_parameters"] = True
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
if args.distributed:
|
model, device_ids=[args.gpu], **ddp_kwargs
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
|
||||||
model_without_ddp = model.module
|
|
||||||
|
|
||||||
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
|
|
||||||
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
|
|
||||||
no_weight_decay_list=no_weight_decay,
|
|
||||||
layer_decay=args.layer_decay
|
|
||||||
)
|
)
|
||||||
|
model_without_ddp = model.module
|
||||||
|
else:
|
||||||
|
model_without_ddp = model # single-GPU
|
||||||
|
|
||||||
|
# ---- Optimizer param groups (after freezing)
|
||||||
|
no_weight_decay = (model_without_ddp.no_weight_decay()
|
||||||
|
if hasattr(model_without_ddp, "no_weight_decay") else [])
|
||||||
|
|
||||||
|
|
||||||
|
param_groups = lrd.param_groups_lrd(
|
||||||
|
model_without_ddp,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
no_weight_decay_list=no_weight_decay,
|
||||||
|
layer_decay=args.layer_decay,
|
||||||
|
)
|
||||||
|
for g in param_groups:
|
||||||
|
g["params"] = [p for p in g["params"] if p.requires_grad]
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
||||||
loss_scaler = NativeScaler()
|
loss_scaler = NativeScaler()
|
||||||
|
print(f"criterion = {criterion}")
|
||||||
|
|
||||||
print("criterion = %s" % str(criterion))
|
# ---- Load previous full state (optimizer, scaler, etc.)
|
||||||
|
misc.load_model(args=args, model_without_ddp=model_without_ddp,
|
||||||
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
optimizer=optimizer, loss_scaler=loss_scaler)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Eval-only Short Circuit
|
||||||
|
# =========================
|
||||||
if args.eval:
|
if args.eval:
|
||||||
if 'epoch' in checkpoint:
|
if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
|
||||||
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
|
print(f"Test with the best model at epoch = {checkpoint['epoch']}")
|
||||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
|
test_stats, auc_roc = evaluate(
|
||||||
num_class=args.nb_classes, log_writer=log_writer)
|
data_loader_test, model, device, args, epoch=0, mode="test",
|
||||||
exit(0)
|
num_class=args.nb_classes, log_writer=log_writer
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Train Loop
|
||||||
|
# =========================
|
||||||
print(f"Start training for {args.epochs} epochs")
|
print(f"Start training for {args.epochs} epochs")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
max_score = 0.0
|
max_score = 0.0
|
||||||
best_epoch = 0
|
best_epoch = 0
|
||||||
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
data_loader_train.sampler.set_epoch(epoch)
|
data_loader_train.sampler.set_epoch(epoch)
|
||||||
@@ -352,49 +388,55 @@ def main(args, criterion):
|
|||||||
model, criterion, data_loader_train,
|
model, criterion, data_loader_train,
|
||||||
optimizer, device, epoch, loss_scaler,
|
optimizer, device, epoch, loss_scaler,
|
||||||
args.clip_grad, mixup_fn,
|
args.clip_grad, mixup_fn,
|
||||||
log_writer=log_writer,
|
log_writer=log_writer, args=args
|
||||||
args=args
|
)
|
||||||
|
|
||||||
|
val_stats, val_score = evaluate(
|
||||||
|
data_loader_val, model, device, args, epoch, mode="val",
|
||||||
|
num_class=args.nb_classes, log_writer=log_writer
|
||||||
)
|
)
|
||||||
|
|
||||||
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
|
|
||||||
num_class=args.nb_classes, log_writer=log_writer)
|
|
||||||
if max_score < val_score:
|
if max_score < val_score:
|
||||||
max_score = val_score
|
max_score = val_score
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if args.output_dir and args.savemodel:
|
if args.output_dir and args.savemodel:
|
||||||
misc.save_model(
|
misc.save_model(
|
||||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
args=args, model=model, model_without_ddp=model_without_ddp,
|
||||||
loss_scaler=loss_scaler, epoch=epoch, mode='best')
|
optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
|
||||||
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
|
)
|
||||||
|
print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
|
||||||
|
|
||||||
if epoch == (args.epochs - 1):
|
|
||||||
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
|
|
||||||
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
|
||||||
model.to(device)
|
|
||||||
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
|
|
||||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
|
|
||||||
num_class=args.nb_classes, log_writer=None)
|
|
||||||
|
|
||||||
if log_writer is not None:
|
if log_writer is not None:
|
||||||
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
|
log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
|
||||||
|
log_writer.flush()
|
||||||
|
|
||||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
|
||||||
'epoch': epoch,
|
"epoch": epoch,
|
||||||
'n_parameters': n_parameters}
|
"n_parameters": n_parameters}
|
||||||
|
|
||||||
if args.output_dir and misc.is_main_process():
|
if args.output_dir and misc.is_main_process():
|
||||||
if log_writer is not None:
|
with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
|
||||||
log_writer.flush()
|
|
||||||
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(log_stats) + "\n")
|
f.write(json.dumps(log_stats) + "\n")
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Final Test (Best Ckpt)
|
||||||
|
# =========================
|
||||||
|
ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
|
||||||
|
_test_stats, _auc_roc = evaluate(
|
||||||
|
data_loader_test, model, device, args, -1, mode="test",
|
||||||
|
num_class=args.nb_classes, log_writer=None
|
||||||
|
)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
print('Training time {}'.format(total_time_str))
|
print(f"Training time {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
args = get_args_parser()
|
args = get_args_parser()
|
||||||
args = args.parse_args()
|
args = args.parse_args()
|
||||||
|
|
||||||
@@ -402,6 +444,5 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if args.output_dir:
|
if args.output_dir:
|
||||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
main(args, criterion)
|
main(args, criterion)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+41
-5
@@ -1,7 +1,3 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
# Partly revised by YZ @UCL&Moorfields
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@@ -10,7 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from timm.models.layers import trunc_normal_
|
||||||
|
|
||||||
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
||||||
""" Vision Transformer with support for global average pooling
|
""" Vision Transformer with support for global average pooling
|
||||||
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def Dinov2(args, **kwargs):
|
||||||
|
|
||||||
|
if args.model_arch == 'dinov2_vits14':
|
||||||
|
arch = 'vit_small_patch14_dinov2.lvd142m'
|
||||||
|
elif args.model_arch == 'dinov2_vitb14':
|
||||||
|
arch = 'vit_base_patch14_dinov2.lvd142m'
|
||||||
|
elif args.model_arch == 'dinov2_vitl14':
|
||||||
|
arch = 'vit_large_patch14_dinov2.lvd142m'
|
||||||
|
elif args.model_arch == 'dinov2_vitg14':
|
||||||
|
arch = 'vit_giant_patch14_dinov2.lvd142m'
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
|
||||||
|
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
|
||||||
|
|
||||||
|
model = timm.create_model(
|
||||||
|
arch,
|
||||||
|
pretrained=True,
|
||||||
|
img_size=224,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def RETFound_dinov2(args, **kwargs):
|
def RETFound_dinov2(args, **kwargs):
|
||||||
model = timm.create_model(
|
model = timm.create_model(
|
||||||
'vit_large_patch14_dinov2.lvd142m',
|
'vit_large_patch14_dinov2.lvd142m',
|
||||||
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def Dinov3(args, **kwargs):
|
||||||
|
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
|
||||||
|
model = torch.hub.load(
|
||||||
|
repo_or_dir="facebookresearch/dinov3",
|
||||||
|
model=args.model_arch,
|
||||||
|
pretrained=False, # main() will load your checkpoint
|
||||||
|
trust_repo=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Figure out feature dimension for the probe
|
||||||
|
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
|
||||||
|
model.head = nn.Linear(feat_dim, args.nb_classes)
|
||||||
|
trunc_normal_(model.head.weight, std=2e-5)
|
||||||
|
if model.head.bias is not None:
|
||||||
|
nn.init.zeros_(model.head.bias)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|||||||
@@ -0,0 +1,27 @@
|
|||||||
|
# ==== Model settings ====
|
||||||
|
ADAPTATION="finetune"
|
||||||
|
MODEL="Dinov2"
|
||||||
|
MODEL_ARCH="dinov2_vitl14"
|
||||||
|
FINETUNE="dinov2_vitl14_pretrain.pth"
|
||||||
|
|
||||||
|
# ==== Data settings ====
|
||||||
|
DATASET="MESSIDOR2"
|
||||||
|
NUM_CLASS=5
|
||||||
|
|
||||||
|
data_path="/home/jupyter/public_dataset/${DATASET}"
|
||||||
|
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1 torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||||
|
--model "${MODEL}" \
|
||||||
|
--model_arch "${MODEL_ARCH}" \
|
||||||
|
--finetune "${FINETUNE}" \
|
||||||
|
--savemodel \
|
||||||
|
--global_pool \
|
||||||
|
--batch_size 24 \
|
||||||
|
--world_size 1 \
|
||||||
|
--epochs 50 \
|
||||||
|
--nb_classes "${NUM_CLASS}" \
|
||||||
|
--data_path "${data_path}" \
|
||||||
|
--input_size 224 \
|
||||||
|
--task "${task}" \
|
||||||
|
--adaptation "${ADAPTATION}"
|
||||||
+49
-21
@@ -1,29 +1,39 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
# Partly revised by YZ @UCL&Moorfields
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Subset
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from timm.data import create_transform
|
from timm.data import create_transform
|
||||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(is_train, args):
|
def build_dataset(is_train, args):
|
||||||
transform = build_transform(is_train, args)
|
transform = build_transform(is_train, args)
|
||||||
root = os.path.join(args.data_path, is_train)
|
root = os.path.join(args.data_path, is_train)
|
||||||
dataset = datasets.ImageFolder(root, transform=transform)
|
dataset = datasets.ImageFolder(root, transform=transform)
|
||||||
|
|
||||||
return dataset
|
if is_train == 'train':
|
||||||
|
ratio = float(getattr(args, "dataratio", 1.0))
|
||||||
|
seed = int(getattr(args, "seed", 0))
|
||||||
|
stratified = bool(getattr(args, "stratified", False))
|
||||||
|
|
||||||
|
if 0.0 < ratio < 1.0:
|
||||||
|
if stratified:
|
||||||
|
idx = _stratified_indices(dataset.targets, ratio, seed)
|
||||||
|
else:
|
||||||
|
# simple uniform subsample with torch.Generator for reproducibility
|
||||||
|
g = torch.Generator().manual_seed(seed)
|
||||||
|
n = len(dataset)
|
||||||
|
k = max(1, int(n * ratio))
|
||||||
|
idx = torch.randperm(n, generator=g)[:k].tolist()
|
||||||
|
dataset = Subset(dataset, idx)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
def build_transform(is_train, args):
|
def build_transform(is_train, args):
|
||||||
mean = IMAGENET_DEFAULT_MEAN
|
mean = IMAGENET_DEFAULT_MEAN
|
||||||
std = IMAGENET_DEFAULT_STD
|
std = IMAGENET_DEFAULT_STD
|
||||||
# train transform
|
|
||||||
if is_train == 'train':
|
if is_train == 'train':
|
||||||
# this should always dispatch to transforms_imagenet_train
|
return create_transform(
|
||||||
transform = create_transform(
|
|
||||||
input_size=args.input_size,
|
input_size=args.input_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
color_jitter=args.color_jitter,
|
color_jitter=args.color_jitter,
|
||||||
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
|
|||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
)
|
)
|
||||||
return transform
|
|
||||||
|
|
||||||
# eval transform
|
# eval transform
|
||||||
t = []
|
crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
|
||||||
if args.input_size <= 224:
|
|
||||||
crop_pct = 224 / 256
|
|
||||||
else:
|
|
||||||
crop_pct = 1.0
|
|
||||||
size = int(args.input_size / crop_pct)
|
size = int(args.input_size / crop_pct)
|
||||||
t.append(
|
t = [
|
||||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||||
)
|
transforms.CenterCrop(args.input_size),
|
||||||
t.append(transforms.CenterCrop(args.input_size))
|
transforms.ToTensor(),
|
||||||
t.append(transforms.ToTensor())
|
transforms.Normalize(mean, std),
|
||||||
t.append(transforms.Normalize(mean, std))
|
]
|
||||||
return transforms.Compose(t)
|
return transforms.Compose(t)
|
||||||
|
|
||||||
|
# ---- helpers ----
|
||||||
|
|
||||||
|
def _stratified_indices(targets, ratio: float, seed: int):
|
||||||
|
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
|
||||||
|
t = torch.as_tensor(targets)
|
||||||
|
classes = torch.unique(t)
|
||||||
|
g = torch.Generator().manual_seed(seed)
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
for c in classes.tolist():
|
||||||
|
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
|
||||||
|
if len(cls_idx) == 0:
|
||||||
|
continue
|
||||||
|
k = max(1, int(round(len(cls_idx) * ratio)))
|
||||||
|
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
|
||||||
|
keep.extend(sel.tolist())
|
||||||
|
|
||||||
|
# shuffle final indices (stable across seed)
|
||||||
|
g2 = torch.Generator().manual_seed(seed + 1)
|
||||||
|
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user