add example Jupyter Notebook
This commit is contained in:
@@ -46,6 +46,8 @@ pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorc
|
||||
git clone https://github.com/rmaphoh/RETFound/
|
||||
cd RETFound
|
||||
pip install -r requirements.txt
|
||||
pip install ipykernel
|
||||
python -m ipykernel install --user --name retfound --display-name "Python (retfound)"
|
||||
```
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,431 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "76b39fb1",
|
||||
"metadata": {
|
||||
"jp-MarkdownHeadingCollapsed": true
|
||||
},
|
||||
"source": [
|
||||
"## Jupyter notebook example - Classification task\n",
|
||||
"### Example using [MESSIDOR2](https://www.adcis.net/en/third-party/messidor2/) dataset\n",
|
||||
"**Application**: Using RETFound for five-category diabetic retinopathy classification\n",
|
||||
"\n",
|
||||
"**Author**: Yukun Zhou\n",
|
||||
"\n",
|
||||
"**Date**: 30 Nov 2025\n",
|
||||
"\n",
|
||||
"**Performance**:\n",
|
||||
"\n",
|
||||
"<table align=\"left\">\n",
|
||||
"<tr>\n",
|
||||
" <th>Accuracy</th>\n",
|
||||
" <th>Recall</th>\n",
|
||||
" <th>F1 Score</th>\n",
|
||||
" <th>ROC AUC</th>\n",
|
||||
" <th>PR AUC</th>\n",
|
||||
"</tr>\n",
|
||||
"<tr>\n",
|
||||
" <td>0.7091</td>\n",
|
||||
" <td>0.5616</td>\n",
|
||||
" <td>0.6078</td>\n",
|
||||
" <td>0.9037</td>\n",
|
||||
" <td>0.6863</td>\n",
|
||||
"</tr>\n",
|
||||
"</table>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7ec435a7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Install environment\n",
|
||||
"1. Follow [RETFound README](https://github.com/rmaphoh/RETFound) to install environment\n",
|
||||
"2. Restart this Jupyter Notebook\n",
|
||||
"3. Select Kernel retfound"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "7cbf5e93-6ca0-4401-88e6-64e39968e7cd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Project root: /home/jupyter/RETFound\n",
|
||||
"sys.executable: /opt/conda/envs/retfound/bin/python\n",
|
||||
"torch version: 2.4.1+cu118\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys, torch\n",
|
||||
"from pathlib import Path\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"PROJECT_ROOT = Path.cwd().resolve()\n",
|
||||
"\n",
|
||||
"if PROJECT_ROOT.name == 'examples': PROJECT_ROOT = PROJECT_ROOT.parent\n",
|
||||
"os.chdir(PROJECT_ROOT)\n",
|
||||
"\n",
|
||||
"print('Project root:', PROJECT_ROOT)\n",
|
||||
"print(\"sys.executable:\", sys.executable)\n",
|
||||
"print(\"torch version:\", torch.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ed67953f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Prepare MESSIDOR2 dataset\n",
|
||||
"1. Download from the [shared data pool](https://github.com/rmaphoh/RETFound/blob/main/BENCHMARK.md).\n",
|
||||
"2. Put the data folder under the project directory, e.g. \"RETFound/MESSIDOR2\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "357be2fa-a914-4d1f-8759-76b2b1c3f20f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Hyperparameter and path settings\n",
|
||||
"1. Can choose finetune or lp (linear probe)\n",
|
||||
"2. Model selection [info](https://github.com/rmaphoh/RETFound#:~:text=In%20train.sh%2C%20the%20model%20can%20be%20selected%20by%20changing%20the%20hyperparameters%20MODEL%2C%20MODEL_ARCH%2C%20FINETUNE%3A)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "5f675843",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"DATA_PATH: /home/jupyter/RETFound/MESSIDOR2\n",
|
||||
"TASK: retfound_dinov2_MESSIDOR2_finetune\n",
|
||||
"OUTPUT_DIR: /home/jupyter/RETFound/output_dir/retfound_dinov2_MESSIDOR2_finetune\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"ADAPTATION='finetune'\n",
|
||||
"MODEL='RETFound_dinov2'\n",
|
||||
"MODEL_ARCH='retfound_dinov2'\n",
|
||||
"FINETUNE='RETFound_dinov2_meh'\n",
|
||||
"DATASET='MESSIDOR2'\n",
|
||||
"NUM_CLASS=5\n",
|
||||
"DATA_PATH=PROJECT_ROOT/DATASET\n",
|
||||
"BATCH_SIZE=24\n",
|
||||
"EPOCHS=50\n",
|
||||
"INPUT_SIZE=224\n",
|
||||
"WORLD_SIZE=1\n",
|
||||
"TASK=f\"{MODEL_ARCH}_{DATASET}_{ADAPTATION}\"\n",
|
||||
"OUTPUT_DIR=PROJECT_ROOT/'output_dir'/TASK\n",
|
||||
"print('DATA_PATH:',DATA_PATH)\n",
|
||||
"print('TASK:',TASK)\n",
|
||||
"print('OUTPUT_DIR:',OUTPUT_DIR)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6ac04845",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Fine-tuning and testing RETFound on MESSIDOR2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "d23ff751",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Not using distributed mode\n",
|
||||
"[12:55:12.463648] job dir: /home/jupyter/RETFound\n",
|
||||
"[12:55:12.463731] Namespace(batch_size=24,\n",
|
||||
"epochs=50,\n",
|
||||
"accum_iter=1,\n",
|
||||
"model='RETFound_dinov2',\n",
|
||||
"model_arch='retfound_dinov2',\n",
|
||||
"input_size=224,\n",
|
||||
"drop_path=0.2,\n",
|
||||
"global_pool=True,\n",
|
||||
"clip_grad=None,\n",
|
||||
"weight_decay=0.05,\n",
|
||||
"lr=None,\n",
|
||||
"blr=0.005,\n",
|
||||
"layer_decay=0.65,\n",
|
||||
"min_lr=1e-06,\n",
|
||||
"warmup_epochs=10,\n",
|
||||
"color_jitter=None,\n",
|
||||
"aa='rand-m9-mstd0.5-inc1',\n",
|
||||
"smoothing=0.1,\n",
|
||||
"reprob=0.25,\n",
|
||||
"remode='pixel',\n",
|
||||
"recount=1,\n",
|
||||
"resplit=False,\n",
|
||||
"mixup=0.0,\n",
|
||||
"cutmix=0.0,\n",
|
||||
"cutmix_minmax=None,\n",
|
||||
"mixup_prob=1.0,\n",
|
||||
"mixup_switch_prob=0.5,\n",
|
||||
"mixup_mode='batch',\n",
|
||||
"finetune='RETFound_dinov2_meh',\n",
|
||||
"task='retfound_dinov2_MESSIDOR2_finetune',\n",
|
||||
"adaptation='finetune',\n",
|
||||
"data_path='/home/jupyter/RETFound/MESSIDOR2',\n",
|
||||
"nb_classes=5,\n",
|
||||
"output_dir='./output_dir',\n",
|
||||
"log_dir='./output_logs',\n",
|
||||
"dataratio='1.0',\n",
|
||||
"stratified=False,\n",
|
||||
"device='cuda',\n",
|
||||
"seed=0,\n",
|
||||
"resume='',\n",
|
||||
"start_epoch=0,\n",
|
||||
"eval=False,\n",
|
||||
"dist_eval=False,\n",
|
||||
"num_workers=10,\n",
|
||||
"pin_mem=True,\n",
|
||||
"world_size=1,\n",
|
||||
"local_rank=-1,\n",
|
||||
"dist_on_itp=False,\n",
|
||||
"dist_url='env://',\n",
|
||||
"savemodel=True,\n",
|
||||
"norm='IMAGENET',\n",
|
||||
"enhance=False,\n",
|
||||
"datasets_seed=2026,\n",
|
||||
"distributed=False)\n",
|
||||
"^C\n",
|
||||
"Traceback (most recent call last):\n",
|
||||
" File \"/home/jupyter/RETFound/main_finetune.py\", line 448, in <module>\n",
|
||||
" main(args, criterion)\n",
|
||||
" File \"/home/jupyter/RETFound/main_finetune.py\", line 173, in main\n",
|
||||
" model = models.__dict__[args.model](\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/home/jupyter/RETFound/models_vit.py\", line 80, in RETFound_dinov2\n",
|
||||
" model = timm.create_model(\n",
|
||||
" ^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/_factory.py\", line 117, in create_model\n",
|
||||
" model = create_fn(\n",
|
||||
" ^^^^^^^^^^\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/vision_transformer.py\", line 2462, in vit_large_patch14_dinov2\n",
|
||||
" model = _create_vision_transformer(\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/vision_transformer.py\", line 1781, in _create_vision_transformer\n",
|
||||
" return build_model_with_cfg(\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/_builder.py\", line 398, in build_model_with_cfg\n",
|
||||
" model = model_cls(**kwargs)\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/vision_transformer.py\", line 540, in __init__\n",
|
||||
" self.init_weights(weight_init)\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/vision_transformer.py\", line 558, in init_weights\n",
|
||||
" named_apply(get_init_weights_vit(mode, head_bias), self)\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/_manipulate.py\", line 34, in named_apply\n",
|
||||
" named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/_manipulate.py\", line 34, in named_apply\n",
|
||||
" named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/_manipulate.py\", line 34, in named_apply\n",
|
||||
" named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n",
|
||||
" [Previous line repeated 1 more time]\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/_manipulate.py\", line 36, in named_apply\n",
|
||||
" fn(module=module, name=name)\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/models/vision_transformer.py\", line 712, in init_weights_vit_timm\n",
|
||||
" trunc_normal_(module.weight, std=.02)\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/layers/weight_init.py\", line 67, in trunc_normal_\n",
|
||||
" return _trunc_normal_(tensor, mean, std, a, b)\n",
|
||||
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
|
||||
" File \"/opt/conda/envs/retfound/lib/python3.11/site-packages/timm/layers/weight_init.py\", line 28, in _trunc_normal_\n",
|
||||
" tensor.uniform_(2 * l - 1, 2 * u - 1)\n",
|
||||
"KeyboardInterrupt\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"!{sys.executable} main_finetune.py \\\n",
|
||||
" --model {MODEL} \\\n",
|
||||
" --model_arch {MODEL_ARCH} \\\n",
|
||||
" --finetune {FINETUNE} \\\n",
|
||||
" --savemodel \\\n",
|
||||
" --global_pool \\\n",
|
||||
" --batch_size {BATCH_SIZE} \\\n",
|
||||
" --epochs {EPOCHS} \\\n",
|
||||
" --nb_classes {NUM_CLASS} \\\n",
|
||||
" --data_path {DATA_PATH} \\\n",
|
||||
" --input_size {INPUT_SIZE} \\\n",
|
||||
" --task {TASK} \\\n",
|
||||
" --adaptation {ADAPTATION}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84ce93ac",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Evaluation-only"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "0af0f8a7",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Not using distributed mode\n",
|
||||
"[13:01:56.741693] job dir: /home/jupyter/RETFound\n",
|
||||
"[13:01:56.741757] Namespace(batch_size=128,\n",
|
||||
"epochs=50,\n",
|
||||
"accum_iter=1,\n",
|
||||
"model='RETFound_dinov2',\n",
|
||||
"model_arch='retfound_dinov2',\n",
|
||||
"input_size=224,\n",
|
||||
"drop_path=0.2,\n",
|
||||
"global_pool=True,\n",
|
||||
"clip_grad=None,\n",
|
||||
"weight_decay=0.05,\n",
|
||||
"lr=None,\n",
|
||||
"blr=0.005,\n",
|
||||
"layer_decay=0.65,\n",
|
||||
"min_lr=1e-06,\n",
|
||||
"warmup_epochs=10,\n",
|
||||
"color_jitter=None,\n",
|
||||
"aa='rand-m9-mstd0.5-inc1',\n",
|
||||
"smoothing=0.1,\n",
|
||||
"reprob=0.25,\n",
|
||||
"remode='pixel',\n",
|
||||
"recount=1,\n",
|
||||
"resplit=False,\n",
|
||||
"mixup=0.0,\n",
|
||||
"cutmix=0.0,\n",
|
||||
"cutmix_minmax=None,\n",
|
||||
"mixup_prob=1.0,\n",
|
||||
"mixup_switch_prob=0.5,\n",
|
||||
"mixup_mode='batch',\n",
|
||||
"finetune='RETFound_dinov2_meh',\n",
|
||||
"task='retfound_dinov2_MESSIDOR2_finetune',\n",
|
||||
"adaptation='finetune',\n",
|
||||
"data_path='/home/jupyter/RETFound/MESSIDOR2',\n",
|
||||
"nb_classes=5,\n",
|
||||
"output_dir='./output_dir',\n",
|
||||
"log_dir='./output_logs',\n",
|
||||
"dataratio='1.0',\n",
|
||||
"stratified=False,\n",
|
||||
"device='cuda',\n",
|
||||
"seed=0,\n",
|
||||
"resume='/home/jupyter/RETFound/output_dir/retfound_dinov2_MESSIDOR2_finetune/checkpoint-best.pth',\n",
|
||||
"start_epoch=0,\n",
|
||||
"eval=True,\n",
|
||||
"dist_eval=False,\n",
|
||||
"num_workers=10,\n",
|
||||
"pin_mem=True,\n",
|
||||
"world_size=1,\n",
|
||||
"local_rank=-1,\n",
|
||||
"dist_on_itp=False,\n",
|
||||
"dist_url='env://',\n",
|
||||
"savemodel=True,\n",
|
||||
"norm='IMAGENET',\n",
|
||||
"enhance=False,\n",
|
||||
"datasets_seed=2026,\n",
|
||||
"distributed=False)\n",
|
||||
"/opt/conda/envs/retfound/lib/python3.11/site-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
|
||||
" warnings.warn(_create_warning_msg(\n",
|
||||
"[13:02:01.795301] Load checkpoint for eval from: /home/jupyter/RETFound/output_dir/retfound_dinov2_MESSIDOR2_finetune/checkpoint-best.pth\n",
|
||||
"[13:02:02.683863] [Adaptation] Full fine-tuning: training all parameters.\n",
|
||||
"[13:02:02.685617] number of trainable params (M): 303.23\n",
|
||||
"[13:02:02.685682] base lr: 5.00e-03\n",
|
||||
"[13:02:02.685697] actual lr: 2.50e-03\n",
|
||||
"[13:02:02.685710] accumulate grad iterations: 1\n",
|
||||
"[13:02:02.685722] effective batch size: 128\n",
|
||||
"[13:02:02.690046] criterion = CrossEntropyLoss()\n",
|
||||
"[13:02:03.513582] Resume checkpoint /home/jupyter/RETFound/output_dir/retfound_dinov2_MESSIDOR2_finetune/checkpoint-best.pth\n",
|
||||
"[13:02:03.514238] Test with the best model at epoch = 12\n",
|
||||
"[13:02:11.667583] test: [0/5] eta: 0:00:40 loss: 0.6644 (0.6644) time: 8.1504 data: 5.8343 max mem: 2841\n",
|
||||
"[13:02:15.546902] test: [4/5] eta: 0:00:02 loss: 0.7874 (0.9476) time: 2.4058 data: 1.2244 max mem: 2841\n",
|
||||
"[13:02:15.598723] test: Total time: 0:00:12 (2.4165 s / it)\n",
|
||||
"[13:02:15.622190] val loss: 0.9476207017898559\n",
|
||||
"[13:02:15.622238] Accuracy: 0.7091, F1 Score: 0.6078, ROC AUC: 0.9037, Hamming Loss: 0.1163,\n",
|
||||
" Jaccard Score: 0.4613, Precision: 0.7244, Recall: 0.5616,\n",
|
||||
" Average Precision: 0.6863, Kappa: 0.5141, Score: 0.6752\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"\n",
|
||||
"CKPT = OUTPUT_DIR / \"checkpoint-best.pth\"\n",
|
||||
"\n",
|
||||
"!{sys.executable} main_finetune.py \\\n",
|
||||
" --model {MODEL} \\\n",
|
||||
" --model_arch {MODEL_ARCH} \\\n",
|
||||
" --finetune {FINETUNE} \\\n",
|
||||
" --savemodel \\\n",
|
||||
" --global_pool \\\n",
|
||||
" --batch_size 128 \\\n",
|
||||
" --nb_classes {NUM_CLASS} \\\n",
|
||||
" --data_path {DATA_PATH} \\\n",
|
||||
" --input_size {INPUT_SIZE} \\\n",
|
||||
" --task {TASK} \\\n",
|
||||
" --adaptation {ADAPTATION} \\\n",
|
||||
" --eval \\\n",
|
||||
" --resume {CKPT}\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "02d2dce7-31c2-48e2-87ce-9223b74cf94e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"environment": {
|
||||
"kernel": "retfound",
|
||||
"name": "workbench-notebooks.m128",
|
||||
"type": "gcloud",
|
||||
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "retfound_jupyter (Local)",
|
||||
"language": "python",
|
||||
"name": "retfound"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user