Merge pull request #43 from BartvanderWoude/main

Fixes for latent_feature.ipynb
This commit is contained in:
Yukun Zhou
2025-04-29 10:42:20 +01:00
committed by GitHub
+7 -7
View File
@@ -22,18 +22,18 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"id": "90c3d964",
"metadata": {},
"outputs": [],
"source": [
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
" \n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
" \n",
" # build model\n",
" if arch=='vit_large_patch16':\n",
" if arch=='RETFound_mae':\n",
" model = models.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
@@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "9a250363",
"metadata": {},
"outputs": [],
@@ -78,7 +78,7 @@
"def get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='vit_large_patch16'):\n",
" arch='RETFound_mae'):\n",
" #loading model\n",
" model_ = prepare_model(chkpt_dir, arch)\n",
" model_.to(device)\n",
@@ -121,7 +121,7 @@
"source": [
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
"data_path = 'DATA_PATH'\n",
"device = torch.device('cuda')\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"arch='dinov2_large'"
]
},