Files
RETFound/latent_feature.ipynb

197 lines
5.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0ae19951",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"from PIL import Image\n",
"import models_vit as models\n",
"from huggingface_hub import hf_hub_download\n",
"np.set_printoptions(threshold=np.inf)\n",
"np.random.seed(1)\n",
"torch.manual_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "90c3d964",
"metadata": {},
"outputs": [],
"source": [
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
" \n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
" \n",
" # build model\n",
" if arch=='RETFound_mae':\n",
" model = models.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" global_pool=True,\n",
" )\n",
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
" else:\n",
" model = models.__dict__[arch](\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" args=None,\n",
" )\n",
" msg = model.load_state_dict(checkpoint['teacher'], strict=False)\n",
" return model\n",
"\n",
"def run_one_image(img, model, arch):\n",
" \n",
" x = torch.tensor(img)\n",
" x = x.unsqueeze(dim=0)\n",
" x = torch.einsum('nhwc->nchw', x)\n",
" \n",
" x = x.to(device, non_blocking=True)\n",
" latent = model.forward_features(x.float())\n",
" \n",
" if arch=='dinov2_large':\n",
" latent = latent[:, 1:, :].mean(dim=1,keepdim=True)\n",
" latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)\n",
" \n",
" latent = torch.squeeze(latent)\n",
"\n",
" return latent\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9a250363",
"metadata": {},
"outputs": [],
"source": [
"def get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='RETFound_mae'):\n",
" #loading model\n",
" model_ = prepare_model(chkpt_dir, arch)\n",
" model_.to(device)\n",
"\n",
" img_list = os.listdir(data_path)\n",
" \n",
" name_list = []\n",
" feature_list = []\n",
" model_.eval()\n",
" \n",
" finished_num = 0\n",
" for i in img_list:\n",
" finished_num+=1\n",
" if (finished_num%1000 == 0):\n",
" print(str(finished_num)+\"finished\")\n",
" \n",
" img = Image.open(os.path.join(data_path, i))\n",
" img = img.resize((224, 224))\n",
" img = np.array(img) / 255.\n",
" img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()\n",
" img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()\n",
" img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()\n",
" assert img.shape == (224, 224, 3)\n",
" \n",
" latent_feature = run_one_image(img, model_,arch)\n",
" \n",
" name_list.append(i)\n",
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
" \n",
" return [name_list,feature_list]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54acfcd7",
"metadata": {},
"outputs": [],
"source": [
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
"data_path = 'DATA_PATH'\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"arch='dinov2_large'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0296f74e",
"metadata": {},
"outputs": [],
"source": [
"[name_list,feature]=get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch=arch)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "925d3994",
"metadata": {},
"outputs": [],
"source": [
"#save the feature\n",
"df_feature = pd.DataFrame(feature)\n",
"df_imgname = pd.DataFrame(name_list)\n",
"df_visualization = pd.concat([df_imgname,df_feature], axis=1)\n",
"column_name_list = []\n",
"\n",
"for i in range(1024):\n",
" column_name_list.append(\"feature_{}\".format(i))\n",
"df_visualization.columns = [\"name\"] + column_name_list\n",
"df_visualization.to_csv(\"Feature.csv\",index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0d13a7-2b46-40eb-ab48-5f90a6aeecb5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "test",
"name": "common-cu121.m123",
"type": "gcloud",
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/base-cu121:m123"
},
"kernelspec": {
"display_name": "Python_test (Local)",
"language": "python",
"name": "test"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}