major package upgrade&new weights
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
## RETFound - A foundation model for retinal imaging
|
||||
|
||||
|
||||
Official repo for [RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae):
|
||||
|
||||
Official repo including a series of retinal foundation models.
|
||||
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae):
|
||||
[New checkpoints](https://www.nature.com/articles/s41586-023-06555-x), which is based on [DINOV2](https://github.com/facebookresearch/dinov2):
|
||||
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
|
||||
|
||||
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
|
||||
@@ -17,6 +18,9 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
|
||||
|
||||
### 🎉News
|
||||
|
||||
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
|
||||
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
|
||||
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
|
||||
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_Feature.ipynb) are now online!
|
||||
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
|
||||
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
|
||||
@@ -29,16 +33,17 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
|
||||
1. Create environment with conda:
|
||||
|
||||
```
|
||||
conda create -n retfound python=3.7.5 -y
|
||||
conda create -n retfound python=3.11.0 -y
|
||||
conda activate retfound
|
||||
```
|
||||
|
||||
2. Install dependencies
|
||||
|
||||
```
|
||||
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
|
||||
git clone https://github.com/rmaphoh/RETFound_MAE/
|
||||
cd RETFound_MAE
|
||||
pip install -r requirement.txt
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
@@ -46,23 +51,51 @@ pip install -r requirement.txt
|
||||
|
||||
To fine tune RETFound on your own data, follow these steps:
|
||||
|
||||
1. Download the RETFound pre-trained weights
|
||||
1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
|
||||
<table><tbody>
|
||||
<!-- START TABLE -->
|
||||
<!-- TABLE HEADER -->
|
||||
<th valign="bottom"></th>
|
||||
<th valign="bottom">ViT-Large</th>
|
||||
<th valign="bottom">Source</th>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">Colour fundus image</td>
|
||||
<td align="center"><a href="https://drive.google.com/file/d/1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE/view?usp=sharing">download</a></td>
|
||||
<tr><td align="left">RETFound_mae_natureCFP</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_natureCFP">access</a></td>
|
||||
<td align="center"><a href="https://www.nature.com/articles/s41586-023-06555-x">Nature RETFound paper</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">OCT</td>
|
||||
<td align="center"><a href="https://drive.google.com/file/d/1m6s7QYkjyjJDlpEuXm7Xp3PmjN-elfW2/view?usp=sharing">download</a></td>
|
||||
<tr><td align="left">RETFound_mae_natureOCT</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_natureOCT">access</a></td>
|
||||
<td align="center"><a href="https://www.nature.com/articles/s41586-023-06555-x">Nature RETFound paper</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_mae_meh</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_mae_shanghai</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_dinov2_meh</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
</tr>
|
||||
<!-- TABLE BODY -->
|
||||
<tr><td align="left">RETFound_dinov2_shanghai</td>
|
||||
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">download</a></td>
|
||||
<td align="center">TBD</a></td>
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
2. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
|
||||
2. Login in your HuggingFace account, where HuggingFace token can be [created and copied](https://huggingface.co/settings/tokens).
|
||||
```
|
||||
huggingface-cli login --token YOUR_HUGGINGFACE_TOKEN
|
||||
```
|
||||
|
||||
3. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
|
||||
|
||||
```
|
||||
├── data folder
|
||||
@@ -80,23 +113,29 @@ To fine tune RETFound on your own data, follow these steps:
|
||||
├──class_c
|
||||
```
|
||||
|
||||
3. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training.
|
||||
|
||||
4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
model can be "RETFound_mae" or "RETFound_dinov2"
|
||||
```
|
||||
```
|
||||
finetune can be "RETFound_mae_natureOCT", "RETFound_mae_natureCFP", "RETFound_mae_meh", "RETFound_mae_shanghai", "RETFound_dinov2_meh", and "RETFound_dinov2_shanghai".
|
||||
```
|
||||
```
|
||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
--model RETFound_mae \
|
||||
--savemodel \
|
||||
--global_pool \
|
||||
--batch_size 16 \
|
||||
--world_size 1 \
|
||||
--model vit_large_patch16 \
|
||||
--epochs 50 \
|
||||
--epochs 100 \
|
||||
--blr 5e-3 --layer_decay 0.65 \
|
||||
--weight_decay 0.05 --drop_path 0.2 \
|
||||
--nb_classes 5 \
|
||||
--data_path ./IDRiD_data/ \
|
||||
--task ./finetune_IDRiD/ \
|
||||
--finetune ./RETFound_cfp_weights.pth \
|
||||
--input_size 224
|
||||
|
||||
--data_path ./IDRiD \
|
||||
--input_size 224 \
|
||||
--task RETFound_mae_meh-IDRiD \
|
||||
--finetune RETFound_mae_meh
|
||||
```
|
||||
|
||||
|
||||
@@ -104,64 +143,32 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f
|
||||
|
||||
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
--eval --batch_size 16 \
|
||||
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
|
||||
--model RETFound_mae \
|
||||
--savemodel \
|
||||
--eval \
|
||||
--global_pool \
|
||||
--batch_size 16 \
|
||||
--world_size 1 \
|
||||
--model vit_large_patch16 \
|
||||
--epochs 50 \
|
||||
--epochs 100 \
|
||||
--blr 5e-3 --layer_decay 0.65 \
|
||||
--weight_decay 0.05 --drop_path 0.2 \
|
||||
--nb_classes 5 \
|
||||
--data_path ./IDRiD_data/ \
|
||||
--task ./internal_IDRiD/ \
|
||||
--resume ./finetune_IDRiD/checkpoint-best.pth \
|
||||
--input_size 224
|
||||
|
||||
```
|
||||
|
||||
|
||||
### Load the model and weights (if you want to call the model in your code)
|
||||
|
||||
```python
|
||||
import torch
|
||||
import models_vit
|
||||
from util.pos_embed import interpolate_pos_embed
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
# call the model
|
||||
model = models_vit.__dict__['vit_large_patch16'](
|
||||
num_classes=2,
|
||||
drop_path_rate=0.2,
|
||||
global_pool=True,
|
||||
)
|
||||
|
||||
# load RETFound weights
|
||||
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
|
||||
checkpoint_model = checkpoint['model']
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
interpolate_pos_embed(model, checkpoint_model)
|
||||
|
||||
# load pre-trained model
|
||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
|
||||
|
||||
# manually initialize fc layer
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
|
||||
print("Model = %s" % str(model))
|
||||
--data_path ./IDRiD \
|
||||
--input_size 224 \
|
||||
--task RETFound_mae_meh-IDRiD \
|
||||
--resume ./finetune_IDRiD/checkpoint-best.pth
|
||||
```
|
||||
|
||||
|
||||
### 📃Citation
|
||||
|
||||
If you find this repository useful, please consider citing this paper:
|
||||
|
||||
```
|
||||
TBD
|
||||
```
|
||||
|
||||
```
|
||||
@article{zhou2023foundation,
|
||||
title={A foundation model for generalizable disease detection from retinal images},
|
||||
|
||||
Reference in New Issue
Block a user