add latent feature notebook
This commit is contained in:
@@ -0,0 +1,210 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "1eae7403-f458-4f55-a557-4e045bd6f679",
|
||||
"metadata": {
|
||||
"id": "1eae7403-f458-4f55-a557-4e045bd6f679"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from PIL import Image\n",
|
||||
"import models_vit"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "4573e6be-935a-4106-8c06-e467552b0e3d",
|
||||
"metadata": {
|
||||
"id": "4573e6be-935a-4106-8c06-e467552b0e3d"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
|
||||
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
|
||||
" # build model\n",
|
||||
" model = models_vit.__dict__[arch](\n",
|
||||
" img_size=224,\n",
|
||||
" num_classes=5,\n",
|
||||
" drop_path_rate=0,\n",
|
||||
" global_pool=True,\n",
|
||||
" )\n",
|
||||
" # load model\n",
|
||||
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
|
||||
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
|
||||
" return model\n",
|
||||
"\n",
|
||||
"def run_one_image(img, model):\n",
|
||||
" \n",
|
||||
" x = torch.tensor(img)\n",
|
||||
" x = x.unsqueeze(dim=0)\n",
|
||||
" x = torch.einsum('nhwc->nchw', x)\n",
|
||||
" \n",
|
||||
" x = x.to(device, non_blocking=True)\n",
|
||||
" latent = model.forward_features(x.float())\n",
|
||||
" latent = torch.squeeze(latent)\n",
|
||||
" \n",
|
||||
" return latent\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5",
|
||||
"metadata": {
|
||||
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5"
|
||||
},
|
||||
"source": [
|
||||
"### Load a pre-trained model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
|
||||
"metadata": {
|
||||
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
|
||||
"outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Model loaded.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# download pre-trained RETFound \n",
|
||||
"\n",
|
||||
"chkpt_dir = './RETFound_cfp.pth'\n",
|
||||
"model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n",
|
||||
"\n",
|
||||
"device = torch.device('cuda')\n",
|
||||
"model_.to(device)\n",
|
||||
"print('Model loaded.')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6",
|
||||
"metadata": {
|
||||
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6"
|
||||
},
|
||||
"source": [
|
||||
"### Load images and save latent feature"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "27755296-05cc-4344-90de-a8ab3878f485",
|
||||
"metadata": {
|
||||
"id": "27755296-05cc-4344-90de-a8ab3878f485",
|
||||
"outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70",
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "FileNotFoundError",
|
||||
"evalue": "[Errno 2] No such file or directory: 'Your data path'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[0;32m/tmp/ipykernel_16866/3238108902.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# get image list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'Your data path'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mname_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||||
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'Your data path'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# get image list\n",
|
||||
"data_path = 'Your data path'\n",
|
||||
"img_list = os.listdir(data_path)\n",
|
||||
"\n",
|
||||
"name_list = []\n",
|
||||
"feature_list = []\n",
|
||||
"model_.eval()\n",
|
||||
"\n",
|
||||
"for i in img_list:\n",
|
||||
" img = Image.open(os.path.join(data_path, i))\n",
|
||||
" img = img.resize((224, 224))\n",
|
||||
" img = np.array(img) / 255.\n",
|
||||
"\n",
|
||||
" assert img.shape == (224, 224, 3)\n",
|
||||
"\n",
|
||||
" # normalize by mean and sd\n",
|
||||
" # can use customised mean and sd for your data\n",
|
||||
" img = img - imagenet_mean\n",
|
||||
" img = img / imagenet_std\n",
|
||||
" \n",
|
||||
" latent_feature = run_one_image(img, model_)\n",
|
||||
" \n",
|
||||
" name_list.append(i)\n",
|
||||
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a365ec24-8e29-485e-83b5-5ac1d02945bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n",
|
||||
"latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2e8bd5e6-5780-420d-9d4c-96025b265668",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"environment": {
|
||||
"kernel": "python3",
|
||||
"name": "common-cu110.m91",
|
||||
"type": "gcloud",
|
||||
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user