diff --git a/latent_feature.ipynb b/latent_feature.ipynb index 7674dd8..2ed1f9f 100644 --- a/latent_feature.ipynb +++ b/latent_feature.ipynb @@ -22,18 +22,18 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "id": "90c3d964", "metadata": {}, "outputs": [], "source": [ - "def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n", + "def prepare_model(chkpt_dir, arch='RETFound_mae'):\n", " \n", " # load model\n", - " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", + " checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n", " \n", " # build model\n", - " if arch=='vit_large_patch16':\n", + " if arch=='RETFound_mae':\n", " model = models.__dict__[arch](\n", " img_size=224,\n", " num_classes=5,\n", @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "id": "9a250363", "metadata": {}, "outputs": [], @@ -78,7 +78,7 @@ "def get_feature(data_path,\n", " chkpt_dir,\n", " device,\n", - " arch='vit_large_patch16'):\n", + " arch='RETFound_mae'):\n", " #loading model\n", " model_ = prepare_model(chkpt_dir, arch)\n", " model_.to(device)\n", @@ -121,7 +121,7 @@ "source": [ "chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n", "data_path = 'DATA_PATH'\n", - "device = torch.device('cuda')\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "arch='dinov2_large'" ] },