From a7a9b3a8b71c7db74b4b71e3fd4e874f90af0a6a Mon Sep 17 00:00:00 2001 From: bartsserver Date: Thu, 17 Apr 2025 12:09:07 +0200 Subject: [PATCH] Correctly use arch to reference architecture in models_vit. Set weights_only to False for use with PyTorch>=2.6. Automatically select cpu if cuda is not available. --- latent_feature.ipynb | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/latent_feature.ipynb b/latent_feature.ipynb index 7674dd8..2ed1f9f 100644 --- a/latent_feature.ipynb +++ b/latent_feature.ipynb @@ -22,18 +22,18 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "90c3d964", "metadata": {}, "outputs": [], "source": [ - "def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n", + "def prepare_model(chkpt_dir, arch='RETFound_mae'):\n", " \n", " # load model\n", - " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", + " checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n", " \n", " # build model\n", - " if arch=='vit_large_patch16':\n", + " if arch=='RETFound_mae':\n", " model = models.__dict__[arch](\n", " img_size=224,\n", " num_classes=5,\n", @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "9a250363", "metadata": {}, "outputs": [], @@ -78,7 +78,7 @@ "def get_feature(data_path,\n", " chkpt_dir,\n", " device,\n", - " arch='vit_large_patch16'):\n", + " arch='RETFound_mae'):\n", " #loading model\n", " model_ = prepare_model(chkpt_dir, arch)\n", " model_.to(device)\n", @@ -121,7 +121,7 @@ "source": [ "chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n", "data_path = 'DATA_PATH'\n", - "device = torch.device('cuda')\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "arch='dinov2_large'" ] },