197 lines
5.6 KiB
Plaintext
197 lines
5.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0ae19951",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"from PIL import Image\n",
|
|
"import models_vit as models\n",
|
|
"from huggingface_hub import hf_hub_download\n",
|
|
"np.set_printoptions(threshold=np.inf)\n",
|
|
"np.random.seed(1)\n",
|
|
"torch.manual_seed(1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "90c3d964",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def prepare_model(chkpt_dir, arch='RETFound_mae'):\n",
|
|
" \n",
|
|
" # load model\n",
|
|
" checkpoint = torch.load(chkpt_dir, map_location='cpu', weights_only=False)\n",
|
|
" \n",
|
|
" # build model\n",
|
|
" if arch=='RETFound_mae':\n",
|
|
" model = models.__dict__[arch](\n",
|
|
" img_size=224,\n",
|
|
" num_classes=5,\n",
|
|
" drop_path_rate=0,\n",
|
|
" global_pool=True,\n",
|
|
" )\n",
|
|
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
|
|
" else:\n",
|
|
" model = models.__dict__[arch](\n",
|
|
" num_classes=5,\n",
|
|
" drop_path_rate=0,\n",
|
|
" args=None,\n",
|
|
" )\n",
|
|
" msg = model.load_state_dict(checkpoint['teacher'], strict=False)\n",
|
|
" return model\n",
|
|
"\n",
|
|
"def run_one_image(img, model, arch):\n",
|
|
" \n",
|
|
" x = torch.tensor(img)\n",
|
|
" x = x.unsqueeze(dim=0)\n",
|
|
" x = torch.einsum('nhwc->nchw', x)\n",
|
|
" \n",
|
|
" x = x.to(device, non_blocking=True)\n",
|
|
" latent = model.forward_features(x.float())\n",
|
|
" \n",
|
|
" if arch=='dinov2_large':\n",
|
|
" latent = latent[:, 1:, :].mean(dim=1,keepdim=True)\n",
|
|
" latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)\n",
|
|
" \n",
|
|
" latent = torch.squeeze(latent)\n",
|
|
"\n",
|
|
" return latent\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9a250363",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_feature(data_path,\n",
|
|
" chkpt_dir,\n",
|
|
" device,\n",
|
|
" arch='RETFound_mae'):\n",
|
|
" #loading model\n",
|
|
" model_ = prepare_model(chkpt_dir, arch)\n",
|
|
" model_.to(device)\n",
|
|
"\n",
|
|
" img_list = os.listdir(data_path)\n",
|
|
" \n",
|
|
" name_list = []\n",
|
|
" feature_list = []\n",
|
|
" model_.eval()\n",
|
|
" \n",
|
|
" finished_num = 0\n",
|
|
" for i in img_list:\n",
|
|
" finished_num+=1\n",
|
|
" if (finished_num%1000 == 0):\n",
|
|
" print(str(finished_num)+\"finished\")\n",
|
|
" \n",
|
|
" img = Image.open(os.path.join(data_path, i))\n",
|
|
" img = img.resize((224, 224))\n",
|
|
" img = np.array(img) / 255.\n",
|
|
" img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()\n",
|
|
" img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()\n",
|
|
" img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()\n",
|
|
" assert img.shape == (224, 224, 3)\n",
|
|
" \n",
|
|
" latent_feature = run_one_image(img, model_,arch)\n",
|
|
" \n",
|
|
" name_list.append(i)\n",
|
|
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
|
|
" \n",
|
|
" return [name_list,feature_list]\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "54acfcd7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chkpt_dir = hf_hub_download(repo_id=\"YukunZhou/RETFound_dinov2_meh\", filename=\"RETFound_dinov2_meh.pth\")\n",
|
|
"data_path = 'DATA_PATH'\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"arch='dinov2_large'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0296f74e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"[name_list,feature]=get_feature(data_path,\n",
|
|
" chkpt_dir,\n",
|
|
" device,\n",
|
|
" arch=arch)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "925d3994",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#save the feature\n",
|
|
"df_feature = pd.DataFrame(feature)\n",
|
|
"df_imgname = pd.DataFrame(name_list)\n",
|
|
"df_visualization = pd.concat([df_imgname,df_feature], axis=1)\n",
|
|
"column_name_list = []\n",
|
|
"\n",
|
|
"for i in range(1024):\n",
|
|
" column_name_list.append(\"feature_{}\".format(i))\n",
|
|
"df_visualization.columns = [\"name\"] + column_name_list\n",
|
|
"df_visualization.to_csv(\"Feature.csv\",index=False)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7f0d13a7-2b46-40eb-ab48-5f90a6aeecb5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"environment": {
|
|
"kernel": "test",
|
|
"name": "common-cu121.m123",
|
|
"type": "gcloud",
|
|
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/base-cu121:m123"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python_test (Local)",
|
|
"language": "python",
|
|
"name": "test"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|