diff --git a/README.md b/README.md index de5b8cb..b79e753 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -## RETFound - A foundation model for retinal image +## RETFound - A foundation model for retinal imaging -This is official repo for RETFound, which heavily bases on [MAE](https://github.com/facebookresearch/mae): +This is the official repo for RETFound, which is heavily based on [MAE](https://github.com/facebookresearch/mae): ### Key features -- RETFound was trained on 1.6 million retinal images +- RETFound was self-supervised pre-trained on 1.6 million retinal images - RETFound has been validated in multiple disease detection tasks -- RETFound can be efficiently adapted to customised task +- RETFound can be efficiently adapted to customised tasks ### Install enviroment @@ -48,7 +48,7 @@ pip install -r requirement.txt -- Organise data (use IDRiD as [example](Example.ipynb)) +- Organise data (using IDRiD as an [example](Example.ipynb))
@@ -93,3 +93,41 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f
```
+### Load the model and weights
+
+```
+import torch
+import models_vit
+from util.pos_embed import interpolate_pos_embed
+from timm.models.layers import trunc_normal_
+
+# call the model
+model = models_vit.__dict__['vit_large_patch16'](
+ num_classes=2,
+ drop_path_rate=0.2,
+ global_pool=True,
+)
+
+# load RETFound weights
+checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
+checkpoint_model = checkpoint['model']
+state_dict = model.state_dict()
+for k in ['head.weight', 'head.bias']:
+ if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
+ print(f"Removing key {k} from pretrained checkpoint")
+ del checkpoint_model[k]
+
+# interpolate position embedding
+interpolate_pos_embed(model, checkpoint_model)
+
+# load pre-trained model
+msg = model.load_state_dict(checkpoint_model, strict=False)
+
+assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
+
+# manually initialize fc layer
+trunc_normal_(model.head.weight, std=2e-5)
+
+print("Model = %s" % str(model))
+```
+