{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "1eae7403-f458-4f55-a557-4e045bd6f679", "metadata": { "id": "1eae7403-f458-4f55-a557-4e045bd6f679" }, "outputs": [], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import models_vit" ] }, { "cell_type": "code", "execution_count": 3, "id": "4573e6be-935a-4106-8c06-e467552b0e3d", "metadata": { "id": "4573e6be-935a-4106-8c06-e467552b0e3d" }, "outputs": [], "source": [ "\n", "imagenet_mean = np.array([0.485, 0.456, 0.406])\n", "imagenet_std = np.array([0.229, 0.224, 0.225])\n", "\n", "\n", "def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n", " # build model\n", " model = models_vit.__dict__[arch](\n", " img_size=224,\n", " num_classes=5,\n", " drop_path_rate=0,\n", " global_pool=True,\n", " )\n", " # load model\n", " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", " msg = model.load_state_dict(checkpoint['model'], strict=False)\n", " return model\n", "\n", "def run_one_image(img, model):\n", " \n", " x = torch.tensor(img)\n", " x = x.unsqueeze(dim=0)\n", " x = torch.einsum('nhwc->nchw', x)\n", " \n", " x = x.to(device, non_blocking=True)\n", " latent = model.forward_features(x.float())\n", " latent = torch.squeeze(latent)\n", " \n", " return latent\n", "\n" ] }, { "cell_type": "markdown", "id": "8b7e691d-93d2-439f-91d6-c22716a897b5", "metadata": { "id": "8b7e691d-93d2-439f-91d6-c22716a897b5" }, "source": [ "### Load a pre-trained model" ] }, { "cell_type": "code", "execution_count": 4, "id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d", "metadata": { "id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d", "outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model loaded.\n" ] } ], "source": [ "# download pre-trained RETFound \n", "\n", "chkpt_dir = './RETFound_cfp.pth'\n", "model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n", "\n", "device = torch.device('cuda')\n", "model_.to(device)\n", "print('Model loaded.')\n" ] }, { "cell_type": "markdown", "id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6", "metadata": { "id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6" }, "source": [ "### Load images and save latent feature" ] }, { "cell_type": "code", "execution_count": 5, "id": "27755296-05cc-4344-90de-a8ab3878f485", "metadata": { "id": "27755296-05cc-4344-90de-a8ab3878f485", "outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70", "tags": [] }, "outputs": [ { "ename": "FileNotFoundError", "evalue": "[Errno 2] No such file or directory: 'Your data path'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_16866/3238108902.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# get image list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'Your data path'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mname_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'Your data path'" ] } ], "source": [ "# get image list\n", "data_path = 'Your data path'\n", "img_list = os.listdir(data_path)\n", "\n", "name_list = []\n", "feature_list = []\n", "model_.eval()\n", "\n", "for i in img_list:\n", " img = Image.open(os.path.join(data_path, i))\n", " img = img.resize((224, 224))\n", " img = np.array(img) / 255.\n", "\n", " assert img.shape == (224, 224, 3)\n", "\n", " # normalize by mean and sd\n", " # can use customised mean and sd for your data\n", " img = img - imagenet_mean\n", " img = img / imagenet_std\n", " \n", " latent_feature = run_one_image(img, model_)\n", " \n", " name_list.append(i)\n", " feature_list.append(latent_feature.detach().cpu().numpy())\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "a365ec24-8e29-485e-83b5-5ac1d02945bb", "metadata": {}, "outputs": [], "source": [ "latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n", "latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')" ] }, { "cell_type": "code", "execution_count": null, "id": "2e8bd5e6-5780-420d-9d4c-96025b265668", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "environment": { "kernel": "python3", "name": "common-cu110.m91", "type": "gcloud", "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.12" } }, "nbformat": 4, "nbformat_minor": 5 }