v1.0
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
## RETFound - A foundation model for retinal image
|
||||
## RETFound - A foundation model for retinal imaging
|
||||
|
||||
|
||||
This is official repo for RETFound, which heavily bases on [MAE](https://github.com/facebookresearch/mae):
|
||||
This is the official repo for RETFound, which is heavily based on [MAE](https://github.com/facebookresearch/mae):
|
||||
|
||||
|
||||
### Key features
|
||||
|
||||
- RETFound was trained on 1.6 million retinal images
|
||||
- RETFound was self-supervised pre-trained on 1.6 million retinal images
|
||||
- RETFound has been validated in multiple disease detection tasks
|
||||
- RETFound can be efficiently adapted to customised task
|
||||
- RETFound can be efficiently adapted to customised tasks
|
||||
|
||||
|
||||
### Install enviroment
|
||||
@@ -48,7 +48,7 @@ pip install -r requirement.txt
|
||||
</tr>
|
||||
</tbody></table>
|
||||
|
||||
- Organise data (use IDRiD as [example](Example.ipynb))
|
||||
- Organise data (using IDRiD as an [example](Example.ipynb))
|
||||
|
||||
<p align="left">
|
||||
<img src="./pic/file_index.jpg" width="160">
|
||||
@@ -93,3 +93,41 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f
|
||||
```
|
||||
|
||||
|
||||
### Load the model and weights
|
||||
|
||||
```
|
||||
import torch
|
||||
import models_vit
|
||||
from util.pos_embed import interpolate_pos_embed
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
# call the model
|
||||
model = models_vit.__dict__['vit_large_patch16'](
|
||||
num_classes=2,
|
||||
drop_path_rate=0.2,
|
||||
global_pool=True,
|
||||
)
|
||||
|
||||
# load RETFound weights
|
||||
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
|
||||
checkpoint_model = checkpoint['model']
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
# interpolate position embedding
|
||||
interpolate_pos_embed(model, checkpoint_model)
|
||||
|
||||
# load pre-trained model
|
||||
msg = model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
|
||||
|
||||
# manually initialize fc layer
|
||||
trunc_normal_(model.head.weight, std=2e-5)
|
||||
|
||||
print("Model = %s" % str(model))
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user