134 lines
3.4 KiB
Markdown
134 lines
3.4 KiB
Markdown
## RETFound - A foundation model for retinal imaging
|
|
|
|
|
|
This is the official repo for RETFound, which is heavily based on [MAE](https://github.com/facebookresearch/mae):
|
|
|
|
|
|
### Key features
|
|
|
|
- RETFound was self-supervised pre-trained on 1.6 million retinal images
|
|
- RETFound has been validated in multiple disease detection tasks
|
|
- RETFound can be efficiently adapted to customised tasks
|
|
|
|
|
|
### Install enviroment
|
|
|
|
Create enviroment with conda:
|
|
|
|
```
|
|
conda create -n retfound python=3.6.15 -y
|
|
```
|
|
|
|
Install Pytorch 1.81 (cuda 11.1)
|
|
```
|
|
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
|
|
```
|
|
|
|
Install others
|
|
```
|
|
pip install -r requirement.txt
|
|
```
|
|
|
|
|
|
### Fine-tuning with RETFound weights
|
|
|
|
- RETFound pre-trained weights
|
|
<table><tbody>
|
|
<!-- START TABLE -->
|
|
<!-- TABLE HEADER -->
|
|
<th valign="bottom"></th>
|
|
<th valign="bottom">ViT-Large</th>
|
|
<!-- TABLE BODY -->
|
|
<tr><td align="left">Colour fundus image</td>
|
|
<td align="center"><a href="https://drive.google.com/file/d/1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE/view?usp=sharing">download</a></td>
|
|
</tr>
|
|
<!-- TABLE BODY -->
|
|
<tr><td align="left">OCT</td>
|
|
<td align="center"><a href="https://drive.google.com/file/d/1m6s7QYkjyjJDlpEuXm7Xp3PmjN-elfW2/view?usp=sharing">download</a></td>
|
|
</tr>
|
|
</tbody></table>
|
|
|
|
- Organise data (using IDRiD as an [example](Example.ipynb))
|
|
|
|
<p align="left">
|
|
<img src="./pic/file_index.jpg" width="160">
|
|
</p>
|
|
|
|
|
|
- Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training.
|
|
|
|
|
|
```
|
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py
|
|
--batch_size 16 \
|
|
--world_size 1 \
|
|
--model vit_large_patch16 \
|
|
--epochs 50 \
|
|
--blr 5e-3 --layer_decay 0.65 \
|
|
--weight_decay 0.05 --drop_path 0.2 \
|
|
--nb_classes 5 \
|
|
--data_path ./IDRiD_data/ \
|
|
--task ./finetune_IDRiD/ \
|
|
--finetune ./RETFound_cfp_weights.pth
|
|
|
|
```
|
|
|
|
|
|
- For evaluation only
|
|
|
|
|
|
```
|
|
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py
|
|
--eval --batch_size 16 \
|
|
--world_size 1 \
|
|
--model vit_large_patch16 \
|
|
--epochs 50 \
|
|
--blr 5e-3 --layer_decay 0.65 \
|
|
--weight_decay 0.05 --drop_path 0.2 \
|
|
--nb_classes 5 \
|
|
--data_path ./IDRiD_data/ \
|
|
--task ./internal_IDRiD/ \
|
|
--resume ./finetune_IDRiD/checkpoint-best.pth
|
|
|
|
```
|
|
|
|
|
|
### Load the model and weights
|
|
|
|
```
|
|
import torch
|
|
import models_vit
|
|
from util.pos_embed import interpolate_pos_embed
|
|
from timm.models.layers import trunc_normal_
|
|
|
|
# call the model
|
|
model = models_vit.__dict__['vit_large_patch16'](
|
|
num_classes=2,
|
|
drop_path_rate=0.2,
|
|
global_pool=True,
|
|
)
|
|
|
|
# load RETFound weights
|
|
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
|
|
checkpoint_model = checkpoint['model']
|
|
state_dict = model.state_dict()
|
|
for k in ['head.weight', 'head.bias']:
|
|
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
|
print(f"Removing key {k} from pretrained checkpoint")
|
|
del checkpoint_model[k]
|
|
|
|
# interpolate position embedding
|
|
interpolate_pos_embed(model, checkpoint_model)
|
|
|
|
# load pre-trained model
|
|
msg = model.load_state_dict(checkpoint_model, strict=False)
|
|
|
|
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
|
|
|
|
# manually initialize fc layer
|
|
trunc_normal_(model.head.weight, std=2e-5)
|
|
|
|
print("Model = %s" % str(model))
|
|
```
|
|
|