From 547ac47153a4d6188e37f33d3955d5c829ebf1bc Mon Sep 17 00:00:00 2001 From: rmaphoh Date: Sat, 27 Jan 2024 19:33:15 +0000 Subject: [PATCH] add latent feature notebook --- RETFound_Feature.ipynb | 210 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 RETFound_Feature.ipynb diff --git a/RETFound_Feature.ipynb b/RETFound_Feature.ipynb new file mode 100644 index 0000000..52f6337 --- /dev/null +++ b/RETFound_Feature.ipynb @@ -0,0 +1,210 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1eae7403-f458-4f55-a557-4e045bd6f679", + "metadata": { + "id": "1eae7403-f458-4f55-a557-4e045bd6f679" + }, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from PIL import Image\n", + "import models_vit" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4573e6be-935a-4106-8c06-e467552b0e3d", + "metadata": { + "id": "4573e6be-935a-4106-8c06-e467552b0e3d" + }, + "outputs": [], + "source": [ + "\n", + "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", + "imagenet_std = np.array([0.229, 0.224, 0.225])\n", + "\n", + "\n", + "def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n", + " # build model\n", + " model = models_vit.__dict__[arch](\n", + " img_size=224,\n", + " num_classes=5,\n", + " drop_path_rate=0,\n", + " global_pool=True,\n", + " )\n", + " # load model\n", + " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", + " msg = model.load_state_dict(checkpoint['model'], strict=False)\n", + " return model\n", + "\n", + "def run_one_image(img, model):\n", + " \n", + " x = torch.tensor(img)\n", + " x = x.unsqueeze(dim=0)\n", + " x = torch.einsum('nhwc->nchw', x)\n", + " \n", + " x = x.to(device, non_blocking=True)\n", + " latent = model.forward_features(x.float())\n", + " latent = torch.squeeze(latent)\n", + " \n", + " return latent\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "8b7e691d-93d2-439f-91d6-c22716a897b5", + "metadata": { + "id": "8b7e691d-93d2-439f-91d6-c22716a897b5" + }, + "source": [ + "### Load a pre-trained model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d", + "metadata": { + "id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d", + "outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model loaded.\n" + ] + } + ], + "source": [ + "# download pre-trained RETFound \n", + "\n", + "chkpt_dir = './RETFound_cfp.pth'\n", + "model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n", + "\n", + "device = torch.device('cuda')\n", + "model_.to(device)\n", + "print('Model loaded.')\n" + ] + }, + { + "cell_type": "markdown", + "id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6", + "metadata": { + "id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6" + }, + "source": [ + "### Load images and save latent feature" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "27755296-05cc-4344-90de-a8ab3878f485", + "metadata": { + "id": "27755296-05cc-4344-90de-a8ab3878f485", + "outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70", + "tags": [] + }, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'Your data path'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_16866/3238108902.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# get image list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'Your data path'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mname_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'Your data path'" + ] + } + ], + "source": [ + "# get image list\n", + "data_path = 'Your data path'\n", + "img_list = os.listdir(data_path)\n", + "\n", + "name_list = []\n", + "feature_list = []\n", + "model_.eval()\n", + "\n", + "for i in img_list:\n", + " img = Image.open(os.path.join(data_path, i))\n", + " img = img.resize((224, 224))\n", + " img = np.array(img) / 255.\n", + "\n", + " assert img.shape == (224, 224, 3)\n", + "\n", + " # normalize by mean and sd\n", + " # can use customised mean and sd for your data\n", + " img = img - imagenet_mean\n", + " img = img / imagenet_std\n", + " \n", + " latent_feature = run_one_image(img, model_)\n", + " \n", + " name_list.append(i)\n", + " feature_list.append(latent_feature.detach().cpu().numpy())\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a365ec24-8e29-485e-83b5-5ac1d02945bb", + "metadata": {}, + "outputs": [], + "source": [ + "latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n", + "latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e8bd5e6-5780-420d-9d4c-96025b265668", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "environment": { + "kernel": "python3", + "name": "common-cu110.m91", + "type": "gcloud", + "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}