Files
RETFound/RETFound_Feature.ipynb
T
2024-01-27 19:33:15 +00:00

211 lines
6.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1eae7403-f458-4f55-a557-4e045bd6f679",
"metadata": {
"id": "1eae7403-f458-4f55-a557-4e045bd6f679"
},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import models_vit"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4573e6be-935a-4106-8c06-e467552b0e3d",
"metadata": {
"id": "4573e6be-935a-4106-8c06-e467552b0e3d"
},
"outputs": [],
"source": [
"\n",
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"\n",
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
" # build model\n",
" model = models_vit.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" global_pool=True,\n",
" )\n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
" return model\n",
"\n",
"def run_one_image(img, model):\n",
" \n",
" x = torch.tensor(img)\n",
" x = x.unsqueeze(dim=0)\n",
" x = torch.einsum('nhwc->nchw', x)\n",
" \n",
" x = x.to(device, non_blocking=True)\n",
" latent = model.forward_features(x.float())\n",
" latent = torch.squeeze(latent)\n",
" \n",
" return latent\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5",
"metadata": {
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5"
},
"source": [
"### Load a pre-trained model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
"metadata": {
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
"outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model loaded.\n"
]
}
],
"source": [
"# download pre-trained RETFound \n",
"\n",
"chkpt_dir = './RETFound_cfp.pth'\n",
"model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n",
"\n",
"device = torch.device('cuda')\n",
"model_.to(device)\n",
"print('Model loaded.')\n"
]
},
{
"cell_type": "markdown",
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6",
"metadata": {
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6"
},
"source": [
"### Load images and save latent feature"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "27755296-05cc-4344-90de-a8ab3878f485",
"metadata": {
"id": "27755296-05cc-4344-90de-a8ab3878f485",
"outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70",
"tags": []
},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'Your data path'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_16866/3238108902.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# get image list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'Your data path'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mname_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'Your data path'"
]
}
],
"source": [
"# get image list\n",
"data_path = 'Your data path'\n",
"img_list = os.listdir(data_path)\n",
"\n",
"name_list = []\n",
"feature_list = []\n",
"model_.eval()\n",
"\n",
"for i in img_list:\n",
" img = Image.open(os.path.join(data_path, i))\n",
" img = img.resize((224, 224))\n",
" img = np.array(img) / 255.\n",
"\n",
" assert img.shape == (224, 224, 3)\n",
"\n",
" # normalize by mean and sd\n",
" # can use customised mean and sd for your data\n",
" img = img - imagenet_mean\n",
" img = img / imagenet_std\n",
" \n",
" latent_feature = run_one_image(img, model_)\n",
" \n",
" name_list.append(i)\n",
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a365ec24-8e29-485e-83b5-5ac1d02945bb",
"metadata": {},
"outputs": [],
"source": [
"latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n",
"latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e8bd5e6-5780-420d-9d4c-96025b265668",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"environment": {
"kernel": "python3",
"name": "common-cu110.m91",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}