Compare commits

...

10 Commits

Author SHA1 Message Date
rmaphoh ae9a9ecf37 add example Jupyter Notebook 2025-11-30 13:57:36 +00:00
rmaphoh 8f5b2ce5e7 add example Jupyter Notebook 2025-11-30 13:54:40 +00:00
rmaphoh dbbddb8936 update readme 2025-09-04 08:19:51 +01:00
rmaphoh ed8b469a0f update readme 2025-09-02 16:35:11 +01:00
rmaphoh 17768be893 update readme 2025-09-02 16:33:57 +01:00
rmaphoh bda7a6c69f remove unused file 2025-08-31 18:10:18 +01:00
rmaphoh 7489af0620 Incorporate DINOv3, DINOv2 2025-08-31 18:07:38 +01:00
rmaphoh 409f7b6167 Incorporate DINOv3, DINOv2 2025-08-31 18:03:57 +01:00
Yukun Zhou 897d71c8c9 Merge pull request #43 from BartvanderWoude/main
Fixes for latent_feature.ipynb
2025-04-29 10:42:20 +01:00
bartsserver a7a9b3a8b7 Correctly use arch to reference architecture in models_vit. Set weights_only to False for use with PyTorch>=2.6. Automatically select cpu if cuda is not available. 2025-04-17 12:09:07 +02:00
8 changed files with 776 additions and 757 deletions
+138 -61
View File
@@ -1,14 +1,15 @@
## RETFound - A foundation model for retinal imaging ## RETFound - A foundation model for retinal images
Official repo including a series of retinal foundation models.<br> Official repo including a series of foundation models and applications for retinal images.<br>
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae).<br> `[RETFound-MAE]`:[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x).<br>
[New checkpoints](https://huggingface.co/YukunZhou), some of which are based on [DINOV2](https://github.com/facebookresearch/dinov2): `[RETFound-DINOv2]`:[Revealing the Impact of Pre-training Data on Medical Foundation Models](https://www.researchsquare.com/article/rs-6080254/v1).<br>
`[DINOv2]`:[General-purpose vision foundation models DINOv2 by Meta](https://github.com/facebookresearch/dinov2).<br>
`[DINOv3]`:[General-purpose vision foundation models DINOv3 by Meta](https://github.com/facebookresearch/dinov3).<br>
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
### 📝Key features ### 📝Key features
@@ -19,13 +20,14 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
### 🎉News ### 🎉News
- 🐉2025/09: **Preprint benchmarking DINOv3, DINOv2, and RETFound is [available](https://arxiv.org/abs/2509.03421)!**
- 🐉2025/09: **We included state-of-the-art DINOv3 into fine-tuning pipeline for retinal applications!**
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!** - 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!** - 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!** - 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online! - 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/latent_feature.ipynb) are now online!
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online! - 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation! - 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
- 2023/10: change the hyperparameter of [input_size](https://github.com/rmaphoh/RETFound_MAE#:~:text=finetune%20./RETFound_cfp_weights.pth%20%5C-,%2D%2Dinput_size%20224,-For%20evaluation%20only) for any image size
### 🔧Install environment ### 🔧Install environment
@@ -40,17 +42,17 @@ conda activate retfound
2. Install dependencies 2. Install dependencies
``` ```
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu121
git clone https://github.com/rmaphoh/RETFound_MAE/ git clone https://github.com/rmaphoh/RETFound/
cd RETFound_MAE cd RETFound
pip install -r requirements.txt pip install -r requirements.txt
pip install ipykernel
python -m ipykernel install --user --name retfound --display-name "Python (retfound)"
``` ```
### 🌱Fine-tuning with RETFound weights ### 🌱Fine-tuning with RETFound weights
To fine tune RETFound on your own data, follow these steps:
1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2: 1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
<table><tbody> <table><tbody>
<!-- START TABLE --> <!-- START TABLE -->
@@ -71,22 +73,22 @@ To fine tune RETFound on your own data, follow these steps:
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_meh</td> <tr><td align="left">RETFound_mae_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_shanghai</td> <tr><td align="left">RETFound_mae_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_meh</td> <tr><td align="left">RETFound_dinov2_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_shanghai</td> <tr><td align="left">RETFound_dinov2_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">access</a></td>
<td align="center">TBD</a></td> <td align="center"><a href="https://www.researchsquare.com/article/rs-6080254/v1">FM data paper</a></td>
</tr> </tr>
</tbody></table> </tbody></table>
@@ -100,7 +102,9 @@ huggingface-cli login --token YOUR_HUGGINGFACE_TOKEN
export HF_ENDPOINT=https://hf-mirror.com export HF_ENDPOINT=https://hf-mirror.com
``` ```
3. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md)) 3. If you would like to fine-tune [DINOv2](https://github.com/facebookresearch/dinov2) and [DINOv3](https://github.com/facebookresearch/dinov3), please visit their GitHub repositories to download the model weights and put them in the RETFound folder.
4. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
``` ```
├── data folder ├── data folder
@@ -118,56 +122,122 @@ export HF_ENDPOINT=https://hf-mirror.com
├──class_c ├──class_c
``` ```
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
The model and finetune can be selected:
| model | finetune | 5. Start fine-tuning by running `sh train.sh`.
|-----------------|--------------------------|
| RETFound_mae | RETFound_mae_natureCFP |
| RETFound_mae | RETFound_mae_natureOCT | In `train.sh`, the model can be selected by changing the hyperparameters `MODEL`, `MODEL_ARCH`, `FINETUNE`:
| RETFound_mae | RETFound_mae_meh |
| RETFound_mae | RETFound_mae_shanghai | **RETFound**:
| RETFound_dinov2 | RETFound_dinov2_meh |
| RETFound_dinov2 | RETFound_dinov2_shanghai | | MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|--------------------------|--------------------------|
| RETFound_mae | retfound_mae | RETFound_mae_natureCFP | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_natureOCT | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_meh | ~300M |
| RETFound_mae | retfound_mae | RETFound_mae_shanghai | ~300M |
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_meh | ~300M |
| RETFound_dinov2 | retfound_dinov2 | RETFound_dinov2_shanghai | ~300M |
**DINOv3**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|----------------------------------|--------------------------|
| Dinov3 | dinov3_vits16 | dinov3_vits16_pretrain.pth | ~21M |
| Dinov3 | dinov3_vits16plus | dinov3_vits16plus_pretrain.pth | ~29M |
| Dinov3 | dinov3_vitb16 | dinov3_vitb16_pretrain.pth | ~86M |
| Dinov3 | dinov3_vitl16 | dinov3_vitl16_pretrain.pth | ~300M |
| Dinov3 | dinov3_vith16plus | dinov3_vith16plus_pretrain.pth | ~840M |
| Dinov3 | dinov3_vit7b16 | dinov3_vit7b16_pretrain.pth | ~6.7B |
**DINOv2**:
| MODEL | MODEL_ARCH | FINETUNE | SIZE |
|-----------------|--------------------------|------------------------------|--------------------------|
| Dinov2 | dinov2_vits14 | dinov2_vits14_pretrain.pth | ~21M |
| Dinov2 | dinov2_vitb14 | dinov2_vitb14_pretrain.pth | ~86M |
| Dinov2 | dinov2_vitl14 | dinov2_vitl14_pretrain.pth | ~300M |
| Dinov2 | dinov2_vitg14 | dinov2_vitg14_pretrain.pth | ~1.1B |
Change the DATA_PATH to your dataset directory.
``` ```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \ # ==== Model settings ====
--model RETFound_mae \ # adaptation {finetune,lp}
--savemodel \ ADAPTATION="finetune"
--global_pool \ MODEL="RETFound_dinov2"
--batch_size 16 \ MODEL_ARCH="retfound_dinov2"
--world_size 1 \ FINETUNE="RETFound_dinov2_meh"
--epochs 100 \
--blr 5e-3 --layer_decay 0.65 \ # ==== Data settings ====
--weight_decay 0.05 --drop_path 0.2 \ # change the dataset name and corresponding class number
--nb_classes 5 \ DATASET="MESSIDOR2"
--data_path ./IDRiD \ NUM_CLASS=5
--input_size 224 \
--task RETFound_mae_meh-IDRiD \ # =======================
--finetune RETFound_mae_meh DATA_PATH="PATH TO THE DATASET"
TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--finetune "${FINETUNE}" \
--savemodel \
--global_pool \
--batch_size 24 \
--world_size 1 \
--epochs 50 \
--nb_classes "${NUM_CLASS}" \
--data_path "${DATA_PATH}" \
--input_size 224 \
--task "${TASK}" \
--adaptation "${ADAPTATION}"
``` ```
4. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the path below)
6. For evaluation only (download data and model checkpoints [here](BENCHMARK.md); change the DATA_PATH below)
``` ```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \ # ==== Model/settings (match training) ====
--model RETFound_mae \ ADAPTATION="finetune"
--savemodel \ MODEL="RETFound_dinov2"
--eval \ MODEL_ARCH="retfound_dinov2"
--global_pool \ FINETUNE="RETFound_dinov2_meh"
--batch_size 16 \
--world_size 1 \ # ==== Data/settings (match training) ====
--epochs 100 \ DATASET="MESSIDOR2"
--blr 5e-3 --layer_decay 0.65 \ NUM_CLASS=5
--weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \ # =======================
--data_path ./IDRiD \ DATA_PATH="PATH TO THE DATASET"
--input_size 224 \ TASK="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
--task RETFound_mae_meh-IDRiD \
--resume ./RETFound_mae_meh-IDRiD/checkpoint-best.pth # Path to the trained checkpoint (adjust if you saved elsewhere)
CKPT="./output_dir/${TASK}/checkpoint-best.pth"
# ==== Evaluation only ====
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--savemodel \
--global_pool \
--batch_size 128 \
--world_size 1 \
--nb_classes "${NUM_CLASS}" \
--data_path "${DATA_PATH}" \
--input_size 224 \
--task "${TASK}" \
--adaptation "${ADAPTATION}" \
--eval \
--resume "${CKPT}"
``` ```
@@ -175,9 +245,6 @@ torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
If you find this repository useful, please consider citing this paper: If you find this repository useful, please consider citing this paper:
```
TBD
```
``` ```
@article{zhou2023foundation, @article{zhou2023foundation,
@@ -192,4 +259,14 @@ TBD
} }
``` ```
```
@misc{zhou2025generalistversusspecialistvision,
title={Generalist versus Specialist Vision Foundation Models for Ocular Disease and Oculomics},
author={Yukun Zhou and Paul Nderitu and Jocelyn Hui Lin Goh and Justin Engelmann and Siegfried K. Wagner and Anran Ran and Hongyang Jiang and Lie Ju and Ke Zou and Sahana Srinivasan and Hyunmin Kim and Takahiro Ninomiya and Zheyuan Wang and Gabriel Dawei Yang and Eden Ruffell and Dominic Williamson and Rui Santos and Gabor Mark Somfai and Carol Y. Cheung and Tien Yin Wong and Daniel C. Alexander and Yih Chung Tham and Pearse A. Keane},
year={2025},
eprint={2509.03421},
archivePrefix={arXiv},
primaryClass={eess.IV},
url={https://arxiv.org/abs/2509.03421},
}
```
+223
View File
@@ -0,0 +1,223 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "76b39fb1",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Jupyter notebook example - Classification task\n",
"### Example using [MESSIDOR2](https://www.adcis.net/en/third-party/messidor2/) dataset\n",
"**Application**: Using RETFound for five-category diabetic retinopathy classification\n",
"\n",
"**Author**: Yukun Zhou\n",
"\n",
"**Date**: 30 Nov 2025\n",
"\n",
"**Performance**:\n",
"\n",
"<table align=\"left\">\n",
"<tr>\n",
" <th>Accuracy</th>\n",
" <th>Recall</th>\n",
" <th>F1 Score</th>\n",
" <th>ROC AUC</th>\n",
" <th>PR AUC</th>\n",
"</tr>\n",
"<tr>\n",
" <td>0.7091</td>\n",
" <td>0.5616</td>\n",
" <td>0.6078</td>\n",
" <td>0.9037</td>\n",
" <td>0.6863</td>\n",
"</tr>\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"id": "7ec435a7",
"metadata": {},
"source": [
"## 1. Install environment\n",
"1. Follow [RETFound README](https://github.com/rmaphoh/RETFound) to install environment\n",
"2. Restart this Jupyter Notebook\n",
"3. Select Kernel retfound"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7cbf5e93-6ca0-4401-88e6-64e39968e7cd",
"metadata": {},
"outputs": [],
"source": [
"import sys, torch\n",
"from pathlib import Path\n",
"import os\n",
"\n",
"PROJECT_ROOT = Path.cwd().resolve()\n",
"\n",
"if PROJECT_ROOT.name == 'examples': PROJECT_ROOT = PROJECT_ROOT.parent\n",
"os.chdir(PROJECT_ROOT)\n",
"\n",
"print('Project root:', PROJECT_ROOT)\n",
"print(\"sys.executable:\", sys.executable)\n",
"print(\"torch version:\", torch.__version__)"
]
},
{
"cell_type": "markdown",
"id": "ed67953f",
"metadata": {},
"source": [
"## 2. Prepare MESSIDOR2 dataset\n",
"1. Download from the [shared data pool](https://github.com/rmaphoh/RETFound/blob/main/BENCHMARK.md).\n",
"2. Put the data folder under the project directory, e.g. \"RETFound/MESSIDOR2\"\n"
]
},
{
"cell_type": "markdown",
"id": "357be2fa-a914-4d1f-8759-76b2b1c3f20f",
"metadata": {},
"source": [
"## 3. Hyperparameter and path settings\n",
"1. Can choose finetune or lp (linear probe)\n",
"2. Model selection [info](https://github.com/rmaphoh/RETFound#:~:text=In%20train.sh%2C%20the%20model%20can%20be%20selected%20by%20changing%20the%20hyperparameters%20MODEL%2C%20MODEL_ARCH%2C%20FINETUNE%3A)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f675843",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"ADAPTATION='finetune'\n",
"MODEL='RETFound_dinov2'\n",
"MODEL_ARCH='retfound_dinov2'\n",
"FINETUNE='RETFound_dinov2_meh'\n",
"DATASET='MESSIDOR2'\n",
"NUM_CLASS=5\n",
"DATA_PATH=PROJECT_ROOT/DATASET\n",
"BATCH_SIZE=24\n",
"EPOCHS=50\n",
"INPUT_SIZE=224\n",
"WORLD_SIZE=1\n",
"TASK=f\"{MODEL_ARCH}_{DATASET}_{ADAPTATION}\"\n",
"OUTPUT_DIR=PROJECT_ROOT/'output_dir'/TASK\n",
"print('DATA_PATH:',DATA_PATH)\n",
"print('TASK:',TASK)\n",
"print('OUTPUT_DIR:',OUTPUT_DIR)"
]
},
{
"cell_type": "markdown",
"id": "6ac04845",
"metadata": {},
"source": [
"## 4. Fine-tuning and testing RETFound on MESSIDOR2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d23ff751",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"!{sys.executable} main_finetune.py \\\n",
" --model {MODEL} \\\n",
" --model_arch {MODEL_ARCH} \\\n",
" --finetune {FINETUNE} \\\n",
" --savemodel \\\n",
" --global_pool \\\n",
" --batch_size {BATCH_SIZE} \\\n",
" --epochs {EPOCHS} \\\n",
" --nb_classes {NUM_CLASS} \\\n",
" --data_path {DATA_PATH} \\\n",
" --input_size {INPUT_SIZE} \\\n",
" --task {TASK} \\\n",
" --adaptation {ADAPTATION}"
]
},
{
"cell_type": "markdown",
"id": "84ce93ac",
"metadata": {},
"source": [
"## 5. Evaluation-only"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0af0f8a7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"CKPT = OUTPUT_DIR / \"checkpoint-best.pth\"\n",
"\n",
"!{sys.executable} main_finetune.py \\\n",
" --model {MODEL} \\\n",
" --model_arch {MODEL_ARCH} \\\n",
" --finetune {FINETUNE} \\\n",
" --savemodel \\\n",
" --global_pool \\\n",
" --batch_size 128 \\\n",
" --nb_classes {NUM_CLASS} \\\n",
" --data_path {DATA_PATH} \\\n",
" --input_size {INPUT_SIZE} \\\n",
" --task {TASK} \\\n",
" --adaptation {ADAPTATION} \\\n",
" --eval \\\n",
" --resume {CKPT}\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02d2dce7-31c2-48e2-87ce-9223b74cf94e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "retfound",
"name": "workbench-notebooks.m128",
"type": "gcloud",
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
},
"kernelspec": {
"display_name": "retfound_jupyter (Local)",
"language": "python",
"name": "retfound"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+7 -7
View File
@@ -22,18 +22,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": null,
"id": "90c3d964", "id": "90c3d964",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n", "def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
" \n", " \n",
" # load model\n", " # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", " checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
" \n", " \n",
" # build model\n", " # build model\n",
" if arch=='vit_large_patch16':\n", " if arch=='RETFound_mae':\n",
" model = models.__dict__[arch](\n", " model = models.__dict__[arch](\n",
" img_size=224,\n", " img_size=224,\n",
" num_classes=5,\n", " num_classes=5,\n",
@@ -70,7 +70,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": null,
"id": "9a250363", "id": "9a250363",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@@ -78,7 +78,7 @@
"def get_feature(data_path,\n", "def get_feature(data_path,\n",
" chkpt_dir,\n", " chkpt_dir,\n",
" device,\n", " device,\n",
" arch='vit_large_patch16'):\n", " arch='RETFound_mae'):\n",
" #loading model\n", " #loading model\n",
" model_ = prepare_model(chkpt_dir, arch)\n", " model_ = prepare_model(chkpt_dir, arch)\n",
" model_.to(device)\n", " model_.to(device)\n",
@@ -121,7 +121,7 @@
"source": [ "source": [
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n", "chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
"data_path = 'DATA_PATH'\n", "data_path = 'DATA_PATH'\n",
"device = torch.device('cuda')\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"arch='dinov2_large'" "arch='dinov2_large'"
] ]
}, },
+290 -249
View File
@@ -1,349 +1,385 @@
#!/usr/bin/env python3
# =========================
import argparse import argparse
import datetime import datetime
import json import json
import numpy as np
import os import os
import time import time
from pathlib import Path from pathlib import Path
import warnings
import faulthandler
# =========================
import numpy as np
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup from timm.data.mixup import Mixup
from huggingface_hub import hf_hub_download, login # login imported as in original
# =========================
import models_vit as models import models_vit as models
import util.lr_decay as lrd import util.lr_decay as lrd
import util.misc as misc import util.misc as misc
from util.datasets import build_dataset from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate from engine_finetune import train_one_epoch, evaluate
import warnings # =========================
import faulthandler
faulthandler.enable() faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action="ignore", category=FutureWarning)
def get_args_parser(): def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) parser = argparse.ArgumentParser(
parser.add_argument('--batch_size', default=128, type=int, "MAE fine-tuning / linear probing for image classification", add_help=False
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') )
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters # ---- Core training
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', parser.add_argument("--batch_size", default=128, type=int,
help='Name of model to train') help="Batch size per GPU (effective batch size = batch_size * accum_iter * #gpus)")
parser.add_argument('--input_size', default=256, type=int, parser.add_argument("--epochs", default=50, type=int)
help='images input size') parser.add_argument("--accum_iter", default=1, type=int,
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT', help="Gradient accumulation steps")
help='Drop path rate (default: 0.1)')
# Optimizer parameters # ---- Model parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', parser.add_argument("--model", default="vit_large_patch16", type=str, metavar="MODEL",
help='Clip gradient norm (default: None, no clipping)') help="Model entry in models_vit.py")
parser.add_argument('--weight_decay', type=float, default=0.05, parser.add_argument("--model_arch", default="dinov3_vits16", type=str, metavar="MODEL_ARCH",
help='weight decay (default: 0.05)') help="Backbone architecture key (e.g., dinov2_vitl14, convnext_base, etc.)")
parser.add_argument('--lr', type=float, default=None, metavar='LR', parser.add_argument("--input_size", default=256, type=int, help="Image size")
help='learning rate (absolute lr)') parser.add_argument("--drop_path", type=float, default=0.2, metavar="PCT", help="Drop path rate")
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR', parser.add_argument("--global_pool", action="store_true"); parser.set_defaults(global_pool=True)
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument("--cls_token", action="store_false", dest="global_pool",
parser.add_argument('--layer_decay', type=float, default=0.65, help="Use class token instead of global pool for classification")
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters # ---- Optimizer parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', parser.add_argument("--clip_grad", type=float, default=None, metavar="NORM", help="Clip grad norm")
help='Color jitter factor (enabled only when not using Auto/RandAug)') parser.add_argument("--weight_decay", type=float, default=0.05, help="Weight decay")
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', parser.add_argument("--lr", type=float, default=None, metavar="LR", help="Absolute LR (overrides blr)")
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument("--blr", type=float, default=5e-3, metavar="LR",
parser.add_argument('--smoothing', type=float, default=0.1, help="Base LR: lr = blr * total_batch_size / 256")
help='Label smoothing (default: 0.1)') parser.add_argument("--layer_decay", type=float, default=0.65, help="Layer-wise LR decay (ViT)")
parser.add_argument("--min_lr", type=float, default=1e-6, metavar="LR", help="Lower LR bound")
parser.add_argument("--warmup_epochs", type=int, default=10, metavar="N", help="Warmup epochs")
# * Random Erase params # ---- Augmentation
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', parser.add_argument("--color_jitter", type=float, default=None, metavar="PCT")
help='Random erase prob (default: 0.25)') parser.add_argument("--aa", type=str, default="rand-m9-mstd0.5-inc1", metavar="NAME")
parser.add_argument('--remode', type=str, default='pixel', parser.add_argument("--smoothing", type=float, default=0.1)
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params # ---- Random erase
parser.add_argument('--mixup', type=float, default=0, parser.add_argument("--reprob", type=float, default=0.25, metavar="PCT")
help='mixup alpha, mixup enabled if > 0.') parser.add_argument("--remode", type=str, default="pixel")
parser.add_argument('--cutmix', type=float, default=0, parser.add_argument("--recount", type=int, default=1)
help='cutmix alpha, cutmix enabled if > 0.') parser.add_argument("--resplit", action="store_true", default=False)
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params # ---- Mixup/Cutmix
parser.add_argument('--finetune', default='', type=str, parser.add_argument("--mixup", type=float, default=0.0)
help='finetune from checkpoint') parser.add_argument("--cutmix", type=float, default=0.0)
parser.add_argument('--task', default='', type=str, parser.add_argument("--cutmix_minmax", type=float, nargs="+", default=None)
help='finetune from checkpoint') parser.add_argument("--mixup_prob", type=float, default=1.0)
parser.add_argument('--global_pool', action='store_true') parser.add_argument("--mixup_switch_prob", type=float, default=0.5)
parser.set_defaults(global_pool=True) parser.add_argument("--mixup_mode", type=str, default="batch")
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters # ---- Finetuning & adaptation
parser.add_argument('--data_path', default='./data/', type=str, parser.add_argument("--finetune", default="", type=str, help="Checkpoint id/path (see model rules below)")
help='dataset path') parser.add_argument("--task", default="", type=str, help="Task name for logging/output grouping")
parser.add_argument('--nb_classes', default=8, type=int, parser.add_argument("--adaptation", default="finetune", choices=["finetune", "lp"],
help='number of the classification types') help="Adaptation strategy: finetune=full fine-tune, lp=linear probe (train head only)")
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.set_defaults(pin_mem=True)
# distributed training parameters # ---- Dataset & paths
parser.add_argument('--world_size', default=1, type=int, parser.add_argument("--data_path", default="./data/", type=str)
help='number of distributed processes') parser.add_argument("--nb_classes", default=8, type=int)
parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument("--output_dir", default="./output_dir")
parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument("--log_dir", default="./output_logs")
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# fine-tuning parameters # >>> NEW: training data efficiency <<<
parser.add_argument('--savemodel', action='store_true', default=True, parser.add_argument(
help='Save model') "--dataratio", type=str, default="1.0",
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method') help=('Training data ratio(s) for subsampling in build_dataset. '
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data') 'Use a single float in (0,1] (e.g., 0.25) or a comma-separated list '
parser.add_argument('--datasets_seed', default=2026, type=int) '(e.g., "1.0,0.5,0.25") if your build_dataset supports sweeps.')
)
parser.add_argument(
"--stratified", action="store_true",
help="If set, subsample training data in a class-stratified manner (requires support in build_dataset)."
)
# ---- Runtime
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--resume", default="", help="Resume full state (optimizer, scaler, etc.)")
parser.add_argument("--start_epoch", default=0, type=int, metavar="N")
parser.add_argument("--eval", action="store_true", help="Evaluation only")
parser.add_argument("--dist_eval", action="store_true", default=False,
help="Distributed evaluation (faster monitoring during training)")
parser.add_argument("--num_workers", default=10, type=int)
parser.add_argument("--pin_mem", action="store_true"); parser.set_defaults(pin_mem=True)
# ---- Distributed
parser.add_argument("--world_size", default=1, type=int)
parser.add_argument("--local_rank", default=-1, type=int)
parser.add_argument("--dist_on_itp", action="store_true")
parser.add_argument("--dist_url", default="env://")
# ---- Misc
parser.add_argument("--savemodel", action="store_true", default=True, help="Save best model")
parser.add_argument("--norm", default="IMAGENET", type=str)
parser.add_argument("--enhance", action="store_true", default=False)
parser.add_argument("--datasets_seed", default=2026, type=int)
return parser return parser
# =========================
# Main
# =========================
def main(args, criterion): def main(args, criterion):
# ---- Optionally load args from resume (when training)
if args.resume and not args.eval: if args.resume and not args.eval:
resume = args.resume resume_path = args.resume
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location="cpu")
print("Load checkpoint from: %s" % args.resume) print(f"Load checkpoint (args) from: {args.resume}")
args = checkpoint['args'] args = checkpoint["args"]
args.resume = resume args.resume = resume_path
# ---- Distributed setup
misc.init_distributed_mode(args) misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}")
print("{}".format(args).replace(', ', ',\n')) print(f"{args}".replace(", ", ",\n"))
device = torch.device(args.device) device = torch.device(args.device)
# fix the seed for reproducibility # ---- Reproducibility
seed = args.seed + misc.get_rank() seed = args.seed + misc.get_rank()
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
cudnn.benchmark = True cudnn.benchmark = True
if args.model=='RETFound_mae': # ---- Build model
if args.model == "RETFound_mae":
model = models.__dict__[args.model]( model = models.__dict__[args.model](
img_size=args.input_size, img_size=args.input_size,
num_classes=args.nb_classes, num_classes=args.nb_classes,
drop_path_rate=args.drop_path, drop_path_rate=args.drop_path,
global_pool=args.global_pool, global_pool=args.global_pool,
) )
else: else:
model = models.__dict__[args.model]( model = models.__dict__[args.model](
num_classes=args.nb_classes, num_classes=args.nb_classes,
drop_path_rate=args.drop_path, drop_path_rate=args.drop_path,
args=args, args=args,
) )
if args.finetune and not args.eval:
print(f"Downloading pre-trained weights from: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}',
filename=f'{args.finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
# ---- Load pre-trained weights (if requested and not eval-only)
if args.finetune and not args.eval:
print(f"Preparing to load pre-trained weights: {args.finetune}")
if args.model in ["Dinov3", "Dinov2"]:
checkpoint_path = args.finetune # local path
elif args.model in ["RETFound_dinov2", "RETFound_mae"]:
print(f"Downloading pre-trained weights from Hugging Face Hub: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f"YukunZhou/{args.finetune}",
filename=f"{args.finetune}.pth",
)
else:
raise ValueError(
f"Unsupported model '{args.model}'. "
f"Expected one of: Dinov3, Dinov2, RETFound_dinov2, RETFound_mae"
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Loaded pre-trained checkpoint from: {checkpoint_path}")
if args.model in ["Dinov3", "Dinov2"]:
checkpoint_model = checkpoint
elif args.model == "RETFound_dinov2":
checkpoint_model = checkpoint["teacher"]
else: # RETFound_mae
checkpoint_model = checkpoint["model"]
# -- Key hygiene
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()} checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
# -- Remove classifier if shape mismatched
state_dict = model.state_dict() state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']: for k in ["head.weight", "head.bias"]:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint") print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k] del checkpoint_model[k]
# interpolate position embedding # -- Interpolate pos embed (ViT)
interpolate_pos_embed(model, checkpoint_model) interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model # -- Load backbone weights (non-strict)
msg = model.load_state_dict(checkpoint_model, strict=False) _ = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5) # -- Re-init head
if hasattr(model, "head") and hasattr(model.head, "weight"):
trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args) # ---- Datasets & samplers
dataset_val = build_dataset(is_train='val', args=args) dataset_train = build_dataset(is_train="train", args=args)
dataset_test = build_dataset(is_train='test', args=args) dataset_val = build_dataset(is_train="val", args=args)
dataset_test = build_dataset(is_train="test", args=args)
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if True: # args.distributed: if not args.eval:
num_tasks = misc.get_world_size() sampler_train = torch.utils.data.DistributedSampler(
global_rank = misc.get_rank() dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
if not args.eval: )
sampler_train = torch.utils.data.DistributedSampler( print(f"Sampler_train = {sampler_train}")
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print(
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval: if args.dist_eval:
if len(dataset_test) % num_tasks != 0: if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' print("Warning: dist eval with dataset not divisible by #procs; results may differ slightly.")
'This will slightly alter validation results as extra duplicate entries are added to achieve ' sampler_val = torch.utils.data.DistributedSampler(
'equal num of samples per-process.') dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
sampler_test = torch.utils.data.DistributedSampler( )
dataset_test, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else: else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test) sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval:
if len(dataset_test) % num_tasks != 0:
print("Warning: dist eval test set not divisible by #procs; results may differ slightly.")
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
# ---- Logging
if global_rank == 0 and args.log_dir is not None and not args.eval: if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task)) log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else: else:
log_writer = None log_writer = None
# ---- DataLoaders
if not args.eval: if not args.eval:
data_loader_train = torch.utils.data.DataLoader( data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train, dataset_train, sampler=sampler_train,
batch_size=args.batch_size, batch_size=args.batch_size, num_workers=args.num_workers,
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True,
pin_memory=args.pin_mem,
drop_last=True,
) )
print(f"len of train_set: {len(data_loader_train) * args.batch_size}")
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
data_loader_val = torch.utils.data.DataLoader( data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val, dataset_val, sampler=sampler_val,
batch_size=args.batch_size, batch_size=args.batch_size, num_workers=args.num_workers,
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False,
pin_memory=args.pin_mem,
drop_last=False
) )
data_loader_test = torch.utils.data.DataLoader( data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test, dataset_test, sampler=sampler_test,
batch_size=args.batch_size, batch_size=args.batch_size, num_workers=args.num_workers,
num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False,
pin_memory=args.pin_mem,
drop_last=False
) )
# ---- Mixup/CutMix
mixup_fn = None mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = (args.mixup > 0) or (args.cutmix > 0.) or (args.cutmix_minmax is not None)
if mixup_active: if mixup_active:
print("Mixup is activated!") print("Mixup is activated!")
mixup_fn = Mixup( mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes) label_smoothing=args.smoothing, num_classes=args.nb_classes
)
# ---- Eval-only: resume weights
if args.resume and args.eval: if args.resume and args.eval:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location="cpu")
print("Load checkpoint from: %s" % args.resume) print(f"Load checkpoint for eval from: {args.resume}")
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint["model"])
model.to(device) model.to(device)
model_without_ddp = model model_without_ddp = model
# ---- Adaptation toggle
if args.adaptation == "lp":
for name, param in model.named_parameters():
param.requires_grad = ("head" in name)
print("[Adaptation] Linear probe: training classifier head only.")
else:
print("[Adaptation] Full fine-tuning: training all parameters.")
# ---- Count trainable params
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of model params (M): %.2f' % (n_parameters / 1.e6)) print(f"number of trainable params (M): {n_parameters / 1.e6:.2f}")
# ---- LR scaling by effective batch size
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None:
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256 args.lr = args.blr * eff_batch_size / 256
print(f"base lr: {args.lr * 256 / eff_batch_size:.2e}")
print(f"actual lr: {args.lr:.2e}")
print(f"accumulate grad iterations: {args.accum_iter}")
print(f"effective batch size: {eff_batch_size}")
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) # ---- DDP (if available)
print("actual lr: %.2e" % args.lr) if args.distributed and torch.cuda.device_count() > 1:
ddp_kwargs = {}
print("accumulate grad iterations: %d" % args.accum_iter) if args.adaptation == "lp":
print("effective batch size: %d" % eff_batch_size) ddp_kwargs["find_unused_parameters"] = True
model = torch.nn.parallel.DistributedDataParallel(
if args.distributed: model, device_ids=[args.gpu], **ddp_kwargs
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) )
model_without_ddp = model.module model_without_ddp = model.module
else:
model_without_ddp = model # single-GPU
# ---- Optimizer param groups (after freezing)
no_weight_decay = (model_without_ddp.no_weight_decay()
if hasattr(model_without_ddp, "no_weight_decay") else [])
param_groups = lrd.param_groups_lrd(
model_without_ddp,
weight_decay=args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay,
)
for g in param_groups:
g["params"] = [p for p in g["params"] if p.requires_grad]
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr) optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler() loss_scaler = NativeScaler()
print(f"criterion = {criterion}")
print("criterion = %s" % str(criterion)) # ---- Load previous full state (optimizer, scaler, etc.)
misc.load_model(args=args, model_without_ddp=model_without_ddp,
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) optimizer=optimizer, loss_scaler=loss_scaler)
# =========================
# Eval-only Short Circuit
# =========================
if args.eval: if args.eval:
if 'epoch' in checkpoint: if "checkpoint" in locals() and isinstance(checkpoint, dict) and ("epoch" in checkpoint):
print("Test with the best model at epoch = %d" % checkpoint['epoch']) print(f"Test with the best model at epoch = {checkpoint['epoch']}")
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test', test_stats, auc_roc = evaluate(
num_class=args.nb_classes, log_writer=log_writer) data_loader_test, model, device, args, epoch=0, mode="test",
exit(0) num_class=args.nb_classes, log_writer=log_writer
)
return
# =========================
# Train Loop
# =========================
print(f"Start training for {args.epochs} epochs") print(f"Start training for {args.epochs} epochs")
start_time = time.time() start_time = time.time()
max_score = 0.0 max_score = 0.0
best_epoch = 0 best_epoch = 0
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
data_loader_train.sampler.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch)
@@ -352,49 +388,55 @@ def main(args, criterion):
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn, args.clip_grad, mixup_fn,
log_writer=log_writer, log_writer=log_writer, args=args
args=args )
val_stats, val_score = evaluate(
data_loader_val, model, device, args, epoch, mode="val",
num_class=args.nb_classes, log_writer=log_writer
) )
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score: if max_score < val_score:
max_score = val_score max_score = val_score
best_epoch = epoch best_epoch = epoch
if args.output_dir and args.savemodel: if args.output_dir and args.savemodel:
misc.save_model( misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, args=args, model=model, model_without_ddp=model_without_ddp,
loss_scaler=loss_scaler, epoch=epoch, mode='best') optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, mode="best"
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score)) )
print(f"Best epoch = {best_epoch}, Best score = {max_score:.4f}")
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None: if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch) log_writer.add_scalar("loss/val", val_stats["loss"], epoch)
log_writer.flush()
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, log_stats = {**{f"train_{k}": v for k, v in train_stats.items()},
'epoch': epoch, "epoch": epoch,
'n_parameters': n_parameters} "n_parameters": n_parameters}
if args.output_dir and misc.is_main_process(): if args.output_dir and misc.is_main_process():
if log_writer is not None: with open(os.path.join(args.output_dir, args.task, "log.txt"), "a", encoding="utf-8") as f:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n") f.write(json.dumps(log_stats) + "\n")
# =========================
# Final Test (Best Ckpt)
# =========================
ckpt_path = os.path.join(args.output_dir, args.task, "checkpoint-best.pth")
checkpoint = torch.load(ckpt_path, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
print(f"Test with the best model, epoch = {checkpoint.get('epoch', -1)}:")
_test_stats, _auc_roc = evaluate(
data_loader_test, model, device, args, -1, mode="test",
num_class=args.nb_classes, log_writer=None
)
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str)) print(f"Training time {total_time_str}")
if __name__ == '__main__': if __name__ == "__main__":
args = get_args_parser() args = get_args_parser()
args = args.parse_args() args = args.parse_args()
@@ -402,6 +444,5 @@ if __name__ == '__main__':
if args.output_dir: if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion) main(args, criterion)
-414
View File
@@ -1,414 +0,0 @@
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
import models_vit as models
import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate
import warnings
import faulthandler
faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning)
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=256, type=int,
help='images input size')
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: 0.1)')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.65,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--task', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='./data/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=8, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# fine-tuning parameters
parser.add_argument('--savemodel', action='store_true', default=True,
help='Save model')
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
parser.add_argument('--datasets_seed', default=2026, type=int)
return parser
def main(args, criterion):
if args.resume and not args.eval:
resume = args.resume
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
args = checkpoint['args']
args.resume = resume
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
if args.model=='RETFound_mae':
model = models.__dict__[args.model](
img_size=args.input_size,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
else:
model = models.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
args=args,
)
if args.finetune and not args.eval:
print(f"Downloading pre-trained weights from: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}',
filename=f'{args.finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args)
dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if not args.eval:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print(
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval:
if len(dataset_test) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else:
log_writer = None
if not args.eval:
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
if args.resume and args.eval:
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
model.load_state_dict(checkpoint['model'])
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
model_without_ddp = model.module
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler()
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
for name, param in model.named_parameters():
if 'head' in name:
param.requires_grad = True
else:
param.requires_grad = False
if args.eval:
if 'epoch' in checkpoint:
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
num_class=args.nb_classes, log_writer=log_writer)
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_score = 0.0
best_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn,
log_writer=log_writer,
args=args
)
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score:
max_score = val_score
best_epoch = epoch
if args.output_dir and args.savemodel:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, mode='best')
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
criterion = torch.nn.CrossEntropyLoss()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion)
+41 -5
View File
@@ -1,7 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
from functools import partial from functools import partial
@@ -10,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from timm.models.layers import trunc_normal_
class VisionTransformer(timm.models.vision_transformer.VisionTransformer): class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling """ Vision Transformer with support for global average pooling
@@ -56,6 +52,30 @@ def RETFound_mae(**kwargs):
def Dinov2(args, **kwargs):
if args.model_arch == 'dinov2_vits14':
arch = 'vit_small_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitb14':
arch = 'vit_base_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitl14':
arch = 'vit_large_patch14_dinov2.lvd142m'
elif args.model_arch == 'dinov2_vitg14':
arch = 'vit_giant_patch14_dinov2.lvd142m'
else:
raise ValueError(f"Unknown model_arch '{args.model_arch}'. "
f"Expected one of: dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14")
model = timm.create_model(
arch,
pretrained=True,
img_size=224,
**kwargs
)
return model
def RETFound_dinov2(args, **kwargs): def RETFound_dinov2(args, **kwargs):
model = timm.create_model( model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m', 'vit_large_patch14_dinov2.lvd142m',
@@ -66,4 +86,20 @@ def RETFound_dinov2(args, **kwargs):
return model return model
def Dinov3(args, **kwargs):
# Load ViT-L/16 backbone (hub model has `head = Identity` by default)
model = torch.hub.load(
repo_or_dir="facebookresearch/dinov3",
model=args.model_arch,
pretrained=False, # main() will load your checkpoint
trust_repo=True,
)
# Figure out feature dimension for the probe
feat_dim = getattr(model, "embed_dim", None) or getattr(model, "num_features", None)
model.head = nn.Linear(feat_dim, args.nb_classes)
trunc_normal_(model.head.weight, std=2e-5)
if model.head.bias is not None:
nn.init.zeros_(model.head.bias)
return model
+28
View File
@@ -0,0 +1,28 @@
# ==== Model settings ====
# adaptation {finetune,lp}
ADAPTATION="finetune"
MODEL="RETFound_dinov2"
MODEL_ARCH="retfound_dinov2"
FINETUNE="RETFound_dinov2_meh"
# ==== Data settings ====
# change the dataset name and corresponding class number
DATASET="MESSIDOR2"
NUM_CLASS=5
data_path="./${DATASET}"
task="${MODEL_ARCH}_${DATASET}_${ADAPTATION}"
torchrun --nproc_per_node=1 --master_port=48766 main_finetune.py \
--model "${MODEL}" \
--model_arch "${MODEL_ARCH}" \
--finetune "${FINETUNE}" \
--savemodel \
--global_pool \
--batch_size 24 \
--world_size 1 \
--epochs 50 \
--nb_classes "${NUM_CLASS}" \
--data_path "${data_path}" \
--input_size 224 \
--task "${task}" \
--adaptation "${ADAPTATION}"
+49 -21
View File
@@ -1,29 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
import os import os
import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms from torchvision import datasets, transforms
from timm.data import create_transform from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args): def build_dataset(is_train, args):
transform = build_transform(is_train, args) transform = build_transform(is_train, args)
root = os.path.join(args.data_path, is_train) root = os.path.join(args.data_path, is_train)
dataset = datasets.ImageFolder(root, transform=transform) dataset = datasets.ImageFolder(root, transform=transform)
return dataset if is_train == 'train':
ratio = float(getattr(args, "dataratio", 1.0))
seed = int(getattr(args, "seed", 0))
stratified = bool(getattr(args, "stratified", False))
if 0.0 < ratio < 1.0:
if stratified:
idx = _stratified_indices(dataset.targets, ratio, seed)
else:
# simple uniform subsample with torch.Generator for reproducibility
g = torch.Generator().manual_seed(seed)
n = len(dataset)
k = max(1, int(n * ratio))
idx = torch.randperm(n, generator=g)[:k].tolist()
dataset = Subset(dataset, idx)
return dataset
def build_transform(is_train, args): def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD std = IMAGENET_DEFAULT_STD
# train transform
if is_train == 'train': if is_train == 'train':
# this should always dispatch to transforms_imagenet_train return create_transform(
transform = create_transform(
input_size=args.input_size, input_size=args.input_size,
is_training=True, is_training=True,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
@@ -35,19 +45,37 @@ def build_transform(is_train, args):
mean=mean, mean=mean,
std=std, std=std,
) )
return transform
# eval transform # eval transform
t = [] crop_pct = 224 / 256 if args.input_size <= 224 else 1.0
if args.input_size <= 224:
crop_pct = 224 / 256
else:
crop_pct = 1.0
size = int(args.input_size / crop_pct) size = int(args.input_size / crop_pct)
t.append( t = [
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
) transforms.CenterCrop(args.input_size),
t.append(transforms.CenterCrop(args.input_size)) transforms.ToTensor(),
t.append(transforms.ToTensor()) transforms.Normalize(mean, std),
t.append(transforms.Normalize(mean, std)) ]
return transforms.Compose(t) return transforms.Compose(t)
# ---- helpers ----
def _stratified_indices(targets, ratio: float, seed: int):
"""Maintain class proportions. Ensures at least 1 sample per class when possible."""
t = torch.as_tensor(targets)
classes = torch.unique(t)
g = torch.Generator().manual_seed(seed)
keep = []
for c in classes.tolist():
cls_idx = torch.nonzero(t == c, as_tuple=False).view(-1)
if len(cls_idx) == 0:
continue
k = max(1, int(round(len(cls_idx) * ratio)))
sel = cls_idx[torch.randperm(len(cls_idx), generator=g)[:k]]
keep.extend(sel.tolist())
# shuffle final indices (stable across seed)
g2 = torch.Generator().manual_seed(seed + 1)
keep = torch.tensor(keep)[torch.randperm(len(keep), generator=g2)].tolist()
return keep