Correctly use arch to reference architecture in models_vit. Set weights_only to False for use with PyTorch>=2.6. Automatically select cpu if cuda is not available.
This commit is contained in:
@@ -22,18 +22,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": null,
|
||||
"id": "90c3d964",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
|
||||
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
|
||||
" \n",
|
||||
" # load model\n",
|
||||
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
|
||||
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
|
||||
" \n",
|
||||
" # build model\n",
|
||||
" if arch=='vit_large_patch16':\n",
|
||||
" if arch=='RETFound_mae':\n",
|
||||
" model = models.__dict__[arch](\n",
|
||||
" img_size=224,\n",
|
||||
" num_classes=5,\n",
|
||||
@@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": null,
|
||||
"id": "9a250363",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -78,7 +78,7 @@
|
||||
"def get_feature(data_path,\n",
|
||||
" chkpt_dir,\n",
|
||||
" device,\n",
|
||||
" arch='vit_large_patch16'):\n",
|
||||
" arch='RETFound_mae'):\n",
|
||||
" #loading model\n",
|
||||
" model_ = prepare_model(chkpt_dir, arch)\n",
|
||||
" model_.to(device)\n",
|
||||
@@ -121,7 +121,7 @@
|
||||
"source": [
|
||||
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
|
||||
"data_path = 'DATA_PATH'\n",
|
||||
"device = torch.device('cuda')\n",
|
||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
"arch='dinov2_large'"
|
||||
]
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user