Correctly use arch to reference architecture in models_vit. Set weights_only to False for use with PyTorch>=2.6. Automatically select cpu if cuda is not available.

This commit is contained in:
bartsserver
2025-04-17 12:09:07 +02:00
parent 91915d6a14
commit a7a9b3a8b7
+7 -7
View File
@@ -22,18 +22,18 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": null,
"id": "90c3d964",
"metadata": {},
"outputs": [],
"source": [
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
" \n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
" \n",
" # build model\n",
" if arch=='vit_large_patch16':\n",
" if arch=='RETFound_mae':\n",
" model = models.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
@@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": null,
"id": "9a250363",
"metadata": {},
"outputs": [],
@@ -78,7 +78,7 @@
"def get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='vit_large_patch16'):\n",
" arch='RETFound_mae'):\n",
" #loading model\n",
" model_ = prepare_model(chkpt_dir, arch)\n",
" model_.to(device)\n",
@@ -121,7 +121,7 @@
"source": [
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
"data_path = 'DATA_PATH'\n",
"device = torch.device('cuda')\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"arch='dinov2_large'"
]
},