Compare commits
10 Commits
91915d6a14
...
ae9a9ecf37
| Author | SHA1 | Date | |
|---|---|---|---|
| ae9a9ecf37 | |||
| 8f5b2ce5e7 | |||
| dbbddb8936 | |||
| ed8b469a0f | |||
| 17768be893 | |||
| bda7a6c69f | |||
| 7489af0620 | |||
| 409f7b6167 | |||
| 897d71c8c9 | |||
| a7a9b3a8b7 |
@@ -1,14 +1,15 @@
|
|||||||
## RETFound - A foundation model for retinal imaging
|
## RETFound - A foundation model for retinal images
|
||||||
|
|
||||||
|
|
||||||
Official repo including a series of retinal foundation models.<br>
|
Official repo including a series of foundation models and applications for retinal images.<br>
|
||||||
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).<br>
|
`[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).<br>
|
||||||
[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2):
|
`[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).<br>
|
||||||
|
`[DINOv2]`:[General-purpose vision foundation models DINOv2 by Meta](https://github.com/facebookresearch/dinov2).<br>
|
||||||
|
`[DINOv3]`:[General-purpose vision foundation models DINOv3 by Meta](https://github.com/facebookresearch/dinov3).<br>
|
||||||
|
|
||||||
|
|
||||||
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
|
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
|
||||||
|
|
||||||
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
|
|
||||||
|
|
||||||
|
|
||||||
### 📝Key features
|
### 📝Key features
|
||||||
|
|
||||||
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
|
|||||||
|
|
||||||
### 🎉News
|
### 🎉News
|
||||||
|
|
||||||
|
- 🐉2025/09: **Preprint benchmarking DINOv3, DINOv2, and RETFound is [available](https://arxiv.org/abs/2509.03421)!**
|
||||||
|
- 🐉2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
|
||||||
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
|
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
|
||||||
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
|
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
|
||||||
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
|
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
|
||||||
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
|
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
|
||||||
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
|
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
|
||||||
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
|
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
|
||||||
- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
|
|
||||||
|
|
||||||
|
|
||||||
### 🔧Install environment
|
### 🔧Install environment
|
||||||
@@ -40,17 +42,17 @@ conda activate retfound
|
|||||||
2. Install dependencies
|
2. Install dependencies
|
||||||
|
|
||||||
```
|
```
|
||||||
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
|
||||||
git clone https://github.com/rmaphoh/RETFound_MAE/
|
git clone https://github.com/rmaphoh/RETFound/
|
||||||
cd RETFound_MAE
|
cd RETFound
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
pip install ipykernel
|
||||||
|
python -m ipykernel install --user --name retfound --display-name "Python (retfound)"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
### 🌱Fine-tuning with RETFound weights
|
### 🌱Fine-tuning with RETFound weights
|
||||||
|
|
||||||
To fine tune RETFound on your own data, follow these steps:
|
|
||||||
|
|
||||||
1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
|
1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
|
||||||
<table><tbody>
|
<table><tbody>
|
||||||
<!-- START TABLE -->
|
<!-- START TABLE -->
|
||||||
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
|
|||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_mae_meh</td>
|
<tr><td align="left">RETFound_mae_meh</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_mae_shanghai</td>
|
<tr><td align="left">RETFound_mae_shanghai</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_dinov2_meh</td>
|
<tr><td align="left">RETFound_dinov2_meh</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
<!-- TABLE BODY -->
|
<!-- TABLE BODY -->
|
||||||
<tr><td align="left">RETFound_dinov2_shanghai</td>
|
<tr><td align="left">RETFound_dinov2_shanghai</td>
|
||||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
|
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
|
||||||
<td align="center">TBD</a></td>
|
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||||
</tr>
|
</tr>
|
||||||
</tbody></table>
|
</tbody></table>
|
||||||
|
|
||||||
@@ -100,7 +102,9 @@ huggingface-cli login --token YOUR_HUGGINGFACE_TOKEN
|
|||||||
export HF_ENDPOINT=https://hf-mirror.com
|
export HF_ENDPOINT=https://hf-mirror.com
|
||||||
```
|
```
|
||||||
|
|
||||||
3. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
|
3. If you would like to fine-tune [DINOv2](https://github.com/facebookresearch/dinov2) and [DINOv3](https://github.com/facebookresearch/dinov3), please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
|
||||||
|
|
||||||
|
4. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
|
||||||
|
|
||||||
```
|
```
|
||||||
├── data folder
|
├── data folder
|
||||||
@@ -118,56 +122,122 @@ export HF_ENDPOINT=https://hf-mirror.com
|
|||||||
├──class_c
|
├──class_c
|
||||||
```
|
```
|
||||||
|
|
||||||
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
|
|
||||||
|
|
||||||
The model and finetune can be selected:
|
|
||||||
|
|
||||||
| model | finetune |
|
5. Start fine-tuning by running `sh train.sh`.
|
||||||
|-----------------|--------------------------|
|
|
||||||
| RETFound_mae | RETFound_mae_natureCFP |
|
|
||||||
| RETFound_mae | RETFound_mae_natureOCT |
|
In `train.sh`, the model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE`:
|
||||||
| RETFound_mae | RETFound_mae_meh |
|
|
||||||
| RETFound_mae | RETFound_mae_shanghai |
|
**RETFound**:
|
||||||
| RETFound_dinov2 | RETFound_dinov2_meh |
|
|
||||||
| RETFound_dinov2 | RETFound_dinov2_shanghai |
|
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||||
|
|-----------------|--------------------------|--------------------------|--------------------------|
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
|
||||||
|
| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
|
||||||
|
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
|
||||||
|
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
|
||||||
|
|
||||||
|
|
||||||
|
**DINOv3**:
|
||||||
|
|
||||||
|
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||||
|
|-----------------|--------------------------|----------------------------------|--------------------------|
|
||||||
|
| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
|
||||||
|
| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
|
||||||
|
| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
|
||||||
|
| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
|
||||||
|
| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
|
||||||
|
| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
|
||||||
|
|
||||||
|
|
||||||
|
**DINOv2**:
|
||||||
|
|
||||||
|
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||||
|
|-----------------|--------------------------|------------------------------|--------------------------|
|
||||||
|
| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
|
||||||
|
| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
|
||||||
|
| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
|
||||||
|
| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
|
||||||
|
|
||||||
|
|
||||||
|
Change the DATA_PATH to your dataset directory.
|
||||||
|
|
||||||
```
|
```
|
||||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
# ==== Model settings ====
|
||||||
--model RETFound_mae \
|
# adaptation {finetune,lp}
|
||||||
--savemodel \
|
ADAPTATION="finetune"
|
||||||
--global_pool \
|
MODEL="RETFound_dinov2"
|
||||||
--batch_size 16 \
|
MODEL_ARCH="retfound_dinov2"
|
||||||
--world_size 1 \
|
FINETUNE="RETFound_dinov2_meh"
|
||||||
--epochs 100 \
|
|
||||||
--blr 5e-3 --layer_decay 0.65 \
|
# ==== Data settings ====
|
||||||
--weight_decay 0.05 --drop_path 0.2 \
|
# change the dataset name and corresponding class number
|
||||||
--nb_classes 5 \
|
DATASET="MESSIDOR2"
|
||||||
--data_path ./IDRiD \
|
NUM_CLASS=5
|
||||||
--input_size 224 \
|
|
||||||
--task RETFound_mae_meh-IDRiD \
|
# =======================
|
||||||
--finetune RETFound_mae_meh
|
DATA_PATH="PATH TO THE DATASET"
|
||||||
|
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||||
|
|
||||||
|
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||||
|
--model "${MODEL}" \
|
||||||
|
--model_arch "${MODEL_ARCH}" \
|
||||||
|
--finetune "${FINETUNE}" \
|
||||||
|
--savemodel \
|
||||||
|
--global_pool \
|
||||||
|
--batch_size 24 \
|
||||||
|
--world_size 1 \
|
||||||
|
--epochs 50 \
|
||||||
|
--nb_classes "${NUM_CLASS}" \
|
||||||
|
--data_path "${DATA_PATH}" \
|
||||||
|
--input_size 224 \
|
||||||
|
--task "${TASK}" \
|
||||||
|
--adaptation "${ADAPTATION}"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
|
|
||||||
|
6. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the DATA_PATH below)
|
||||||
|
|
||||||
|
|
||||||
```
|
```
|
||||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
# ==== Model/settings (match training) ====
|
||||||
--model RETFound_mae \
|
ADAPTATION="finetune"
|
||||||
--savemodel \
|
MODEL="RETFound_dinov2"
|
||||||
--eval \
|
MODEL_ARCH="retfound_dinov2"
|
||||||
--global_pool \
|
FINETUNE="RETFound_dinov2_meh"
|
||||||
--batch_size 16 \
|
|
||||||
--world_size 1 \
|
# ==== Data/settings (match training) ====
|
||||||
--epochs 100 \
|
DATASET="MESSIDOR2"
|
||||||
--blr 5e-3 --layer_decay 0.65 \
|
NUM_CLASS=5
|
||||||
--weight_decay 0.05 --drop_path 0.2 \
|
|
||||||
--nb_classes 5 \
|
# =======================
|
||||||
--data_path ./IDRiD \
|
DATA_PATH="PATH TO THE DATASET"
|
||||||
--input_size 224 \
|
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||||
--task RETFound_mae_meh-IDRiD \
|
|
||||||
--resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth
|
# Path to the trained checkpoint (adjust if you saved elsewhere)
|
||||||
|
CKPT="./output_dir/${TASK}/checkpoint-best.pth"
|
||||||
|
|
||||||
|
# ==== Evaluation only ====
|
||||||
|
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||||
|
--model "${MODEL}" \
|
||||||
|
--model_arch "${MODEL_ARCH}" \
|
||||||
|
--savemodel \
|
||||||
|
--global_pool \
|
||||||
|
--batch_size 128 \
|
||||||
|
--world_size 1 \
|
||||||
|
--nb_classes "${NUM_CLASS}" \
|
||||||
|
--data_path "${DATA_PATH}" \
|
||||||
|
--input_size 224 \
|
||||||
|
--task "${TASK}" \
|
||||||
|
--adaptation "${ADAPTATION}" \
|
||||||
|
--eval \
|
||||||
|
--resume "${CKPT}"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
@@ -175,9 +245,6 @@ torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
|||||||
|
|
||||||
If you find this repository useful, please consider citing this paper:
|
If you find this repository useful, please consider citing this paper:
|
||||||
|
|
||||||
```
|
|
||||||
TBD
|
|
||||||
```
|
|
||||||
|
|
||||||
```
|
```
|
||||||
@article{zhou2023foundation,
|
@article{zhou2023foundation,
|
||||||
@@ -192,4 +259,14 @@ TBD
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{zhou2025generalistversusspecialistvision,
|
||||||
|
title={Generalist versus Specialist Vision Foundation Models for Ocular Disease and Oculomics},
|
||||||
|
author={Yukun Zhou and Paul Nderitu and Jocelyn Hui Lin Goh and Justin Engelmann and Siegfried K. Wagner and Anran Ran and Hongyang Jiang and Lie Ju and Ke Zou and Sahana Srinivasan and Hyunmin Kim and Takahiro Ninomiya and Zheyuan Wang and Gabriel Dawei Yang and Eden Ruffell and Dominic Williamson and Rui Santos and Gabor Mark Somfai and Carol Y. Cheung and Tien Yin Wong and Daniel C. Alexander and Yih Chung Tham and Pearse A. Keane},
|
||||||
|
year={2025},
|
||||||
|
eprint={2509.03421},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={eess.IV},
|
||||||
|
url={https://arxiv.org/abs/2509.03421},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@@ -0,0 +1,223 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "76b39fb1",
|
||||||
|
"metadata": {
|
||||||
|
"jp-MarkdownHeadingCollapsed": true
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Jupyter notebook example - Classification task\n",
|
||||||
|
"### Example using [MESSIDOR2](https://www.adcis.net/en/third-party/messidor2/) dataset\n",
|
||||||
|
"**Application**: Using RETFound for five-category diabetic retinopathy classification\n",
|
||||||
|
"\n",
|
||||||
|
"**Author**: Yukun Zhou\n",
|
||||||
|
"\n",
|
||||||
|
"**Date**: 30 Nov 2025\n",
|
||||||
|
"\n",
|
||||||
|
"**Performance**:\n",
|
||||||
|
"\n",
|
||||||
|
"<table align=\"left\">\n",
|
||||||
|
"<tr>\n",
|
||||||
|
" <th>Accuracy</th>\n",
|
||||||
|
" <th>Recall</th>\n",
|
||||||
|
" <th>F1 Score</th>\n",
|
||||||
|
" <th>ROC AUC</th>\n",
|
||||||
|
" <th>PR AUC</th>\n",
|
||||||
|
"</tr>\n",
|
||||||
|
"<tr>\n",
|
||||||
|
" <td>0.7091</td>\n",
|
||||||
|
" <td>0.5616</td>\n",
|
||||||
|
" <td>0.6078</td>\n",
|
||||||
|
" <td>0.9037</td>\n",
|
||||||
|
" <td>0.6863</td>\n",
|
||||||
|
"</tr>\n",
|
||||||
|
"</table>\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7ec435a7",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. Install environment\n",
|
||||||
|
"1. Follow [RETFound README](https://github.com/rmaphoh/RETFound) to install environment\n",
|
||||||
|
"2. Restart this Jupyter Notebook\n",
|
||||||
|
"3. Select Kernel retfound"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7cbf5e93-6ca0-4401-88e6-64e39968e7cd",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys, torch\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"PROJECT_ROOT = Path.cwd().resolve()\n",
|
||||||
|
"\n",
|
||||||
|
"if PROJECT_ROOT.name == 'examples': PROJECT_ROOT = PROJECT_ROOT.parent\n",
|
||||||
|
"os.chdir(PROJECT_ROOT)\n",
|
||||||
|
"\n",
|
||||||
|
"print('Project root:', PROJECT_ROOT)\n",
|
||||||
|
"print(\"sys.executable:\", sys.executable)\n",
|
||||||
|
"print(\"torch version:\", torch.__version__)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ed67953f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. Prepare MESSIDOR2 dataset\n",
|
||||||
|
"1. Download from the [shared data pool](https://github.com/rmaphoh/RETFound/blob/main/BENCHMARK.md).\n",
|
||||||
|
"2. Put the data folder under the project directory, e.g. \"RETFound/MESSIDOR2\"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "357be2fa-a914-4d1f-8759-76b2b1c3f20f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3. Hyperparameter and path settings\n",
|
||||||
|
"1. Can choose finetune or lp (linear probe)\n",
|
||||||
|
"2. Model selection [info](https://github.com/rmaphoh/RETFound#:~:text=In%20train.sh%2C%20the%20model%20can%20be%20selected%20by%20changing%20the%20hyperparameters%20MODEL%2C%20MODEL_ARCH%2C%20FINETUNE%3A)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "5f675843",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"ADAPTATION='finetune'\n",
|
||||||
|
"MODEL='RETFound_dinov2'\n",
|
||||||
|
"MODEL_ARCH='retfound_dinov2'\n",
|
||||||
|
"FINETUNE='RETFound_dinov2_meh'\n",
|
||||||
|
"DATASET='MESSIDOR2'\n",
|
||||||
|
"NUM_CLASS=5\n",
|
||||||
|
"DATA_PATH=PROJECT_ROOT/DATASET\n",
|
||||||
|
"BATCH_SIZE=24\n",
|
||||||
|
"EPOCHS=50\n",
|
||||||
|
"INPUT_SIZE=224\n",
|
||||||
|
"WORLD_SIZE=1\n",
|
||||||
|
"TASK=f\"{MODEL_ARCH}_{DATASET}_{ADAPTATION}\"\n",
|
||||||
|
"OUTPUT_DIR=PROJECT_ROOT/'output_dir'/TASK\n",
|
||||||
|
"print('DATA_PATH:',DATA_PATH)\n",
|
||||||
|
"print('TASK:',TASK)\n",
|
||||||
|
"print('OUTPUT_DIR:',OUTPUT_DIR)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "6ac04845",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 4. Fine-tuning and testing RETFound on MESSIDOR2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d23ff751",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"\n",
|
||||||
|
"!{sys.executable} main_finetune.py \\\n",
|
||||||
|
" --model {MODEL} \\\n",
|
||||||
|
" --model_arch {MODEL_ARCH} \\\n",
|
||||||
|
" --finetune {FINETUNE} \\\n",
|
||||||
|
" --savemodel \\\n",
|
||||||
|
" --global_pool \\\n",
|
||||||
|
" --batch_size {BATCH_SIZE} \\\n",
|
||||||
|
" --epochs {EPOCHS} \\\n",
|
||||||
|
" --nb_classes {NUM_CLASS} \\\n",
|
||||||
|
" --data_path {DATA_PATH} \\\n",
|
||||||
|
" --input_size {INPUT_SIZE} \\\n",
|
||||||
|
" --task {TASK} \\\n",
|
||||||
|
" --adaptation {ADAPTATION}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "84ce93ac",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 5. Evaluation-only"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "0af0f8a7",
|
||||||
|
"metadata": {
|
||||||
|
"scrolled": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import sys\n",
|
||||||
|
"\n",
|
||||||
|
"CKPT = OUTPUT_DIR / \"checkpoint-best.pth\"\n",
|
||||||
|
"\n",
|
||||||
|
"!{sys.executable} main_finetune.py \\\n",
|
||||||
|
" --model {MODEL} \\\n",
|
||||||
|
" --model_arch {MODEL_ARCH} \\\n",
|
||||||
|
" --finetune {FINETUNE} \\\n",
|
||||||
|
" --savemodel \\\n",
|
||||||
|
" --global_pool \\\n",
|
||||||
|
" --batch_size 128 \\\n",
|
||||||
|
" --nb_classes {NUM_CLASS} \\\n",
|
||||||
|
" --data_path {DATA_PATH} \\\n",
|
||||||
|
" --input_size {INPUT_SIZE} \\\n",
|
||||||
|
" --task {TASK} \\\n",
|
||||||
|
" --adaptation {ADAPTATION} \\\n",
|
||||||
|
" --eval \\\n",
|
||||||
|
" --resume {CKPT}\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "02d2dce7-31c2-48e2-87ce-9223b74cf94e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"environment": {
|
||||||
|
"kernel": "retfound",
|
||||||
|
"name": "workbench-notebooks.m128",
|
||||||
|
"type": "gcloud",
|
||||||
|
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "retfound_jupyter (Local)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "retfound"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -22,18 +22,18 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": null,
|
||||||
"id": "90c3d964",
|
"id": "90c3d964",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
|
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # load model\n",
|
" # load model\n",
|
||||||
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
|
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
|
||||||
" \n",
|
" \n",
|
||||||
" # build model\n",
|
" # build model\n",
|
||||||
" if arch=='vit_large_patch16':\n",
|
" if arch=='RETFound_mae':\n",
|
||||||
" model = models.__dict__[arch](\n",
|
" model = models.__dict__[arch](\n",
|
||||||
" img_size=224,\n",
|
" img_size=224,\n",
|
||||||
" num_classes=5,\n",
|
" num_classes=5,\n",
|
||||||
@@ -70,7 +70,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": null,
|
||||||
"id": "9a250363",
|
"id": "9a250363",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -78,7 +78,7 @@
|
|||||||
"def get_feature(data_path,\n",
|
"def get_feature(data_path,\n",
|
||||||
" chkpt_dir,\n",
|
" chkpt_dir,\n",
|
||||||
" device,\n",
|
" device,\n",
|
||||||
" arch='vit_large_patch16'):\n",
|
" arch='RETFound_mae'):\n",
|
||||||
" #loading model\n",
|
" #loading model\n",
|
||||||
" model_ = prepare_model(chkpt_dir, arch)\n",
|
" model_ = prepare_model(chkpt_dir, arch)\n",
|
||||||
" model_.to(device)\n",
|
" model_.to(device)\n",
|
||||||
@@ -121,7 +121,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
|
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
|
||||||
"data_path = 'DATA_PATH'\n",
|
"data_path = 'DATA_PATH'\n",
|
||||||
"device = torch.device('cuda')\n",
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||||
"arch='dinov2_large'"
|
"arch='dinov2_large'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
+290
-249
@@ -1,349 +1,385 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
# =========================
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import warnings
|
||||||
|
import faulthandler
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from timm.models.layers import trunc_normal_
|
from timm.models.layers import trunc_normal_
|
||||||
from timm.data.mixup import Mixup
|
from timm.data.mixup import Mixup
|
||||||
|
from huggingface_hub import hf_hub_download, login # login imported as in original
|
||||||
|
|
||||||
|
# =========================
|
||||||
import models_vit as models
|
import models_vit as models
|
||||||
import util.lr_decay as lrd
|
import util.lr_decay as lrd
|
||||||
import util.misc as misc
|
import util.misc as misc
|
||||||
from util.datasets import build_dataset
|
from util.datasets import build_dataset
|
||||||
from util.pos_embed import interpolate_pos_embed
|
from util.pos_embed import interpolate_pos_embed
|
||||||
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
||||||
from huggingface_hub import hf_hub_download, login
|
|
||||||
from engine_finetune import train_one_epoch, evaluate
|
from engine_finetune import train_one_epoch, evaluate
|
||||||
|
|
||||||
import warnings
|
# =========================
|
||||||
import faulthandler
|
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
def get_args_parser():
|
||||||
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
|
parser = argparse.ArgumentParser(
|
||||||
parser.add_argument('--batch_size', default=128, type=int,
|
"MAE fine-tuning / linear probing for image classification", add_help=False
|
||||||
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
)
|
||||||
parser.add_argument('--epochs', default=50, type=int)
|
|
||||||
parser.add_argument('--accum_iter', default=1, type=int,
|
|
||||||
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
|
||||||
|
|
||||||
# Model parameters
|
# ---- Core training
|
||||||
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
parser.add_argument("--batch_size", default=128, type=int,
|
||||||
help='Name of model to train')
|
help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
|
||||||
parser.add_argument('--input_size', default=256, type=int,
|
parser.add_argument("--epochs", default=50, type=int)
|
||||||
help='images input size')
|
parser.add_argument("--accum_iter", default=1, type=int,
|
||||||
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
|
help="Gradient accumulation steps")
|
||||||
help='Drop path rate (default: 0.1)')
|
|
||||||
|
|
||||||
# Optimizer parameters
|
# ---- Model parameters
|
||||||
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
|
||||||
help='Clip gradient norm (default: None, no clipping)')
|
help="Model entry in models_vit.py")
|
||||||
parser.add_argument('--weight_decay', type=float, default=0.05,
|
parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
|
||||||
help='weight decay (default: 0.05)')
|
help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
|
||||||
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
parser.add_argument("--input_size", default=256, type=int, help="Image size")
|
||||||
help='learning rate (absolute lr)')
|
parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
|
||||||
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
|
parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
|
||||||
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
parser.add_argument("--cls_token", action="store_false", dest="global_pool",
|
||||||
parser.add_argument('--layer_decay', type=float, default=0.65,
|
help="Use class token instead of global pool for classification")
|
||||||
help='layer-wise lr decay from ELECTRA/BEiT')
|
|
||||||
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
|
||||||
help='lower lr bound for cyclic schedulers that hit 0')
|
|
||||||
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
|
||||||
help='epochs to warmup LR')
|
|
||||||
|
|
||||||
# Augmentation parameters
|
# ---- Optimizer parameters
|
||||||
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
|
parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
|
||||||
help='Color jitter factor (enabled only when not using Auto/RandAug)')
|
parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
|
||||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
|
||||||
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
|
parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
|
||||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
help="Base LR: lr = blr * total_batch_size / 256")
|
||||||
help='Label smoothing (default: 0.1)')
|
parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
|
||||||
|
parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
|
||||||
|
parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
|
||||||
|
|
||||||
# * Random Erase params
|
# ---- Augmentation
|
||||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
|
||||||
help='Random erase prob (default: 0.25)')
|
parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
|
||||||
parser.add_argument('--remode', type=str, default='pixel',
|
parser.add_argument("--smoothing", type=float, default=0.1)
|
||||||
help='Random erase mode (default: "pixel")')
|
|
||||||
parser.add_argument('--recount', type=int, default=1,
|
|
||||||
help='Random erase count (default: 1)')
|
|
||||||
parser.add_argument('--resplit', action='store_true', default=False,
|
|
||||||
help='Do not random erase first (clean) augmentation split')
|
|
||||||
|
|
||||||
# * Mixup params
|
# ---- Random erase
|
||||||
parser.add_argument('--mixup', type=float, default=0,
|
parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
|
||||||
help='mixup alpha, mixup enabled if > 0.')
|
parser.add_argument("--remode", type=str, default="pixel")
|
||||||
parser.add_argument('--cutmix', type=float, default=0,
|
parser.add_argument("--recount", type=int, default=1)
|
||||||
help='cutmix alpha, cutmix enabled if > 0.')
|
parser.add_argument("--resplit", action="store_true", default=False)
|
||||||
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
|
||||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
|
||||||
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
|
||||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
|
||||||
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
|
||||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
|
||||||
parser.add_argument('--mixup_mode', type=str, default='batch',
|
|
||||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
|
||||||
|
|
||||||
# * Finetuning params
|
# ---- Mixup/Cutmix
|
||||||
parser.add_argument('--finetune', default='', type=str,
|
parser.add_argument("--mixup", type=float, default=0.0)
|
||||||
help='finetune from checkpoint')
|
parser.add_argument("--cutmix", type=float, default=0.0)
|
||||||
parser.add_argument('--task', default='', type=str,
|
parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
|
||||||
help='finetune from checkpoint')
|
parser.add_argument("--mixup_prob", type=float, default=1.0)
|
||||||
parser.add_argument('--global_pool', action='store_true')
|
parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
|
||||||
parser.set_defaults(global_pool=True)
|
parser.add_argument("--mixup_mode", type=str, default="batch")
|
||||||
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
|
||||||
help='Use class token instead of global pool for classification')
|
|
||||||
|
|
||||||
# Dataset parameters
|
# ---- Finetuning & adaptation
|
||||||
parser.add_argument('--data_path', default='./data/', type=str,
|
parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
|
||||||
help='dataset path')
|
parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
|
||||||
parser.add_argument('--nb_classes', default=8, type=int,
|
parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
|
||||||
help='number of the classification types')
|
help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
|
||||||
parser.add_argument('--output_dir', default='./output_dir',
|
|
||||||
help='path where to save, empty for no saving')
|
|
||||||
parser.add_argument('--log_dir', default='./output_logs',
|
|
||||||
help='path where to tensorboard log')
|
|
||||||
parser.add_argument('--device', default='cuda',
|
|
||||||
help='device to use for training / testing')
|
|
||||||
parser.add_argument('--seed', default=0, type=int)
|
|
||||||
parser.add_argument('--resume', default='',
|
|
||||||
help='resume from checkpoint')
|
|
||||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
|
||||||
help='start epoch')
|
|
||||||
parser.add_argument('--eval', action='store_true',
|
|
||||||
help='Perform evaluation only')
|
|
||||||
parser.add_argument('--dist_eval', action='store_true', default=False,
|
|
||||||
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
|
||||||
parser.add_argument('--num_workers', default=10, type=int)
|
|
||||||
parser.add_argument('--pin_mem', action='store_true',
|
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
|
||||||
parser.set_defaults(pin_mem=True)
|
|
||||||
|
|
||||||
# distributed training parameters
|
# ---- Dataset & paths
|
||||||
parser.add_argument('--world_size', default=1, type=int,
|
parser.add_argument("--data_path", default="./data/", type=str)
|
||||||
help='number of distributed processes')
|
parser.add_argument("--nb_classes", default=8, type=int)
|
||||||
parser.add_argument('--local_rank', default=-1, type=int)
|
parser.add_argument("--output_dir", default="./output_dir")
|
||||||
parser.add_argument('--dist_on_itp', action='store_true')
|
parser.add_argument("--log_dir", default="./output_logs")
|
||||||
parser.add_argument('--dist_url', default='env://',
|
|
||||||
help='url used to set up distributed training')
|
|
||||||
|
|
||||||
# fine-tuning parameters
|
# >>> NEW: training data efficiency <<<
|
||||||
parser.add_argument('--savemodel', action='store_true', default=True,
|
parser.add_argument(
|
||||||
help='Save model')
|
"--dataratio", type=str, default="1.0",
|
||||||
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
|
help=('Training data ratio(s) for subsampling in build_dataset. '
|
||||||
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
|
'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
|
||||||
parser.add_argument('--datasets_seed', default=2026, type=int)
|
'(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stratified", action="store_true",
|
||||||
|
help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Runtime
|
||||||
|
parser.add_argument("--device", default="cuda")
|
||||||
|
parser.add_argument("--seed", default=0, type=int)
|
||||||
|
parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
|
||||||
|
parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
|
||||||
|
parser.add_argument("--eval", action="store_true", help="Evaluation only")
|
||||||
|
parser.add_argument("--dist_eval", action="store_true", default=False,
|
||||||
|
help="Distributed evaluation (faster monitoring during training)")
|
||||||
|
parser.add_argument("--num_workers", default=10, type=int)
|
||||||
|
parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
|
||||||
|
|
||||||
|
# ---- Distributed
|
||||||
|
parser.add_argument("--world_size", default=1, type=int)
|
||||||
|
parser.add_argument("--local_rank", default=-1, type=int)
|
||||||
|
parser.add_argument("--dist_on_itp", action="store_true")
|
||||||
|
parser.add_argument("--dist_url", default="env://")
|
||||||
|
|
||||||
|
# ---- Misc
|
||||||
|
parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
|
||||||
|
parser.add_argument("--norm", default="IMAGENET", type=str)
|
||||||
|
parser.add_argument("--enhance", action="store_true", default=False)
|
||||||
|
parser.add_argument("--datasets_seed", default=2026, type=int)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Main
|
||||||
|
# =========================
|
||||||
def main(args, criterion):
|
def main(args, criterion):
|
||||||
|
# ---- Optionally load args from resume (when training)
|
||||||
if args.resume and not args.eval:
|
if args.resume and not args.eval:
|
||||||
resume = args.resume
|
resume_path = args.resume
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||||
print("Load checkpoint from: %s" % args.resume)
|
print(f"Load checkpoint (args) from: {args.resume}")
|
||||||
args = checkpoint['args']
|
args = checkpoint["args"]
|
||||||
args.resume = resume
|
args.resume = resume_path
|
||||||
|
|
||||||
|
# ---- Distributed setup
|
||||||
misc.init_distributed_mode(args)
|
misc.init_distributed_mode(args)
|
||||||
|
|
||||||
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
|
||||||
print("{}".format(args).replace(', ', ',\n'))
|
print(f"{args}".replace(", ", ",\n"))
|
||||||
|
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|
||||||
# fix the seed for reproducibility
|
# ---- Reproducibility
|
||||||
seed = args.seed + misc.get_rank()
|
seed = args.seed + misc.get_rank()
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
if args.model=='RETFound_mae':
|
# ---- Build model
|
||||||
|
if args.model == "RETFound_mae":
|
||||||
model = models.__dict__[args.model](
|
model = models.__dict__[args.model](
|
||||||
img_size=args.input_size,
|
img_size=args.input_size,
|
||||||
num_classes=args.nb_classes,
|
num_classes=args.nb_classes,
|
||||||
drop_path_rate=args.drop_path,
|
drop_path_rate=args.drop_path,
|
||||||
global_pool=args.global_pool,
|
global_pool=args.global_pool,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = models.__dict__[args.model](
|
model = models.__dict__[args.model](
|
||||||
num_classes=args.nb_classes,
|
num_classes=args.nb_classes,
|
||||||
drop_path_rate=args.drop_path,
|
drop_path_rate=args.drop_path,
|
||||||
args=args,
|
args=args,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.finetune and not args.eval:
|
|
||||||
|
|
||||||
print(f"Downloading pre-trained weights from: {args.finetune}")
|
|
||||||
|
|
||||||
checkpoint_path = hf_hub_download(
|
|
||||||
repo_id=f'YukunZhou/{args.finetune}',
|
|
||||||
filename=f'{args.finetune}.pth',
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
||||||
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
|
||||||
|
|
||||||
if args.model!='RETFound_mae':
|
|
||||||
checkpoint_model = checkpoint['teacher']
|
|
||||||
else:
|
|
||||||
checkpoint_model = checkpoint['model']
|
|
||||||
|
|
||||||
|
# ---- Load pre-trained weights (if requested and not eval-only)
|
||||||
|
if args.finetune and not args.eval:
|
||||||
|
print(f"Preparing to load pre-trained weights: {args.finetune}")
|
||||||
|
|
||||||
|
if args.model in ["Dinov3", "Dinov2"]:
|
||||||
|
checkpoint_path = args.finetune # local path
|
||||||
|
elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
|
||||||
|
print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
|
||||||
|
checkpoint_path = hf_hub_download(
|
||||||
|
repo_id=f"YukunZhou/{args.finetune}",
|
||||||
|
filename=f"{args.finetune}.pth",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported model '{args.model}'. "
|
||||||
|
f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
|
print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
|
||||||
|
|
||||||
|
if args.model in ["Dinov3", "Dinov2"]:
|
||||||
|
checkpoint_model = checkpoint
|
||||||
|
elif args.model == "RETFound_dinov2":
|
||||||
|
checkpoint_model = checkpoint["teacher"]
|
||||||
|
else: # RETFound_mae
|
||||||
|
checkpoint_model = checkpoint["model"]
|
||||||
|
|
||||||
|
# -- Key hygiene
|
||||||
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
||||||
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
||||||
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
||||||
|
|
||||||
|
# -- Remove classifier if shape mismatched
|
||||||
state_dict = model.state_dict()
|
state_dict = model.state_dict()
|
||||||
for k in ['head.weight', 'head.bias']:
|
for k in ["head.weight", "head.bias"]:
|
||||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||||
print(f"Removing key {k} from pretrained checkpoint")
|
print(f"Removing key {k} from pretrained checkpoint")
|
||||||
del checkpoint_model[k]
|
del checkpoint_model[k]
|
||||||
|
|
||||||
# interpolate position embedding
|
# -- Interpolate pos embed (ViT)
|
||||||
interpolate_pos_embed(model, checkpoint_model)
|
interpolate_pos_embed(model, checkpoint_model)
|
||||||
|
|
||||||
# load pre-trained model
|
# -- Load backbone weights (non-strict)
|
||||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
_ = model.load_state_dict(checkpoint_model, strict=False)
|
||||||
|
|
||||||
trunc_normal_(model.head.weight, std=2e-5)
|
# -- Re-init head
|
||||||
|
if hasattr(model, "head") and hasattr(model.head, "weight"):
|
||||||
|
trunc_normal_(model.head.weight, std=2e-5)
|
||||||
|
|
||||||
dataset_train = build_dataset(is_train='train', args=args)
|
# ---- Datasets & samplers
|
||||||
dataset_val = build_dataset(is_train='val', args=args)
|
dataset_train = build_dataset(is_train="train", args=args)
|
||||||
dataset_test = build_dataset(is_train='test', args=args)
|
dataset_val = build_dataset(is_train="val", args=args)
|
||||||
|
dataset_test = build_dataset(is_train="test", args=args)
|
||||||
|
|
||||||
|
num_tasks = misc.get_world_size()
|
||||||
|
global_rank = misc.get_rank()
|
||||||
|
|
||||||
if True: # args.distributed:
|
if not args.eval:
|
||||||
num_tasks = misc.get_world_size()
|
sampler_train = torch.utils.data.DistributedSampler(
|
||||||
global_rank = misc.get_rank()
|
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
if not args.eval:
|
)
|
||||||
sampler_train = torch.utils.data.DistributedSampler(
|
print(f"Sampler_train = {sampler_train}")
|
||||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
|
||||||
)
|
|
||||||
print("Sampler_train = %s" % str(sampler_train))
|
|
||||||
if args.dist_eval:
|
|
||||||
if len(dataset_val) % num_tasks != 0:
|
|
||||||
print(
|
|
||||||
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
||||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
|
||||||
'equal num of samples per-process.')
|
|
||||||
sampler_val = torch.utils.data.DistributedSampler(
|
|
||||||
dataset_val, num_replicas=num_tasks, rank=global_rank,
|
|
||||||
shuffle=True) # shuffle=True to reduce monitor bias
|
|
||||||
else:
|
|
||||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
|
||||||
|
|
||||||
if args.dist_eval:
|
if args.dist_eval:
|
||||||
if len(dataset_test) % num_tasks != 0:
|
if len(dataset_val) % num_tasks != 0:
|
||||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
|
||||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
sampler_val = torch.utils.data.DistributedSampler(
|
||||||
'equal num of samples per-process.')
|
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
sampler_test = torch.utils.data.DistributedSampler(
|
)
|
||||||
dataset_test, num_replicas=num_tasks, rank=global_rank,
|
|
||||||
shuffle=True) # shuffle=True to reduce monitor bias
|
|
||||||
else:
|
else:
|
||||||
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||||
|
|
||||||
|
if args.dist_eval:
|
||||||
|
if len(dataset_test) % num_tasks != 0:
|
||||||
|
print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
|
||||||
|
sampler_test = torch.utils.data.DistributedSampler(
|
||||||
|
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
||||||
|
|
||||||
|
# ---- Logging
|
||||||
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
||||||
os.makedirs(args.log_dir, exist_ok=True)
|
os.makedirs(args.log_dir, exist_ok=True)
|
||||||
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
||||||
else:
|
else:
|
||||||
log_writer = None
|
log_writer = None
|
||||||
|
|
||||||
|
# ---- DataLoaders
|
||||||
if not args.eval:
|
if not args.eval:
|
||||||
data_loader_train = torch.utils.data.DataLoader(
|
data_loader_train = torch.utils.data.DataLoader(
|
||||||
dataset_train, sampler=sampler_train,
|
dataset_train, sampler=sampler_train,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers,
|
pin_memory=args.pin_mem, drop_last=True,
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
)
|
||||||
|
print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
|
||||||
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
|
|
||||||
|
|
||||||
data_loader_val = torch.utils.data.DataLoader(
|
data_loader_val = torch.utils.data.DataLoader(
|
||||||
dataset_val, sampler=sampler_val,
|
dataset_val, sampler=sampler_val,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers,
|
pin_memory=args.pin_mem, drop_last=False,
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader_test = torch.utils.data.DataLoader(
|
data_loader_test = torch.utils.data.DataLoader(
|
||||||
dataset_test, sampler=sampler_test,
|
dataset_test, sampler=sampler_test,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||||
num_workers=args.num_workers,
|
pin_memory=args.pin_mem, drop_last=False,
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- Mixup/CutMix
|
||||||
mixup_fn = None
|
mixup_fn = None
|
||||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
|
||||||
if mixup_active:
|
if mixup_active:
|
||||||
print("Mixup is activated!")
|
print("Mixup is activated!")
|
||||||
mixup_fn = Mixup(
|
mixup_fn = Mixup(
|
||||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
label_smoothing=args.smoothing, num_classes=args.nb_classes
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Eval-only: resume weights
|
||||||
if args.resume and args.eval:
|
if args.resume and args.eval:
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||||
print("Load checkpoint from: %s" % args.resume)
|
print(f"Load checkpoint for eval from: {args.resume}")
|
||||||
model.load_state_dict(checkpoint['model'])
|
model.load_state_dict(checkpoint["model"])
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model_without_ddp = model
|
model_without_ddp = model
|
||||||
|
|
||||||
|
# ---- Adaptation toggle
|
||||||
|
if args.adaptation == "lp":
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param.requires_grad = ("head" in name)
|
||||||
|
print("[Adaptation] Linear probe: training classifier head only.")
|
||||||
|
else:
|
||||||
|
print("[Adaptation] Full fine-tuning: training all parameters.")
|
||||||
|
|
||||||
|
# ---- Count trainable params
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
|
print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
|
||||||
|
|
||||||
|
# ---- LR scaling by effective batch size
|
||||||
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
||||||
|
if args.lr is None:
|
||||||
if args.lr is None: # only base_lr is specified
|
|
||||||
args.lr = args.blr * eff_batch_size / 256
|
args.lr = args.blr * eff_batch_size / 256
|
||||||
|
print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
|
||||||
|
print(f"actual lr: {args.lr:.2e}")
|
||||||
|
print(f"accumulate grad iterations: {args.accum_iter}")
|
||||||
|
print(f"effective batch size: {eff_batch_size}")
|
||||||
|
|
||||||
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
# ---- DDP (if available)
|
||||||
print("actual lr: %.2e" % args.lr)
|
if args.distributed and torch.cuda.device_count() > 1:
|
||||||
|
ddp_kwargs = {}
|
||||||
print("accumulate grad iterations: %d" % args.accum_iter)
|
if args.adaptation == "lp":
|
||||||
print("effective batch size: %d" % eff_batch_size)
|
ddp_kwargs["find_unused_parameters"] = True
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
if args.distributed:
|
model, device_ids=[args.gpu], **ddp_kwargs
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
)
|
||||||
model_without_ddp = model.module
|
model_without_ddp = model.module
|
||||||
|
else:
|
||||||
|
model_without_ddp = model # single-GPU
|
||||||
|
|
||||||
|
# ---- Optimizer param groups (after freezing)
|
||||||
|
no_weight_decay = (model_without_ddp.no_weight_decay()
|
||||||
|
if hasattr(model_without_ddp, "no_weight_decay") else [])
|
||||||
|
|
||||||
|
|
||||||
|
param_groups = lrd.param_groups_lrd(
|
||||||
|
model_without_ddp,
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
no_weight_decay_list=no_weight_decay,
|
||||||
|
layer_decay=args.layer_decay,
|
||||||
|
)
|
||||||
|
for g in param_groups:
|
||||||
|
g["params"] = [p for p in g["params"] if p.requires_grad]
|
||||||
|
|
||||||
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
|
|
||||||
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
|
|
||||||
no_weight_decay_list=no_weight_decay,
|
|
||||||
layer_decay=args.layer_decay
|
|
||||||
)
|
|
||||||
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
||||||
loss_scaler = NativeScaler()
|
loss_scaler = NativeScaler()
|
||||||
|
print(f"criterion = {criterion}")
|
||||||
|
|
||||||
print("criterion = %s" % str(criterion))
|
# ---- Load previous full state (optimizer, scaler, etc.)
|
||||||
|
misc.load_model(args=args, model_without_ddp=model_without_ddp,
|
||||||
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
optimizer=optimizer, loss_scaler=loss_scaler)
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Eval-only Short Circuit
|
||||||
|
# =========================
|
||||||
if args.eval:
|
if args.eval:
|
||||||
if 'epoch' in checkpoint:
|
if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
|
||||||
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
|
print(f"Test with the best model at epoch = {checkpoint['epoch']}")
|
||||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
|
test_stats, auc_roc = evaluate(
|
||||||
num_class=args.nb_classes, log_writer=log_writer)
|
data_loader_test, model, device, args, epoch=0, mode="test",
|
||||||
exit(0)
|
num_class=args.nb_classes, log_writer=log_writer
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Train Loop
|
||||||
|
# =========================
|
||||||
print(f"Start training for {args.epochs} epochs")
|
print(f"Start training for {args.epochs} epochs")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
max_score = 0.0
|
max_score = 0.0
|
||||||
best_epoch = 0
|
best_epoch = 0
|
||||||
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
for epoch in range(args.start_epoch, args.epochs):
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
data_loader_train.sampler.set_epoch(epoch)
|
data_loader_train.sampler.set_epoch(epoch)
|
||||||
@@ -352,49 +388,55 @@ def main(args, criterion):
|
|||||||
model, criterion, data_loader_train,
|
model, criterion, data_loader_train,
|
||||||
optimizer, device, epoch, loss_scaler,
|
optimizer, device, epoch, loss_scaler,
|
||||||
args.clip_grad, mixup_fn,
|
args.clip_grad, mixup_fn,
|
||||||
log_writer=log_writer,
|
log_writer=log_writer, args=args
|
||||||
args=args
|
)
|
||||||
|
|
||||||
|
val_stats, val_score = evaluate(
|
||||||
|
data_loader_val, model, device, args, epoch, mode="val",
|
||||||
|
num_class=args.nb_classes, log_writer=log_writer
|
||||||
)
|
)
|
||||||
|
|
||||||
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
|
|
||||||
num_class=args.nb_classes, log_writer=log_writer)
|
|
||||||
if max_score < val_score:
|
if max_score < val_score:
|
||||||
max_score = val_score
|
max_score = val_score
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
if args.output_dir and args.savemodel:
|
if args.output_dir and args.savemodel:
|
||||||
misc.save_model(
|
misc.save_model(
|
||||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
args=args, model=model, model_without_ddp=model_without_ddp,
|
||||||
loss_scaler=loss_scaler, epoch=epoch, mode='best')
|
optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
|
||||||
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
|
)
|
||||||
|
print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
|
||||||
|
|
||||||
if epoch == (args.epochs - 1):
|
|
||||||
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
|
|
||||||
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
|
||||||
model.to(device)
|
|
||||||
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
|
|
||||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
|
|
||||||
num_class=args.nb_classes, log_writer=None)
|
|
||||||
|
|
||||||
if log_writer is not None:
|
if log_writer is not None:
|
||||||
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
|
log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
|
||||||
|
log_writer.flush()
|
||||||
|
|
||||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
|
||||||
'epoch': epoch,
|
"epoch": epoch,
|
||||||
'n_parameters': n_parameters}
|
"n_parameters": n_parameters}
|
||||||
|
|
||||||
if args.output_dir and misc.is_main_process():
|
if args.output_dir and misc.is_main_process():
|
||||||
if log_writer is not None:
|
with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
|
||||||
log_writer.flush()
|
|
||||||
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(log_stats) + "\n")
|
f.write(json.dumps(log_stats) + "\n")
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# Final Test (Best Ckpt)
|
||||||
|
# =========================
|
||||||
|
ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
|
||||||
|
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
|
||||||
|
model.to(device)
|
||||||
|
print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
|
||||||
|
_test_stats, _auc_roc = evaluate(
|
||||||
|
data_loader_test, model, device, args, -1, mode="test",
|
||||||
|
num_class=args.nb_classes, log_writer=None
|
||||||
|
)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
print('Training time {}'.format(total_time_str))
|
print(f"Training time {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
args = get_args_parser()
|
args = get_args_parser()
|
||||||
args = args.parse_args()
|
args = args.parse_args()
|
||||||
|
|
||||||
@@ -402,6 +444,5 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
if args.output_dir:
|
if args.output_dir:
|
||||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
main(args, criterion)
|
main(args, criterion)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,414 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
from timm.models.layers import trunc_normal_
|
|
||||||
from timm.data.mixup import Mixup
|
|
||||||
|
|
||||||
import models_vit as models
|
|
||||||
import util.lr_decay as lrd
|
|
||||||
import util.misc as misc
|
|
||||||
from util.datasets import build_dataset
|
|
||||||
from util.pos_embed import interpolate_pos_embed
|
|
||||||
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
|
||||||
from huggingface_hub import hf_hub_download, login
|
|
||||||
from engine_finetune import train_one_epoch, evaluate
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
import faulthandler
|
|
||||||
|
|
||||||
faulthandler.enable()
|
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def get_args_parser():
|
|
||||||
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
|
|
||||||
parser.add_argument('--batch_size', default=128, type=int,
|
|
||||||
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
|
||||||
parser.add_argument('--epochs', default=50, type=int)
|
|
||||||
parser.add_argument('--accum_iter', default=1, type=int,
|
|
||||||
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
|
||||||
help='Name of model to train')
|
|
||||||
parser.add_argument('--input_size', default=256, type=int,
|
|
||||||
help='images input size')
|
|
||||||
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
|
|
||||||
help='Drop path rate (default: 0.1)')
|
|
||||||
|
|
||||||
# Optimizer parameters
|
|
||||||
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
|
||||||
help='Clip gradient norm (default: None, no clipping)')
|
|
||||||
parser.add_argument('--weight_decay', type=float, default=0.05,
|
|
||||||
help='weight decay (default: 0.05)')
|
|
||||||
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
|
||||||
help='learning rate (absolute lr)')
|
|
||||||
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
|
|
||||||
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
|
||||||
parser.add_argument('--layer_decay', type=float, default=0.65,
|
|
||||||
help='layer-wise lr decay from ELECTRA/BEiT')
|
|
||||||
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
|
||||||
help='lower lr bound for cyclic schedulers that hit 0')
|
|
||||||
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
|
||||||
help='epochs to warmup LR')
|
|
||||||
|
|
||||||
# Augmentation parameters
|
|
||||||
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
|
|
||||||
help='Color jitter factor (enabled only when not using Auto/RandAug)')
|
|
||||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
|
||||||
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
|
|
||||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
|
||||||
help='Label smoothing (default: 0.1)')
|
|
||||||
|
|
||||||
# * Random Erase params
|
|
||||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
|
||||||
help='Random erase prob (default: 0.25)')
|
|
||||||
parser.add_argument('--remode', type=str, default='pixel',
|
|
||||||
help='Random erase mode (default: "pixel")')
|
|
||||||
parser.add_argument('--recount', type=int, default=1,
|
|
||||||
help='Random erase count (default: 1)')
|
|
||||||
parser.add_argument('--resplit', action='store_true', default=False,
|
|
||||||
help='Do not random erase first (clean) augmentation split')
|
|
||||||
|
|
||||||
# * Mixup params
|
|
||||||
parser.add_argument('--mixup', type=float, default=0,
|
|
||||||
help='mixup alpha, mixup enabled if > 0.')
|
|
||||||
parser.add_argument('--cutmix', type=float, default=0,
|
|
||||||
help='cutmix alpha, cutmix enabled if > 0.')
|
|
||||||
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
|
||||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
|
||||||
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
|
||||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
|
||||||
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
|
||||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
|
||||||
parser.add_argument('--mixup_mode', type=str, default='batch',
|
|
||||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
|
||||||
|
|
||||||
# * Finetuning params
|
|
||||||
parser.add_argument('--finetune', default='', type=str,
|
|
||||||
help='finetune from checkpoint')
|
|
||||||
parser.add_argument('--task', default='', type=str,
|
|
||||||
help='finetune from checkpoint')
|
|
||||||
parser.add_argument('--global_pool', action='store_true')
|
|
||||||
parser.set_defaults(global_pool=True)
|
|
||||||
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
|
||||||
help='Use class token instead of global pool for classification')
|
|
||||||
|
|
||||||
# Dataset parameters
|
|
||||||
parser.add_argument('--data_path', default='./data/', type=str,
|
|
||||||
help='dataset path')
|
|
||||||
parser.add_argument('--nb_classes', default=8, type=int,
|
|
||||||
help='number of the classification types')
|
|
||||||
parser.add_argument('--output_dir', default='./output_dir',
|
|
||||||
help='path where to save, empty for no saving')
|
|
||||||
parser.add_argument('--log_dir', default='./output_logs',
|
|
||||||
help='path where to tensorboard log')
|
|
||||||
parser.add_argument('--device', default='cuda',
|
|
||||||
help='device to use for training / testing')
|
|
||||||
parser.add_argument('--seed', default=0, type=int)
|
|
||||||
parser.add_argument('--resume', default='',
|
|
||||||
help='resume from checkpoint')
|
|
||||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
|
||||||
help='start epoch')
|
|
||||||
parser.add_argument('--eval', action='store_true',
|
|
||||||
help='Perform evaluation only')
|
|
||||||
parser.add_argument('--dist_eval', action='store_true', default=False,
|
|
||||||
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
|
||||||
parser.add_argument('--num_workers', default=10, type=int)
|
|
||||||
parser.add_argument('--pin_mem', action='store_true',
|
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
|
||||||
parser.set_defaults(pin_mem=True)
|
|
||||||
|
|
||||||
# distributed training parameters
|
|
||||||
parser.add_argument('--world_size', default=1, type=int,
|
|
||||||
help='number of distributed processes')
|
|
||||||
parser.add_argument('--local_rank', default=-1, type=int)
|
|
||||||
parser.add_argument('--dist_on_itp', action='store_true')
|
|
||||||
parser.add_argument('--dist_url', default='env://',
|
|
||||||
help='url used to set up distributed training')
|
|
||||||
|
|
||||||
# fine-tuning parameters
|
|
||||||
parser.add_argument('--savemodel', action='store_true', default=True,
|
|
||||||
help='Save model')
|
|
||||||
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
|
|
||||||
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
|
|
||||||
parser.add_argument('--datasets_seed', default=2026, type=int)
|
|
||||||
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
def main(args, criterion):
|
|
||||||
if args.resume and not args.eval:
|
|
||||||
resume = args.resume
|
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
|
||||||
print("Load checkpoint from: %s" % args.resume)
|
|
||||||
args = checkpoint['args']
|
|
||||||
args.resume = resume
|
|
||||||
|
|
||||||
misc.init_distributed_mode(args)
|
|
||||||
|
|
||||||
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
|
||||||
print("{}".format(args).replace(', ', ',\n'))
|
|
||||||
|
|
||||||
device = torch.device(args.device)
|
|
||||||
|
|
||||||
# fix the seed for reproducibility
|
|
||||||
seed = args.seed + misc.get_rank()
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
|
|
||||||
cudnn.benchmark = True
|
|
||||||
|
|
||||||
if args.model=='RETFound_mae':
|
|
||||||
model = models.__dict__[args.model](
|
|
||||||
img_size=args.input_size,
|
|
||||||
num_classes=args.nb_classes,
|
|
||||||
drop_path_rate=args.drop_path,
|
|
||||||
global_pool=args.global_pool,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = models.__dict__[args.model](
|
|
||||||
num_classes=args.nb_classes,
|
|
||||||
drop_path_rate=args.drop_path,
|
|
||||||
args=args,
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.finetune and not args.eval:
|
|
||||||
|
|
||||||
print(f"Downloading pre-trained weights from: {args.finetune}")
|
|
||||||
|
|
||||||
checkpoint_path = hf_hub_download(
|
|
||||||
repo_id=f'YukunZhou/{args.finetune}',
|
|
||||||
filename=f'{args.finetune}.pth',
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
||||||
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
|
||||||
|
|
||||||
if args.model!='RETFound_mae':
|
|
||||||
checkpoint_model = checkpoint['teacher']
|
|
||||||
else:
|
|
||||||
checkpoint_model = checkpoint['model']
|
|
||||||
|
|
||||||
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
|
||||||
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
|
||||||
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
|
||||||
|
|
||||||
state_dict = model.state_dict()
|
|
||||||
for k in ['head.weight', 'head.bias']:
|
|
||||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
|
||||||
print(f"Removing key {k} from pretrained checkpoint")
|
|
||||||
del checkpoint_model[k]
|
|
||||||
|
|
||||||
# interpolate position embedding
|
|
||||||
interpolate_pos_embed(model, checkpoint_model)
|
|
||||||
|
|
||||||
# load pre-trained model
|
|
||||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
|
||||||
|
|
||||||
trunc_normal_(model.head.weight, std=2e-5)
|
|
||||||
|
|
||||||
dataset_train = build_dataset(is_train='train', args=args)
|
|
||||||
dataset_val = build_dataset(is_train='val', args=args)
|
|
||||||
dataset_test = build_dataset(is_train='test', args=args)
|
|
||||||
|
|
||||||
|
|
||||||
if True: # args.distributed:
|
|
||||||
num_tasks = misc.get_world_size()
|
|
||||||
global_rank = misc.get_rank()
|
|
||||||
if not args.eval:
|
|
||||||
sampler_train = torch.utils.data.DistributedSampler(
|
|
||||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
|
||||||
)
|
|
||||||
print("Sampler_train = %s" % str(sampler_train))
|
|
||||||
if args.dist_eval:
|
|
||||||
if len(dataset_val) % num_tasks != 0:
|
|
||||||
print(
|
|
||||||
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
||||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
|
||||||
'equal num of samples per-process.')
|
|
||||||
sampler_val = torch.utils.data.DistributedSampler(
|
|
||||||
dataset_val, num_replicas=num_tasks, rank=global_rank,
|
|
||||||
shuffle=True) # shuffle=True to reduce monitor bias
|
|
||||||
else:
|
|
||||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
|
||||||
|
|
||||||
if args.dist_eval:
|
|
||||||
if len(dataset_test) % num_tasks != 0:
|
|
||||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
|
||||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
|
||||||
'equal num of samples per-process.')
|
|
||||||
sampler_test = torch.utils.data.DistributedSampler(
|
|
||||||
dataset_test, num_replicas=num_tasks, rank=global_rank,
|
|
||||||
shuffle=True) # shuffle=True to reduce monitor bias
|
|
||||||
else:
|
|
||||||
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
|
||||||
|
|
||||||
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
|
||||||
os.makedirs(args.log_dir, exist_ok=True)
|
|
||||||
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
|
||||||
else:
|
|
||||||
log_writer = None
|
|
||||||
|
|
||||||
if not args.eval:
|
|
||||||
data_loader_train = torch.utils.data.DataLoader(
|
|
||||||
dataset_train, sampler=sampler_train,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_workers=args.num_workers,
|
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
|
|
||||||
|
|
||||||
data_loader_val = torch.utils.data.DataLoader(
|
|
||||||
dataset_val, sampler=sampler_val,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_workers=args.num_workers,
|
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=False
|
|
||||||
)
|
|
||||||
|
|
||||||
data_loader_test = torch.utils.data.DataLoader(
|
|
||||||
dataset_test, sampler=sampler_test,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
num_workers=args.num_workers,
|
|
||||||
pin_memory=args.pin_mem,
|
|
||||||
drop_last=False
|
|
||||||
)
|
|
||||||
|
|
||||||
mixup_fn = None
|
|
||||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
|
||||||
if mixup_active:
|
|
||||||
print("Mixup is activated!")
|
|
||||||
mixup_fn = Mixup(
|
|
||||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
|
||||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
|
||||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
|
||||||
|
|
||||||
if args.resume and args.eval:
|
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
|
||||||
print("Load checkpoint from: %s" % args.resume)
|
|
||||||
model.load_state_dict(checkpoint['model'])
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
model_without_ddp = model
|
|
||||||
|
|
||||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
||||||
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
|
|
||||||
|
|
||||||
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
|
||||||
|
|
||||||
if args.lr is None: # only base_lr is specified
|
|
||||||
args.lr = args.blr * eff_batch_size / 256
|
|
||||||
|
|
||||||
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
|
||||||
print("actual lr: %.2e" % args.lr)
|
|
||||||
|
|
||||||
print("accumulate grad iterations: %d" % args.accum_iter)
|
|
||||||
print("effective batch size: %d" % eff_batch_size)
|
|
||||||
|
|
||||||
if args.distributed:
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
|
|
||||||
model_without_ddp = model.module
|
|
||||||
|
|
||||||
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
|
|
||||||
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
|
|
||||||
no_weight_decay_list=no_weight_decay,
|
|
||||||
layer_decay=args.layer_decay
|
|
||||||
)
|
|
||||||
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
|
||||||
loss_scaler = NativeScaler()
|
|
||||||
|
|
||||||
print("criterion = %s" % str(criterion))
|
|
||||||
|
|
||||||
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if 'head' in name:
|
|
||||||
param.requires_grad = True
|
|
||||||
else:
|
|
||||||
param.requires_grad = False
|
|
||||||
|
|
||||||
|
|
||||||
if args.eval:
|
|
||||||
if 'epoch' in checkpoint:
|
|
||||||
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
|
|
||||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
|
|
||||||
num_class=args.nb_classes, log_writer=log_writer)
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
print(f"Start training for {args.epochs} epochs")
|
|
||||||
start_time = time.time()
|
|
||||||
max_score = 0.0
|
|
||||||
best_epoch = 0
|
|
||||||
for epoch in range(args.start_epoch, args.epochs):
|
|
||||||
if args.distributed:
|
|
||||||
data_loader_train.sampler.set_epoch(epoch)
|
|
||||||
|
|
||||||
train_stats = train_one_epoch(
|
|
||||||
model, criterion, data_loader_train,
|
|
||||||
optimizer, device, epoch, loss_scaler,
|
|
||||||
args.clip_grad, mixup_fn,
|
|
||||||
log_writer=log_writer,
|
|
||||||
args=args
|
|
||||||
)
|
|
||||||
|
|
||||||
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
|
|
||||||
num_class=args.nb_classes, log_writer=log_writer)
|
|
||||||
if max_score < val_score:
|
|
||||||
max_score = val_score
|
|
||||||
best_epoch = epoch
|
|
||||||
if args.output_dir and args.savemodel:
|
|
||||||
misc.save_model(
|
|
||||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
|
||||||
loss_scaler=loss_scaler, epoch=epoch, mode='best')
|
|
||||||
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
|
|
||||||
|
|
||||||
|
|
||||||
if epoch == (args.epochs - 1):
|
|
||||||
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
|
|
||||||
model.load_state_dict(checkpoint['model'], strict=False)
|
|
||||||
model.to(device)
|
|
||||||
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
|
|
||||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
|
|
||||||
num_class=args.nb_classes, log_writer=None)
|
|
||||||
|
|
||||||
if log_writer is not None:
|
|
||||||
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
|
|
||||||
|
|
||||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
|
||||||
'epoch': epoch,
|
|
||||||
'n_parameters': n_parameters}
|
|
||||||
|
|
||||||
if args.output_dir and misc.is_main_process():
|
|
||||||
if log_writer is not None:
|
|
||||||
log_writer.flush()
|
|
||||||
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(log_stats) + "\n")
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
||||||
print('Training time {}'.format(total_time_str))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
args = get_args_parser()
|
|
||||||
args = args.parse_args()
|
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
if args.output_dir:
|
|
||||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
main(args, criterion)
|
|
||||||
|
|
||||||
|
|
||||||
+41
-5
@@ -1,7 +1,3 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
# Partly revised by YZ @UCL&Moorfields
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
@@ -10,7 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from timm.models.layers import trunc_normal_
|
||||||
|
|
||||||
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
||||||
""" Vision Transformer with support for global average pooling
|
""" Vision Transformer with support for global average pooling
|
||||||
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def Dinov2(args, **kwargs):
|
||||||
|
|
||||||
|
if args.model_arch == 'dinov2_vits14':
|
||||||
|
arch = 'vit_small_patch14_dinov2.lvd142m'
|
||||||
|
elif args.model_arch == 'dinov2_vitb14':
|
||||||
|
arch = 'vit_base_patch14_dinov2.lvd142m'
|
||||||
|
elif args.model_arch == 'dinov2_vitl14':
|
||||||
|
arch = 'vit_large_patch14_dinov2.lvd142m'
|
||||||
|
elif args.model_arch == 'dinov2_vitg14':
|
||||||
|
arch = 'vit_giant_patch14_dinov2.lvd142m'
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
|
||||||
|
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
|
||||||
|
|
||||||
|
model = timm.create_model(
|
||||||
|
arch,
|
||||||
|
pretrained=True,
|
||||||
|
img_size=224,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def RETFound_dinov2(args, **kwargs):
|
def RETFound_dinov2(args, **kwargs):
|
||||||
model = timm.create_model(
|
model = timm.create_model(
|
||||||
'vit_large_patch14_dinov2.lvd142m',
|
'vit_large_patch14_dinov2.lvd142m',
|
||||||
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def Dinov3(args, **kwargs):
|
||||||
|
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
|
||||||
|
model = torch.hub.load(
|
||||||
|
repo_or_dir="facebookresearch/dinov3",
|
||||||
|
model=args.model_arch,
|
||||||
|
pretrained=False, # main() will load your checkpoint
|
||||||
|
trust_repo=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Figure out feature dimension for the probe
|
||||||
|
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
|
||||||
|
model.head = nn.Linear(feat_dim, args.nb_classes)
|
||||||
|
trunc_normal_(model.head.weight, std=2e-5)
|
||||||
|
if model.head.bias is not None:
|
||||||
|
nn.init.zeros_(model.head.bias)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# ==== Model settings ====
|
||||||
|
# adaptation {finetune,lp}
|
||||||
|
ADAPTATION="finetune"
|
||||||
|
MODEL="RETFound_dinov2"
|
||||||
|
MODEL_ARCH="retfound_dinov2"
|
||||||
|
FINETUNE="RETFound_dinov2_meh"
|
||||||
|
|
||||||
|
# ==== Data settings ====
|
||||||
|
# change the dataset name and corresponding class number
|
||||||
|
DATASET="MESSIDOR2"
|
||||||
|
NUM_CLASS=5
|
||||||
|
data_path="./${DATASET}"
|
||||||
|
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||||
|
|
||||||
|
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||||
|
--model "${MODEL}" \
|
||||||
|
--model_arch "${MODEL_ARCH}" \
|
||||||
|
--finetune "${FINETUNE}" \
|
||||||
|
--savemodel \
|
||||||
|
--global_pool \
|
||||||
|
--batch_size 24 \
|
||||||
|
--world_size 1 \
|
||||||
|
--epochs 50 \
|
||||||
|
--nb_classes "${NUM_CLASS}" \
|
||||||
|
--data_path "${data_path}" \
|
||||||
|
--input_size 224 \
|
||||||
|
--task "${task}" \
|
||||||
|
--adaptation "${ADAPTATION}"
|
||||||
+49
-21
@@ -1,29 +1,39 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
# Partly revised by YZ @UCL&Moorfields
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Subset
|
||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from timm.data import create_transform
|
from timm.data import create_transform
|
||||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(is_train, args):
|
def build_dataset(is_train, args):
|
||||||
transform = build_transform(is_train, args)
|
transform = build_transform(is_train, args)
|
||||||
root = os.path.join(args.data_path, is_train)
|
root = os.path.join(args.data_path, is_train)
|
||||||
dataset = datasets.ImageFolder(root, transform=transform)
|
dataset = datasets.ImageFolder(root, transform=transform)
|
||||||
|
|
||||||
return dataset
|
if is_train == 'train':
|
||||||
|
ratio = float(getattr(args, "dataratio", 1.0))
|
||||||
|
seed = int(getattr(args, "seed", 0))
|
||||||
|
stratified = bool(getattr(args, "stratified", False))
|
||||||
|
|
||||||
|
if 0.0 < ratio < 1.0:
|
||||||
|
if stratified:
|
||||||
|
idx = _stratified_indices(dataset.targets, ratio, seed)
|
||||||
|
else:
|
||||||
|
# simple uniform subsample with torch.Generator for reproducibility
|
||||||
|
g = torch.Generator().manual_seed(seed)
|
||||||
|
n = len(dataset)
|
||||||
|
k = max(1, int(n * ratio))
|
||||||
|
idx = torch.randperm(n, generator=g)[:k].tolist()
|
||||||
|
dataset = Subset(dataset, idx)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
def build_transform(is_train, args):
|
def build_transform(is_train, args):
|
||||||
mean = IMAGENET_DEFAULT_MEAN
|
mean = IMAGENET_DEFAULT_MEAN
|
||||||
std = IMAGENET_DEFAULT_STD
|
std = IMAGENET_DEFAULT_STD
|
||||||
# train transform
|
|
||||||
if is_train == 'train':
|
if is_train == 'train':
|
||||||
# this should always dispatch to transforms_imagenet_train
|
return create_transform(
|
||||||
transform = create_transform(
|
|
||||||
input_size=args.input_size,
|
input_size=args.input_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
color_jitter=args.color_jitter,
|
color_jitter=args.color_jitter,
|
||||||
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
|
|||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
)
|
)
|
||||||
return transform
|
|
||||||
|
|
||||||
# eval transform
|
# eval transform
|
||||||
t = []
|
crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
|
||||||
if args.input_size <= 224:
|
|
||||||
crop_pct = 224 / 256
|
|
||||||
else:
|
|
||||||
crop_pct = 1.0
|
|
||||||
size = int(args.input_size / crop_pct)
|
size = int(args.input_size / crop_pct)
|
||||||
t.append(
|
t = [
|
||||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||||
)
|
transforms.CenterCrop(args.input_size),
|
||||||
t.append(transforms.CenterCrop(args.input_size))
|
transforms.ToTensor(),
|
||||||
t.append(transforms.ToTensor())
|
transforms.Normalize(mean, std),
|
||||||
t.append(transforms.Normalize(mean, std))
|
]
|
||||||
return transforms.Compose(t)
|
return transforms.Compose(t)
|
||||||
|
|
||||||
|
# ---- helpers ----
|
||||||
|
|
||||||
|
def _stratified_indices(targets, ratio: float, seed: int):
|
||||||
|
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
|
||||||
|
t = torch.as_tensor(targets)
|
||||||
|
classes = torch.unique(t)
|
||||||
|
g = torch.Generator().manual_seed(seed)
|
||||||
|
|
||||||
|
keep = []
|
||||||
|
for c in classes.tolist():
|
||||||
|
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
|
||||||
|
if len(cls_idx) == 0:
|
||||||
|
continue
|
||||||
|
k = max(1, int(round(len(cls_idx) * ratio)))
|
||||||
|
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
|
||||||
|
keep.extend(sel.tolist())
|
||||||
|
|
||||||
|
# shuffle final indices (stable across seed)
|
||||||
|
g2 = torch.Generator().manual_seed(seed + 1)
|
||||||
|
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
|
||||||
|
return keep
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user