Compare commits
10 Commits
91915d6a14
...
ae9a9ecf37
| Author | SHA1 | Date | |
|---|---|---|---|
| ae9a9ecf37 | |||
| 8f5b2ce5e7 | |||
| dbbddb8936 | |||
| ed8b469a0f | |||
| 17768be893 | |||
| bda7a6c69f | |||
| 7489af0620 | |||
| 409f7b6167 | |||
| 897d71c8c9 | |||
| a7a9b3a8b7 |
@@ -1,14 +1,15 @@
|
||||
## RETFound - A foundation model for retinal imaging
|
||||
## RETFound - A foundation model for retinal images
|
||||
|
||||
|
||||
Official repo including a series of retinal foundation models.<br>
|
||||
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).<br>
|
||||
[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2):
|
||||
Official repo including a series of foundation models and applications for retinal images.<br>
|
||||
`[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).<br>
|
||||
`[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).<br>
|
||||
`[DINOv2]`:[General-purpose vision foundation models DINOv2 by Meta](https://github.com/facebookresearch/dinov2).<br>
|
||||
`[DINOv3]`:[General-purpose vision foundation models DINOv3 by Meta](https://github.com/facebookresearch/dinov3).<br>
|
||||
|
||||
|
||||
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
|
||||
|
||||
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
|
||||
|
||||
|
||||
### 📝Key features
|
||||
|
||||
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
|
||||
|
||||
### 🎉News
|
||||
|
||||
- 🐉2025/09: **Preprint benchmarking DINOv3, DINOv2, and RETFound is [available](https://arxiv.org/abs/2509.03421)!**
|
||||
- 🐉2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
|
||||
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
|
||||
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
|
||||
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
|
||||
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
|
||||
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
|
||||
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
|
||||
- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
|
||||
|
||||
|
||||
### 🔧Install environment
|
||||
@@ -40,17 +42,17 @@ conda activate retfound
|
||||
2. Install dependencies
|
||||
|
||||
```
|
||||
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||
git clone https://github.com/rmaphoh/RETFound_MAE/
|
||||
cd RETFound_MAE
|
||||
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
git clone https://github.com/rmaphoh/RETFound/
|
||||
cd RETFound
|
||||
pip install -r requirements.txt
|
||||
pip install ipykernel
|
||||
python -m ipykernel install --user --name retfound --display-name "Python (retfound)"
|
||||
```
|
||||
|
||||
|
||||
### 🌱Fine-tuning with RETFound weights
|
||||
|
||||
To fine tune RETFound on your own data, follow these steps:
|
||||
|
||||
1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_mae_meh</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_mae_shanghai</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_dinov2_meh</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_dinov2_shanghai</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
<td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
@@ -100,7 +102,9 @@ huggingface-cli login --token YOUR_HUGGINGFACE_TOKEN
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
```
|
||||
|
||||
3. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
|
||||
3. If you would like to fine-tune [DINOv2](https://github.com/facebookresearch/dinov2) and [DINOv3](https://github.com/facebookresearch/dinov3), please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
|
||||
|
||||
4. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
|
||||
|
||||
```
|
||||
├── data folder
|
||||
@@ -118,56 +122,122 @@ export HF_ENDPOINT=https://hf-mirror.com
|
||||
├──class_c
|
||||
```
|
||||
|
||||
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
|
||||
|
||||
The model and finetune can be selected:
|
||||
|
||||
| model | finetune |
|
||||
|-----------------|--------------------------|
|
||||
| RETFound_mae | RETFound_mae_natureCFP |
|
||||
| RETFound_mae | RETFound_mae_natureOCT |
|
||||
| RETFound_mae | RETFound_mae_meh |
|
||||
| RETFound_mae | RETFound_mae_shanghai |
|
||||
| RETFound_dinov2 | RETFound_dinov2_meh |
|
||||
| RETFound_dinov2 | RETFound_dinov2_shanghai |
|
||||
5. Start fine-tuning by running `sh train.sh`.
|
||||
|
||||
|
||||
In `train.sh`, the model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE`:
|
||||
|
||||
**RETFound**:
|
||||
|
||||
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||
|-----------------|--------------------------|--------------------------|--------------------------|
|
||||
| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
|
||||
| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
|
||||
| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
|
||||
| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
|
||||
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
|
||||
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
|
||||
|
||||
|
||||
**DINOv3**:
|
||||
|
||||
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||
|-----------------|--------------------------|----------------------------------|--------------------------|
|
||||
| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
|
||||
| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
|
||||
| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
|
||||
| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
|
||||
| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
|
||||
| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
|
||||
|
||||
|
||||
**DINOv2**:
|
||||
|
||||
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|
||||
|-----------------|--------------------------|------------------------------|--------------------------|
|
||||
| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
|
||||
| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
|
||||
| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
|
||||
| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
|
||||
|
||||
|
||||
Change the DATA_PATH to your dataset directory.
|
||||
|
||||
```
|
||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
--model RETFound_mae \
|
||||
--savemodel \
|
||||
--global_pool \
|
||||
--batch_size 16 \
|
||||
--world_size 1 \
|
||||
--epochs 100 \
|
||||
--blr 5e-3 --layer_decay 0.65 \
|
||||
--weight_decay 0.05 --drop_path 0.2 \
|
||||
--nb_classes 5 \
|
||||
--data_path ./IDRiD \
|
||||
--input_size 224 \
|
||||
--task RETFound_mae_meh-IDRiD \
|
||||
--finetune RETFound_mae_meh
|
||||
# ==== Model settings ====
|
||||
# adaptation {finetune,lp}
|
||||
ADAPTATION="finetune"
|
||||
MODEL="RETFound_dinov2"
|
||||
MODEL_ARCH="retfound_dinov2"
|
||||
FINETUNE="RETFound_dinov2_meh"
|
||||
|
||||
# ==== Data settings ====
|
||||
# change the dataset name and corresponding class number
|
||||
DATASET="MESSIDOR2"
|
||||
NUM_CLASS=5
|
||||
|
||||
# =======================
|
||||
DATA_PATH="PATH TO THE DATASET"
|
||||
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||
|
||||
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||
--model "${MODEL}" \
|
||||
--model_arch "${MODEL_ARCH}" \
|
||||
--finetune "${FINETUNE}" \
|
||||
--savemodel \
|
||||
--global_pool \
|
||||
--batch_size 24 \
|
||||
--world_size 1 \
|
||||
--epochs 50 \
|
||||
--nb_classes "${NUM_CLASS}" \
|
||||
--data_path "${DATA_PATH}" \
|
||||
--input_size 224 \
|
||||
--task "${TASK}" \
|
||||
--adaptation "${ADAPTATION}"
|
||||
|
||||
```
|
||||
|
||||
|
||||
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
|
||||
|
||||
6. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the DATA_PATH below)
|
||||
|
||||
|
||||
```
|
||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
--model RETFound_mae \
|
||||
--savemodel \
|
||||
--eval \
|
||||
--global_pool \
|
||||
--batch_size 16 \
|
||||
--world_size 1 \
|
||||
--epochs 100 \
|
||||
--blr 5e-3 --layer_decay 0.65 \
|
||||
--weight_decay 0.05 --drop_path 0.2 \
|
||||
--nb_classes 5 \
|
||||
--data_path ./IDRiD \
|
||||
--input_size 224 \
|
||||
--task RETFound_mae_meh-IDRiD \
|
||||
--resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth
|
||||
# ==== Model/settings (match training) ====
|
||||
ADAPTATION="finetune"
|
||||
MODEL="RETFound_dinov2"
|
||||
MODEL_ARCH="retfound_dinov2"
|
||||
FINETUNE="RETFound_dinov2_meh"
|
||||
|
||||
# ==== Data/settings (match training) ====
|
||||
DATASET="MESSIDOR2"
|
||||
NUM_CLASS=5
|
||||
|
||||
# =======================
|
||||
DATA_PATH="PATH TO THE DATASET"
|
||||
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||
|
||||
# Path to the trained checkpoint (adjust if you saved elsewhere)
|
||||
CKPT="./output_dir/${TASK}/checkpoint-best.pth"
|
||||
|
||||
# ==== Evaluation only ====
|
||||
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||
--model "${MODEL}" \
|
||||
--model_arch "${MODEL_ARCH}" \
|
||||
--savemodel \
|
||||
--global_pool \
|
||||
--batch_size 128 \
|
||||
--world_size 1 \
|
||||
--nb_classes "${NUM_CLASS}" \
|
||||
--data_path "${DATA_PATH}" \
|
||||
--input_size 224 \
|
||||
--task "${TASK}" \
|
||||
--adaptation "${ADAPTATION}" \
|
||||
--eval \
|
||||
--resume "${CKPT}"
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -175,9 +245,6 @@ torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
|
||||
If you find this repository useful, please consider citing this paper:
|
||||
|
||||
```
|
||||
TBD
|
||||
```
|
||||
|
||||
```
|
||||
@article{zhou2023foundation,
|
||||
@@ -192,4 +259,14 @@ TBD
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
```
|
||||
@misc{zhou2025generalistversusspecialistvision,
|
||||
title={Generalist versus Specialist Vision Foundation Models for Ocular Disease and Oculomics},
|
||||
author={Yukun Zhou and Paul Nderitu and Jocelyn Hui Lin Goh and Justin Engelmann and Siegfried K. Wagner and Anran Ran and Hongyang Jiang and Lie Ju and Ke Zou and Sahana Srinivasan and Hyunmin Kim and Takahiro Ninomiya and Zheyuan Wang and Gabriel Dawei Yang and Eden Ruffell and Dominic Williamson and Rui Santos and Gabor Mark Somfai and Carol Y. Cheung and Tien Yin Wong and Daniel C. Alexander and Yih Chung Tham and Pearse A. Keane},
|
||||
year={2025},
|
||||
eprint={2509.03421},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={eess.IV},
|
||||
url={https://arxiv.org/abs/2509.03421},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -0,0 +1,223 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "76b39fb1",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true
|
||||
},
|
||||
"source": [
|
||||
"## Jupyter notebook example - Classification task\n",
|
||||
"### Example using [MESSIDOR2](https://www.adcis.net/en/third-party/messidor2/) dataset\n",
|
||||
"**Application**: Using RETFound for five-category diabetic retinopathy classification\n",
|
||||
"\n",
|
||||
"**Author**: Yukun Zhou\n",
|
||||
"\n",
|
||||
"**Date**: 30 Nov 2025\n",
|
||||
"\n",
|
||||
"**Performance**:\n",
|
||||
"\n",
|
||||
"<table align=\"left\">\n",
|
||||
"<tr>\n",
|
||||
" <th>Accuracy</th>\n",
|
||||
" <th>Recall</th>\n",
|
||||
" <th>F1 Score</th>\n",
|
||||
" <th>ROC AUC</th>\n",
|
||||
" <th>PR AUC</th>\n",
|
||||
"</tr>\n",
|
||||
"<tr>\n",
|
||||
" <td>0.7091</td>\n",
|
||||
" <td>0.5616</td>\n",
|
||||
" <td>0.6078</td>\n",
|
||||
" <td>0.9037</td>\n",
|
||||
" <td>0.6863</td>\n",
|
||||
"</tr>\n",
|
||||
"</table>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7ec435a7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Install environment\n",
|
||||
"1. Follow [RETFound README](https://github.com/rmaphoh/RETFound) to install environment\n",
|
||||
"2. Restart this Jupyter Notebook\n",
|
||||
"3. Select Kernel retfound"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7cbf5e93-6ca0-4401-88e6-64e39968e7cd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys, torch\n",
|
||||
"from pathlib import Path\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"PROJECT_ROOT = Path.cwd().resolve()\n",
|
||||
"\n",
|
||||
"if PROJECT_ROOT.name == 'examples': PROJECT_ROOT = PROJECT_ROOT.parent\n",
|
||||
"os.chdir(PROJECT_ROOT)\n",
|
||||
"\n",
|
||||
"print('Project root:', PROJECT_ROOT)\n",
|
||||
"print(\"sys.executable:\", sys.executable)\n",
|
||||
"print(\"torch version:\", torch.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed67953f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Prepare MESSIDOR2 dataset\n",
|
||||
"1. Download from the [shared data pool](https://github.com/rmaphoh/RETFound/blob/main/BENCHMARK.md).\n",
|
||||
"2. Put the data folder under the project directory, e.g. \"RETFound/MESSIDOR2\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "357be2fa-a914-4d1f-8759-76b2b1c3f20f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Hyperparameter and path settings\n",
|
||||
"1. Can choose finetune or lp (linear probe)\n",
|
||||
"2. Model selection [info](https://github.com/rmaphoh/RETFound#:~:text=In%20train.sh%2C%20the%20model%20can%20be%20selected%20by%20changing%20the%20hyperparameters%20MODEL%2C%20MODEL_ARCH%2C%20FINETUNE%3A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5f675843",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"ADAPTATION='finetune'\n",
|
||||
"MODEL='RETFound_dinov2'\n",
|
||||
"MODEL_ARCH='retfound_dinov2'\n",
|
||||
"FINETUNE='RETFound_dinov2_meh'\n",
|
||||
"DATASET='MESSIDOR2'\n",
|
||||
"NUM_CLASS=5\n",
|
||||
"DATA_PATH=PROJECT_ROOT/DATASET\n",
|
||||
"BATCH_SIZE=24\n",
|
||||
"EPOCHS=50\n",
|
||||
"INPUT_SIZE=224\n",
|
||||
"WORLD_SIZE=1\n",
|
||||
"TASK=f\"{MODEL_ARCH}_{DATASET}_{ADAPTATION}\"\n",
|
||||
"OUTPUT_DIR=PROJECT_ROOT/'output_dir'/TASK\n",
|
||||
"print('DATA_PATH:',DATA_PATH)\n",
|
||||
"print('TASK:',TASK)\n",
|
||||
"print('OUTPUT_DIR:',OUTPUT_DIR)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6ac04845",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Fine-tuning and testing RETFound on MESSIDOR2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d23ff751",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"!{sys.executable} main_finetune.py \\\n",
|
||||
" --model {MODEL} \\\n",
|
||||
" --model_arch {MODEL_ARCH} \\\n",
|
||||
" --finetune {FINETUNE} \\\n",
|
||||
" --savemodel \\\n",
|
||||
" --global_pool \\\n",
|
||||
" --batch_size {BATCH_SIZE} \\\n",
|
||||
" --epochs {EPOCHS} \\\n",
|
||||
" --nb_classes {NUM_CLASS} \\\n",
|
||||
" --data_path {DATA_PATH} \\\n",
|
||||
" --input_size {INPUT_SIZE} \\\n",
|
||||
" --task {TASK} \\\n",
|
||||
" --adaptation {ADAPTATION}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84ce93ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Evaluation-only"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0af0f8a7",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"CKPT = OUTPUT_DIR / \"checkpoint-best.pth\"\n",
|
||||
"\n",
|
||||
"!{sys.executable} main_finetune.py \\\n",
|
||||
" --model {MODEL} \\\n",
|
||||
" --model_arch {MODEL_ARCH} \\\n",
|
||||
" --finetune {FINETUNE} \\\n",
|
||||
" --savemodel \\\n",
|
||||
" --global_pool \\\n",
|
||||
" --batch_size 128 \\\n",
|
||||
" --nb_classes {NUM_CLASS} \\\n",
|
||||
" --data_path {DATA_PATH} \\\n",
|
||||
" --input_size {INPUT_SIZE} \\\n",
|
||||
" --task {TASK} \\\n",
|
||||
" --adaptation {ADAPTATION} \\\n",
|
||||
" --eval \\\n",
|
||||
" --resume {CKPT}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "02d2dce7-31c2-48e2-87ce-9223b74cf94e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"environment": {
|
||||
"kernel": "retfound",
|
||||
"name": "workbench-notebooks.m128",
|
||||
"type": "gcloud",
|
||||
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "retfound_jupyter (Local)",
|
||||
"language": "python",
|
||||
"name": "retfound"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -22,18 +22,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": null,
|
||||
"id": "90c3d964",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
|
||||
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
|
||||
" \n",
|
||||
" # load model\n",
|
||||
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
|
||||
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
|
||||
" \n",
|
||||
" # build model\n",
|
||||
" if arch=='vit_large_patch16':\n",
|
||||
" if arch=='RETFound_mae':\n",
|
||||
" model = models.__dict__[arch](\n",
|
||||
" img_size=224,\n",
|
||||
" num_classes=5,\n",
|
||||
@@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": null,
|
||||
"id": "9a250363",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -78,7 +78,7 @@
|
||||
"def get_feature(data_path,\n",
|
||||
" chkpt_dir,\n",
|
||||
" device,\n",
|
||||
" arch='vit_large_patch16'):\n",
|
||||
" arch='RETFound_mae'):\n",
|
||||
" #loading model\n",
|
||||
" model_ = prepare_model(chkpt_dir, arch)\n",
|
||||
" model_.to(device)\n",
|
||||
@@ -121,7 +121,7 @@
|
||||
"source": [
|
||||
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
|
||||
"data_path = 'DATA_PATH'\n",
|
||||
"device = torch.device('cuda')\n",
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"arch='dinov2_large'"
|
||||
]
|
||||
},
|
||||
|
||||
+285
-244
@@ -1,180 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# =========================
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import faulthandler
|
||||
|
||||
# =========================
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from timm.models.layers import trunc_normal_
|
||||
from timm.data.mixup import Mixup
|
||||
from huggingface_hub import hf_hub_download, login # login imported as in original
|
||||
|
||||
# =========================
|
||||
import models_vit as models
|
||||
import util.lr_decay as lrd
|
||||
import util.misc as misc
|
||||
from util.datasets import build_dataset
|
||||
from util.pos_embed import interpolate_pos_embed
|
||||
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
||||
from huggingface_hub import hf_hub_download, login
|
||||
from engine_finetune import train_one_epoch, evaluate
|
||||
|
||||
import warnings
|
||||
import faulthandler
|
||||
|
||||
# =========================
|
||||
faulthandler.enable()
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
|
||||
parser.add_argument('--batch_size', default=128, type=int,
|
||||
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
||||
parser.add_argument('--epochs', default=50, type=int)
|
||||
parser.add_argument('--accum_iter', default=1, type=int,
|
||||
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
||||
parser = argparse.ArgumentParser(
|
||||
"MAE fine-tuning / linear probing for image classification", add_help=False
|
||||
)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--input_size', default=256, type=int,
|
||||
help='images input size')
|
||||
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
|
||||
help='Drop path rate (default: 0.1)')
|
||||
# ---- Core training
|
||||
parser.add_argument("--batch_size", default=128, type=int,
|
||||
help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
|
||||
parser.add_argument("--epochs", default=50, type=int)
|
||||
parser.add_argument("--accum_iter", default=1, type=int,
|
||||
help="Gradient accumulation steps")
|
||||
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.05,
|
||||
help='weight decay (default: 0.05)')
|
||||
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
||||
help='learning rate (absolute lr)')
|
||||
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
|
||||
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
||||
parser.add_argument('--layer_decay', type=float, default=0.65,
|
||||
help='layer-wise lr decay from ELECTRA/BEiT')
|
||||
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0')
|
||||
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to warmup LR')
|
||||
# ---- Model parameters
|
||||
parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
|
||||
help="Model entry in models_vit.py")
|
||||
parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
|
||||
help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
|
||||
parser.add_argument("--input_size", default=256, type=int, help="Image size")
|
||||
parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
|
||||
parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
|
||||
parser.add_argument("--cls_token", action="store_false", dest="global_pool",
|
||||
help="Use class token instead of global pool for classification")
|
||||
|
||||
# Augmentation parameters
|
||||
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
|
||||
help='Color jitter factor (enabled only when not using Auto/RandAug)')
|
||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='Label smoothing (default: 0.1)')
|
||||
# ---- Optimizer parameters
|
||||
parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
|
||||
parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
|
||||
parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
|
||||
parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
|
||||
help="Base LR: lr = blr * total_batch_size / 256")
|
||||
parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
|
||||
parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
|
||||
parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
|
||||
|
||||
# * Random Erase params
|
||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
||||
help='Random erase prob (default: 0.25)')
|
||||
parser.add_argument('--remode', type=str, default='pixel',
|
||||
help='Random erase mode (default: "pixel")')
|
||||
parser.add_argument('--recount', type=int, default=1,
|
||||
help='Random erase count (default: 1)')
|
||||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
# ---- Augmentation
|
||||
parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
|
||||
parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
|
||||
parser.add_argument("--smoothing", type=float, default=0.1)
|
||||
|
||||
# * Mixup params
|
||||
parser.add_argument('--mixup', type=float, default=0,
|
||||
help='mixup alpha, mixup enabled if > 0.')
|
||||
parser.add_argument('--cutmix', type=float, default=0,
|
||||
help='cutmix alpha, cutmix enabled if > 0.')
|
||||
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
||||
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
||||
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
||||
parser.add_argument('--mixup_mode', type=str, default='batch',
|
||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
# ---- Random erase
|
||||
parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
|
||||
parser.add_argument("--remode", type=str, default="pixel")
|
||||
parser.add_argument("--recount", type=int, default=1)
|
||||
parser.add_argument("--resplit", action="store_true", default=False)
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='', type=str,
|
||||
help='finetune from checkpoint')
|
||||
parser.add_argument('--task', default='', type=str,
|
||||
help='finetune from checkpoint')
|
||||
parser.add_argument('--global_pool', action='store_true')
|
||||
parser.set_defaults(global_pool=True)
|
||||
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
||||
help='Use class token instead of global pool for classification')
|
||||
# ---- Mixup/Cutmix
|
||||
parser.add_argument("--mixup", type=float, default=0.0)
|
||||
parser.add_argument("--cutmix", type=float, default=0.0)
|
||||
parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
|
||||
parser.add_argument("--mixup_prob", type=float, default=1.0)
|
||||
parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
|
||||
parser.add_argument("--mixup_mode", type=str, default="batch")
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data_path', default='./data/', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--nb_classes', default=8, type=int,
|
||||
help='number of the classification types')
|
||||
parser.add_argument('--output_dir', default='./output_dir',
|
||||
help='path where to save, empty for no saving')
|
||||
parser.add_argument('--log_dir', default='./output_logs',
|
||||
help='path where to tensorboard log')
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--resume', default='',
|
||||
help='resume from checkpoint')
|
||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true',
|
||||
help='Perform evaluation only')
|
||||
parser.add_argument('--dist_eval', action='store_true', default=False,
|
||||
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
||||
parser.add_argument('--num_workers', default=10, type=int)
|
||||
parser.add_argument('--pin_mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.set_defaults(pin_mem=True)
|
||||
# ---- Finetuning & adaptation
|
||||
parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
|
||||
parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
|
||||
parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
|
||||
help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--local_rank', default=-1, type=int)
|
||||
parser.add_argument('--dist_on_itp', action='store_true')
|
||||
parser.add_argument('--dist_url', default='env://',
|
||||
help='url used to set up distributed training')
|
||||
# ---- Dataset & paths
|
||||
parser.add_argument("--data_path", default="./data/", type=str)
|
||||
parser.add_argument("--nb_classes", default=8, type=int)
|
||||
parser.add_argument("--output_dir", default="./output_dir")
|
||||
parser.add_argument("--log_dir", default="./output_logs")
|
||||
|
||||
# fine-tuning parameters
|
||||
parser.add_argument('--savemodel', action='store_true', default=True,
|
||||
help='Save model')
|
||||
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
|
||||
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
|
||||
parser.add_argument('--datasets_seed', default=2026, type=int)
|
||||
# >>> NEW: training data efficiency <<<
|
||||
parser.add_argument(
|
||||
"--dataratio", type=str, default="1.0",
|
||||
help=('Training data ratio(s) for subsampling in build_dataset. '
|
||||
'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
|
||||
'(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stratified", action="store_true",
|
||||
help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
|
||||
)
|
||||
|
||||
# ---- Runtime
|
||||
parser.add_argument("--device", default="cuda")
|
||||
parser.add_argument("--seed", default=0, type=int)
|
||||
parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
|
||||
parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
|
||||
parser.add_argument("--eval", action="store_true", help="Evaluation only")
|
||||
parser.add_argument("--dist_eval", action="store_true", default=False,
|
||||
help="Distributed evaluation (faster monitoring during training)")
|
||||
parser.add_argument("--num_workers", default=10, type=int)
|
||||
parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
|
||||
|
||||
# ---- Distributed
|
||||
parser.add_argument("--world_size", default=1, type=int)
|
||||
parser.add_argument("--local_rank", default=-1, type=int)
|
||||
parser.add_argument("--dist_on_itp", action="store_true")
|
||||
parser.add_argument("--dist_url", default="env://")
|
||||
|
||||
# ---- Misc
|
||||
parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
|
||||
parser.add_argument("--norm", default="IMAGENET", type=str)
|
||||
parser.add_argument("--enhance", action="store_true", default=False)
|
||||
parser.add_argument("--datasets_seed", default=2026, type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
# =========================
|
||||
# Main
|
||||
# =========================
|
||||
def main(args, criterion):
|
||||
# ---- Optionally load args from resume (when training)
|
||||
if args.resume and not args.eval:
|
||||
resume = args.resume
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
print("Load checkpoint from: %s" % args.resume)
|
||||
args = checkpoint['args']
|
||||
args.resume = resume
|
||||
resume_path = args.resume
|
||||
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||
print(f"Load checkpoint (args) from: {args.resume}")
|
||||
args = checkpoint["args"]
|
||||
args.resume = resume_path
|
||||
|
||||
# ---- Distributed setup
|
||||
misc.init_distributed_mode(args)
|
||||
|
||||
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
||||
print("{}".format(args).replace(', ', ',\n'))
|
||||
print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
|
||||
print(f"{args}".replace(", ", ",\n"))
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
# ---- Reproducibility
|
||||
seed = args.seed + misc.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.model=='RETFound_mae':
|
||||
# ---- Build model
|
||||
if args.model == "RETFound_mae":
|
||||
model = models.__dict__[args.model](
|
||||
img_size=args.input_size,
|
||||
num_classes=args.nb_classes,
|
||||
drop_path_rate=args.drop_path,
|
||||
global_pool=args.global_pool,
|
||||
)
|
||||
img_size=args.input_size,
|
||||
num_classes=args.nb_classes,
|
||||
drop_path_rate=args.drop_path,
|
||||
global_pool=args.global_pool,
|
||||
)
|
||||
else:
|
||||
model = models.__dict__[args.model](
|
||||
num_classes=args.nb_classes,
|
||||
@@ -182,168 +176,210 @@ def main(args, criterion):
|
||||
args=args,
|
||||
)
|
||||
|
||||
# ---- Load pre-trained weights (if requested and not eval-only)
|
||||
if args.finetune and not args.eval:
|
||||
print(f"Preparing to load pre-trained weights: {args.finetune}")
|
||||
|
||||
print(f"Downloading pre-trained weights from: {args.finetune}")
|
||||
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=f'YukunZhou/{args.finetune}',
|
||||
filename=f'{args.finetune}.pth',
|
||||
)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
||||
|
||||
if args.model!='RETFound_mae':
|
||||
checkpoint_model = checkpoint['teacher']
|
||||
if args.model in ["Dinov3", "Dinov2"]:
|
||||
checkpoint_path = args.finetune # local path
|
||||
elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
|
||||
print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=f"YukunZhou/{args.finetune}",
|
||||
filename=f"{args.finetune}.pth",
|
||||
)
|
||||
else:
|
||||
checkpoint_model = checkpoint['model']
|
||||
raise ValueError(
|
||||
f"Unsupported model '{args.model}'. "
|
||||
f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
|
||||
)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
|
||||
|
||||
if args.model in ["Dinov3", "Dinov2"]:
|
||||
checkpoint_model = checkpoint
|
||||
elif args.model == "RETFound_dinov2":
|
||||
checkpoint_model = checkpoint["teacher"]
|
||||
else: # RETFound_mae
|
||||
checkpoint_model = checkpoint["model"]
|
||||
|
||||
# -- Key hygiene
|
||||
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
||||
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
||||
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
||||
|
||||
# -- Remove classifier if shape mismatched
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias']:
|
||||
for k in ["head.weight", "head.bias"]:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
# -- Interpolate pos embed (ViT)
|
||||
interpolate_pos_embed(model, checkpoint_model)
|
||||
|
||||
# load pre-trained model
|
||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
||||
# -- Load backbone weights (non-strict)
|
||||
_ = model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
# -- Re-init head
|
||||
if hasattr(model, "head") and hasattr(model.head, "weight"):
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
|
||||
dataset_train = build_dataset(is_train='train', args=args)
|
||||
dataset_val = build_dataset(is_train='val', args=args)
|
||||
dataset_test = build_dataset(is_train='test', args=args)
|
||||
# ---- Datasets & samplers
|
||||
dataset_train = build_dataset(is_train="train", args=args)
|
||||
dataset_val = build_dataset(is_train="val", args=args)
|
||||
dataset_test = build_dataset(is_train="test", args=args)
|
||||
|
||||
num_tasks = misc.get_world_size()
|
||||
global_rank = misc.get_rank()
|
||||
|
||||
if True: # args.distributed:
|
||||
num_tasks = misc.get_world_size()
|
||||
global_rank = misc.get_rank()
|
||||
if not args.eval:
|
||||
sampler_train = torch.utils.data.DistributedSampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
print("Sampler_train = %s" % str(sampler_train))
|
||||
if args.dist_eval:
|
||||
if len(dataset_val) % num_tasks != 0:
|
||||
print(
|
||||
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_val = torch.utils.data.DistributedSampler(
|
||||
dataset_val, num_replicas=num_tasks, rank=global_rank,
|
||||
shuffle=True) # shuffle=True to reduce monitor bias
|
||||
else:
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
|
||||
if not args.eval:
|
||||
sampler_train = torch.utils.data.DistributedSampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
print(f"Sampler_train = {sampler_train}")
|
||||
if args.dist_eval:
|
||||
if len(dataset_test) % num_tasks != 0:
|
||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_test = torch.utils.data.DistributedSampler(
|
||||
dataset_test, num_replicas=num_tasks, rank=global_rank,
|
||||
shuffle=True) # shuffle=True to reduce monitor bias
|
||||
if len(dataset_val) % num_tasks != 0:
|
||||
print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
|
||||
sampler_val = torch.utils.data.DistributedSampler(
|
||||
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
|
||||
if args.dist_eval:
|
||||
if len(dataset_test) % num_tasks != 0:
|
||||
print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
|
||||
sampler_test = torch.utils.data.DistributedSampler(
|
||||
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
||||
|
||||
# ---- Logging
|
||||
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
||||
os.makedirs(args.log_dir, exist_ok=True)
|
||||
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
||||
else:
|
||||
log_writer = None
|
||||
|
||||
# ---- DataLoaders
|
||||
if not args.eval:
|
||||
data_loader_train = torch.utils.data.DataLoader(
|
||||
dataset_train, sampler=sampler_train,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=True,
|
||||
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem, drop_last=True,
|
||||
)
|
||||
|
||||
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
|
||||
print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
|
||||
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val, sampler=sampler_val,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem, drop_last=False,
|
||||
)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test, sampler=sampler_test,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
batch_size=args.batch_size, num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem, drop_last=False,
|
||||
)
|
||||
|
||||
# ---- Mixup/CutMix
|
||||
mixup_fn = None
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
|
||||
if mixup_active:
|
||||
print("Mixup is activated!")
|
||||
mixup_fn = Mixup(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
||||
label_smoothing=args.smoothing, num_classes=args.nb_classes
|
||||
)
|
||||
|
||||
# ---- Eval-only: resume weights
|
||||
if args.resume and args.eval:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
print("Load checkpoint from: %s" % args.resume)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
checkpoint = torch.load(args.resume, map_location="cpu")
|
||||
print(f"Load checkpoint for eval from: {args.resume}")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
|
||||
model.to(device)
|
||||
model_without_ddp = model
|
||||
|
||||
# ---- Adaptation toggle
|
||||
if args.adaptation == "lp":
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad = ("head" in name)
|
||||
print("[Adaptation] Linear probe: training classifier head only.")
|
||||
else:
|
||||
print("[Adaptation] Full fine-tuning: training all parameters.")
|
||||
|
||||
# ---- Count trainable params
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
|
||||
print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
|
||||
|
||||
# ---- LR scaling by effective batch size
|
||||
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
||||
|
||||
if args.lr is None: # only base_lr is specified
|
||||
if args.lr is None:
|
||||
args.lr = args.blr * eff_batch_size / 256
|
||||
print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
|
||||
print(f"actual lr: {args.lr:.2e}")
|
||||
print(f"accumulate grad iterations: {args.accum_iter}")
|
||||
print(f"effective batch size: {eff_batch_size}")
|
||||
|
||||
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
||||
print("actual lr: %.2e" % args.lr)
|
||||
|
||||
print("accumulate grad iterations: %d" % args.accum_iter)
|
||||
print("effective batch size: %d" % eff_batch_size)
|
||||
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
# ---- DDP (if available)
|
||||
if args.distributed and torch.cuda.device_count() > 1:
|
||||
ddp_kwargs = {}
|
||||
if args.adaptation == "lp":
|
||||
ddp_kwargs["find_unused_parameters"] = True
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.gpu], **ddp_kwargs
|
||||
)
|
||||
model_without_ddp = model.module
|
||||
else:
|
||||
model_without_ddp = model # single-GPU
|
||||
|
||||
# ---- Optimizer param groups (after freezing)
|
||||
no_weight_decay = (model_without_ddp.no_weight_decay()
|
||||
if hasattr(model_without_ddp, "no_weight_decay") else [])
|
||||
|
||||
|
||||
param_groups = lrd.param_groups_lrd(
|
||||
model_without_ddp,
|
||||
weight_decay=args.weight_decay,
|
||||
no_weight_decay_list=no_weight_decay,
|
||||
layer_decay=args.layer_decay,
|
||||
)
|
||||
for g in param_groups:
|
||||
g["params"] = [p for p in g["params"] if p.requires_grad]
|
||||
|
||||
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
|
||||
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
|
||||
no_weight_decay_list=no_weight_decay,
|
||||
layer_decay=args.layer_decay
|
||||
)
|
||||
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
||||
loss_scaler = NativeScaler()
|
||||
print(f"criterion = {criterion}")
|
||||
|
||||
print("criterion = %s" % str(criterion))
|
||||
|
||||
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
||||
# ---- Load previous full state (optimizer, scaler, etc.)
|
||||
misc.load_model(args=args, model_without_ddp=model_without_ddp,
|
||||
optimizer=optimizer, loss_scaler=loss_scaler)
|
||||
|
||||
# =========================
|
||||
# Eval-only Short Circuit
|
||||
# =========================
|
||||
if args.eval:
|
||||
if 'epoch' in checkpoint:
|
||||
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
|
||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
|
||||
num_class=args.nb_classes, log_writer=log_writer)
|
||||
exit(0)
|
||||
if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
|
||||
print(f"Test with the best model at epoch = {checkpoint['epoch']}")
|
||||
test_stats, auc_roc = evaluate(
|
||||
data_loader_test, model, device, args, epoch=0, mode="test",
|
||||
num_class=args.nb_classes, log_writer=log_writer
|
||||
)
|
||||
return
|
||||
|
||||
# =========================
|
||||
# Train Loop
|
||||
# =========================
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
max_score = 0.0
|
||||
best_epoch = 0
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
data_loader_train.sampler.set_epoch(epoch)
|
||||
@@ -352,49 +388,55 @@ def main(args, criterion):
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
args.clip_grad, mixup_fn,
|
||||
log_writer=log_writer,
|
||||
args=args
|
||||
log_writer=log_writer, args=args
|
||||
)
|
||||
|
||||
val_stats, val_score = evaluate(
|
||||
data_loader_val, model, device, args, epoch, mode="val",
|
||||
num_class=args.nb_classes, log_writer=log_writer
|
||||
)
|
||||
|
||||
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
|
||||
num_class=args.nb_classes, log_writer=log_writer)
|
||||
if max_score < val_score:
|
||||
max_score = val_score
|
||||
best_epoch = epoch
|
||||
if args.output_dir and args.savemodel:
|
||||
misc.save_model(
|
||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
||||
loss_scaler=loss_scaler, epoch=epoch, mode='best')
|
||||
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
|
||||
|
||||
|
||||
if epoch == (args.epochs - 1):
|
||||
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
|
||||
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
|
||||
model.to(device)
|
||||
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
|
||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
|
||||
num_class=args.nb_classes, log_writer=None)
|
||||
args=args, model=model, model_without_ddp=model_without_ddp,
|
||||
optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
|
||||
)
|
||||
print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
|
||||
|
||||
if log_writer is not None:
|
||||
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
|
||||
log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
|
||||
log_writer.flush()
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
|
||||
"epoch": epoch,
|
||||
"n_parameters": n_parameters}
|
||||
|
||||
if args.output_dir and misc.is_main_process():
|
||||
if log_writer is not None:
|
||||
log_writer.flush()
|
||||
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
|
||||
with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
# =========================
|
||||
# Final Test (Best Ckpt)
|
||||
# =========================
|
||||
ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
|
||||
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
||||
model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
|
||||
_test_stats, _auc_roc = evaluate(
|
||||
data_loader_test, model, device, args, -1, mode="test",
|
||||
num_class=args.nb_classes, log_writer=None
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
print(f"Training time {total_time_str}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser()
|
||||
args = args.parse_args()
|
||||
|
||||
@@ -402,6 +444,5 @@ if __name__ == '__main__':
|
||||
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
main(args, criterion)
|
||||
|
||||
|
||||
|
||||
@@ -1,414 +0,0 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from timm.models.layers import trunc_normal_
|
||||
from timm.data.mixup import Mixup
|
||||
|
||||
import models_vit as models
|
||||
import util.lr_decay as lrd
|
||||
import util.misc as misc
|
||||
from util.datasets import build_dataset
|
||||
from util.pos_embed import interpolate_pos_embed
|
||||
from util.misc import NativeScalerWithGradNormCount as NativeScaler
|
||||
from huggingface_hub import hf_hub_download, login
|
||||
from engine_finetune import train_one_epoch, evaluate
|
||||
|
||||
import warnings
|
||||
import faulthandler
|
||||
|
||||
faulthandler.enable()
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
|
||||
parser.add_argument('--batch_size', default=128, type=int,
|
||||
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
|
||||
parser.add_argument('--epochs', default=50, type=int)
|
||||
parser.add_argument('--accum_iter', default=1, type=int,
|
||||
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--input_size', default=256, type=int,
|
||||
help='images input size')
|
||||
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
|
||||
help='Drop path rate (default: 0.1)')
|
||||
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
parser.add_argument('--weight_decay', type=float, default=0.05,
|
||||
help='weight decay (default: 0.05)')
|
||||
parser.add_argument('--lr', type=float, default=None, metavar='LR',
|
||||
help='learning rate (absolute lr)')
|
||||
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
|
||||
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
|
||||
parser.add_argument('--layer_decay', type=float, default=0.65,
|
||||
help='layer-wise lr decay from ELECTRA/BEiT')
|
||||
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0')
|
||||
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to warmup LR')
|
||||
|
||||
# Augmentation parameters
|
||||
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
|
||||
help='Color jitter factor (enabled only when not using Auto/RandAug)')
|
||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='Label smoothing (default: 0.1)')
|
||||
|
||||
# * Random Erase params
|
||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
||||
help='Random erase prob (default: 0.25)')
|
||||
parser.add_argument('--remode', type=str, default='pixel',
|
||||
help='Random erase mode (default: "pixel")')
|
||||
parser.add_argument('--recount', type=int, default=1,
|
||||
help='Random erase count (default: 1)')
|
||||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
|
||||
# * Mixup params
|
||||
parser.add_argument('--mixup', type=float, default=0,
|
||||
help='mixup alpha, mixup enabled if > 0.')
|
||||
parser.add_argument('--cutmix', type=float, default=0,
|
||||
help='cutmix alpha, cutmix enabled if > 0.')
|
||||
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
|
||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
||||
parser.add_argument('--mixup_prob', type=float, default=1.0,
|
||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
||||
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
|
||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
||||
parser.add_argument('--mixup_mode', type=str, default='batch',
|
||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='', type=str,
|
||||
help='finetune from checkpoint')
|
||||
parser.add_argument('--task', default='', type=str,
|
||||
help='finetune from checkpoint')
|
||||
parser.add_argument('--global_pool', action='store_true')
|
||||
parser.set_defaults(global_pool=True)
|
||||
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
|
||||
help='Use class token instead of global pool for classification')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data_path', default='./data/', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--nb_classes', default=8, type=int,
|
||||
help='number of the classification types')
|
||||
parser.add_argument('--output_dir', default='./output_dir',
|
||||
help='path where to save, empty for no saving')
|
||||
parser.add_argument('--log_dir', default='./output_logs',
|
||||
help='path where to tensorboard log')
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--resume', default='',
|
||||
help='resume from checkpoint')
|
||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true',
|
||||
help='Perform evaluation only')
|
||||
parser.add_argument('--dist_eval', action='store_true', default=False,
|
||||
help='Enabling distributed evaluation (recommended during training for faster monitor')
|
||||
parser.add_argument('--num_workers', default=10, type=int)
|
||||
parser.add_argument('--pin_mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.set_defaults(pin_mem=True)
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--local_rank', default=-1, type=int)
|
||||
parser.add_argument('--dist_on_itp', action='store_true')
|
||||
parser.add_argument('--dist_url', default='env://',
|
||||
help='url used to set up distributed training')
|
||||
|
||||
# fine-tuning parameters
|
||||
parser.add_argument('--savemodel', action='store_true', default=True,
|
||||
help='Save model')
|
||||
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
|
||||
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
|
||||
parser.add_argument('--datasets_seed', default=2026, type=int)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args, criterion):
|
||||
if args.resume and not args.eval:
|
||||
resume = args.resume
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
print("Load checkpoint from: %s" % args.resume)
|
||||
args = checkpoint['args']
|
||||
args.resume = resume
|
||||
|
||||
misc.init_distributed_mode(args)
|
||||
|
||||
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
||||
print("{}".format(args).replace(', ', ',\n'))
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# fix the seed for reproducibility
|
||||
seed = args.seed + misc.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
if args.model=='RETFound_mae':
|
||||
model = models.__dict__[args.model](
|
||||
img_size=args.input_size,
|
||||
num_classes=args.nb_classes,
|
||||
drop_path_rate=args.drop_path,
|
||||
global_pool=args.global_pool,
|
||||
)
|
||||
else:
|
||||
model = models.__dict__[args.model](
|
||||
num_classes=args.nb_classes,
|
||||
drop_path_rate=args.drop_path,
|
||||
args=args,
|
||||
)
|
||||
|
||||
if args.finetune and not args.eval:
|
||||
|
||||
print(f"Downloading pre-trained weights from: {args.finetune}")
|
||||
|
||||
checkpoint_path = hf_hub_download(
|
||||
repo_id=f'YukunZhou/{args.finetune}',
|
||||
filename=f'{args.finetune}.pth',
|
||||
)
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
print("Load pre-trained checkpoint from: %s" % args.finetune)
|
||||
|
||||
if args.model!='RETFound_mae':
|
||||
checkpoint_model = checkpoint['teacher']
|
||||
else:
|
||||
checkpoint_model = checkpoint['model']
|
||||
|
||||
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
|
||||
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
|
||||
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
|
||||
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
interpolate_pos_embed(model, checkpoint_model)
|
||||
|
||||
# load pre-trained model
|
||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
|
||||
dataset_train = build_dataset(is_train='train', args=args)
|
||||
dataset_val = build_dataset(is_train='val', args=args)
|
||||
dataset_test = build_dataset(is_train='test', args=args)
|
||||
|
||||
|
||||
if True: # args.distributed:
|
||||
num_tasks = misc.get_world_size()
|
||||
global_rank = misc.get_rank()
|
||||
if not args.eval:
|
||||
sampler_train = torch.utils.data.DistributedSampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
print("Sampler_train = %s" % str(sampler_train))
|
||||
if args.dist_eval:
|
||||
if len(dataset_val) % num_tasks != 0:
|
||||
print(
|
||||
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_val = torch.utils.data.DistributedSampler(
|
||||
dataset_val, num_replicas=num_tasks, rank=global_rank,
|
||||
shuffle=True) # shuffle=True to reduce monitor bias
|
||||
else:
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
|
||||
if args.dist_eval:
|
||||
if len(dataset_test) % num_tasks != 0:
|
||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_test = torch.utils.data.DistributedSampler(
|
||||
dataset_test, num_replicas=num_tasks, rank=global_rank,
|
||||
shuffle=True) # shuffle=True to reduce monitor bias
|
||||
else:
|
||||
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
|
||||
|
||||
if global_rank == 0 and args.log_dir is not None and not args.eval:
|
||||
os.makedirs(args.log_dir, exist_ok=True)
|
||||
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
|
||||
else:
|
||||
log_writer = None
|
||||
|
||||
if not args.eval:
|
||||
data_loader_train = torch.utils.data.DataLoader(
|
||||
dataset_train, sampler=sampler_train,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
|
||||
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val, sampler=sampler_val,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
data_loader_test = torch.utils.data.DataLoader(
|
||||
dataset_test, sampler=sampler_test,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
mixup_fn = None
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
if mixup_active:
|
||||
print("Mixup is activated!")
|
||||
mixup_fn = Mixup(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
||||
|
||||
if args.resume and args.eval:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
print("Load checkpoint from: %s" % args.resume)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
|
||||
model.to(device)
|
||||
model_without_ddp = model
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
|
||||
|
||||
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
|
||||
|
||||
if args.lr is None: # only base_lr is specified
|
||||
args.lr = args.blr * eff_batch_size / 256
|
||||
|
||||
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
|
||||
print("actual lr: %.2e" % args.lr)
|
||||
|
||||
print("accumulate grad iterations: %d" % args.accum_iter)
|
||||
print("effective batch size: %d" % eff_batch_size)
|
||||
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
|
||||
model_without_ddp = model.module
|
||||
|
||||
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
|
||||
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
|
||||
no_weight_decay_list=no_weight_decay,
|
||||
layer_decay=args.layer_decay
|
||||
)
|
||||
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
|
||||
loss_scaler = NativeScaler()
|
||||
|
||||
print("criterion = %s" % str(criterion))
|
||||
|
||||
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if 'head' in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
if args.eval:
|
||||
if 'epoch' in checkpoint:
|
||||
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
|
||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
|
||||
num_class=args.nb_classes, log_writer=log_writer)
|
||||
exit(0)
|
||||
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
max_score = 0.0
|
||||
best_epoch = 0
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
data_loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_stats = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
args.clip_grad, mixup_fn,
|
||||
log_writer=log_writer,
|
||||
args=args
|
||||
)
|
||||
|
||||
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
|
||||
num_class=args.nb_classes, log_writer=log_writer)
|
||||
if max_score < val_score:
|
||||
max_score = val_score
|
||||
best_epoch = epoch
|
||||
if args.output_dir and args.savemodel:
|
||||
misc.save_model(
|
||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
||||
loss_scaler=loss_scaler, epoch=epoch, mode='best')
|
||||
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
|
||||
|
||||
|
||||
if epoch == (args.epochs - 1):
|
||||
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
|
||||
model.load_state_dict(checkpoint['model'], strict=False)
|
||||
model.to(device)
|
||||
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
|
||||
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
|
||||
num_class=args.nb_classes, log_writer=None)
|
||||
|
||||
if log_writer is not None:
|
||||
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
|
||||
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
|
||||
if args.output_dir and misc.is_main_process():
|
||||
if log_writer is not None:
|
||||
log_writer.flush()
|
||||
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args_parser()
|
||||
args = args.parse_args()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
main(args, criterion)
|
||||
|
||||
|
||||
+41
-5
@@ -1,7 +1,3 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Partly revised by YZ @UCL&Moorfields
|
||||
# --------------------------------------------------------
|
||||
|
||||
from functools import partial
|
||||
|
||||
@@ -10,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
||||
""" Vision Transformer with support for global average pooling
|
||||
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
|
||||
|
||||
|
||||
|
||||
def Dinov2(args, **kwargs):
|
||||
|
||||
if args.model_arch == 'dinov2_vits14':
|
||||
arch = 'vit_small_patch14_dinov2.lvd142m'
|
||||
elif args.model_arch == 'dinov2_vitb14':
|
||||
arch = 'vit_base_patch14_dinov2.lvd142m'
|
||||
elif args.model_arch == 'dinov2_vitl14':
|
||||
arch = 'vit_large_patch14_dinov2.lvd142m'
|
||||
elif args.model_arch == 'dinov2_vitg14':
|
||||
arch = 'vit_giant_patch14_dinov2.lvd142m'
|
||||
else:
|
||||
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
|
||||
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
|
||||
|
||||
model = timm.create_model(
|
||||
arch,
|
||||
pretrained=True,
|
||||
img_size=224,
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def RETFound_dinov2(args, **kwargs):
|
||||
model = timm.create_model(
|
||||
'vit_large_patch14_dinov2.lvd142m',
|
||||
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
def Dinov3(args, **kwargs):
|
||||
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
|
||||
model = torch.hub.load(
|
||||
repo_or_dir="facebookresearch/dinov3",
|
||||
model=args.model_arch,
|
||||
pretrained=False, # main() will load your checkpoint
|
||||
trust_repo=True,
|
||||
)
|
||||
|
||||
# Figure out feature dimension for the probe
|
||||
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
|
||||
model.head = nn.Linear(feat_dim, args.nb_classes)
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
if model.head.bias is not None:
|
||||
nn.init.zeros_(model.head.bias)
|
||||
|
||||
return model
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# ==== Model settings ====
|
||||
# adaptation {finetune,lp}
|
||||
ADAPTATION="finetune"
|
||||
MODEL="RETFound_dinov2"
|
||||
MODEL_ARCH="retfound_dinov2"
|
||||
FINETUNE="RETFound_dinov2_meh"
|
||||
|
||||
# ==== Data settings ====
|
||||
# change the dataset name and corresponding class number
|
||||
DATASET="MESSIDOR2"
|
||||
NUM_CLASS=5
|
||||
data_path="./${DATASET}"
|
||||
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
|
||||
|
||||
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
|
||||
--model "${MODEL}" \
|
||||
--model_arch "${MODEL_ARCH}" \
|
||||
--finetune "${FINETUNE}" \
|
||||
--savemodel \
|
||||
--global_pool \
|
||||
--batch_size 24 \
|
||||
--world_size 1 \
|
||||
--epochs 50 \
|
||||
--nb_classes "${NUM_CLASS}" \
|
||||
--data_path "${data_path}" \
|
||||
--input_size 224 \
|
||||
--task "${task}" \
|
||||
--adaptation "${ADAPTATION}"
|
||||
+49
-21
@@ -1,29 +1,39 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Partly revised by YZ @UCL&Moorfields
|
||||
# --------------------------------------------------------
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Subset
|
||||
from torchvision import datasets, transforms
|
||||
from timm.data import create_transform
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
transform = build_transform(is_train, args)
|
||||
root = os.path.join(args.data_path, is_train)
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
|
||||
return dataset
|
||||
if is_train == 'train':
|
||||
ratio = float(getattr(args, "dataratio", 1.0))
|
||||
seed = int(getattr(args, "seed", 0))
|
||||
stratified = bool(getattr(args, "stratified", False))
|
||||
|
||||
if 0.0 < ratio < 1.0:
|
||||
if stratified:
|
||||
idx = _stratified_indices(dataset.targets, ratio, seed)
|
||||
else:
|
||||
# simple uniform subsample with torch.Generator for reproducibility
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
n = len(dataset)
|
||||
k = max(1, int(n * ratio))
|
||||
idx = torch.randperm(n, generator=g)[:k].tolist()
|
||||
dataset = Subset(dataset, idx)
|
||||
|
||||
return dataset
|
||||
|
||||
def build_transform(is_train, args):
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
std = IMAGENET_DEFAULT_STD
|
||||
# train transform
|
||||
|
||||
if is_train == 'train':
|
||||
# this should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
return create_transform(
|
||||
input_size=args.input_size,
|
||||
is_training=True,
|
||||
color_jitter=args.color_jitter,
|
||||
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
|
||||
mean=mean,
|
||||
std=std,
|
||||
)
|
||||
return transform
|
||||
|
||||
# eval transform
|
||||
t = []
|
||||
if args.input_size <= 224:
|
||||
crop_pct = 224 / 256
|
||||
else:
|
||||
crop_pct = 1.0
|
||||
crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
|
||||
size = int(args.input_size / crop_pct)
|
||||
t.append(
|
||||
t = [
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
)
|
||||
t.append(transforms.CenterCrop(args.input_size))
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(mean, std))
|
||||
transforms.CenterCrop(args.input_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std),
|
||||
]
|
||||
return transforms.Compose(t)
|
||||
|
||||
# ---- helpers ----
|
||||
|
||||
def _stratified_indices(targets, ratio: float, seed: int):
|
||||
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
|
||||
t = torch.as_tensor(targets)
|
||||
classes = torch.unique(t)
|
||||
g = torch.Generator().manual_seed(seed)
|
||||
|
||||
keep = []
|
||||
for c in classes.tolist():
|
||||
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
|
||||
if len(cls_idx) == 0:
|
||||
continue
|
||||
k = max(1, int(round(len(cls_idx) * ratio)))
|
||||
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
|
||||
keep.extend(sel.tolist())
|
||||
|
||||
# shuffle final indices (stable across seed)
|
||||
g2 = torch.Generator().manual_seed(seed + 1)
|
||||
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
|
||||
return keep
|
||||
|
||||
|
||||
Reference in New Issue
Block a user