major package upgrade&new weights

This commit is contained in:
rmaphoh
2025-02-19 12:37:05 +00:00
parent f5ab012b71
commit f0425f5526
21 changed files with 1931 additions and 2333 deletions
+6
View File
@@ -0,0 +1,6 @@
[Desktop Entry]
Encoding=UTF-8
Name=Link to
Type=Link
URL=file:///home/yukun/Moorfield/git_repo/RETFound_MAE/README.md
Icon=text-markdown
+75 -68
View File
@@ -1,8 +1,9 @@
## RETFound - A foundation model for retinal imaging ## RETFound - A foundation model for retinal imaging
Official repo for [RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae): Official repo including a series of retinal foundation models.
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae):
[New checkpoints](https://www.nature.com/articles/s41586-023-06555-x), which is based on [DINOV2](https://github.com/facebookresearch/dinov2):
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE) Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
@@ -17,6 +18,9 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
### 🎉News ### 🎉News
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_Feature.ipynb) are now online! - 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_Feature.ipynb) are now online!
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online! - 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation! - 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
@@ -29,16 +33,17 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
1. Create environment with conda: 1. Create environment with conda:
``` ```
conda create -n retfound python=3.7.5 -y conda create -n retfound python=3.11.0 -y
conda activate retfound conda activate retfound
``` ```
2. Install dependencies 2. Install dependencies
``` ```
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
git clone https://github.com/rmaphoh/RETFound_MAE/ git clone https://github.com/rmaphoh/RETFound_MAE/
cd RETFound_MAE cd RETFound_MAE
pip install -r requirement.txt pip install -r requirements.txt
``` ```
@@ -46,23 +51,51 @@ pip install -r requirement.txt
To fine tune RETFound on your own data, follow these steps: To fine tune RETFound on your own data, follow these steps:
1. Download the RETFound pre-trained weights 1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
<table><tbody> <table><tbody>
<!-- START TABLE --> <!-- START TABLE -->
<!-- TABLE HEADER --> <!-- TABLE HEADER -->
<th valign="bottom"></th> <th valign="bottom"></th>
<th valign="bottom">ViT-Large</th> <th valign="bottom">ViT-Large</th>
<th valign="bottom">Source</th>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">Colour fundus image</td> <tr><td align="left">RETFound_mae_natureCFP</td>
<td align="center"><a href="https://drive.google.com/file/d/1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE/view?usp=sharing">download</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_natureCFP">access</a></td>
<td align="center"><a href="https://www.nature.com/articles/s41586-023-06555-x">Nature RETFound paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">OCT</td> <tr><td align="left">RETFound_mae_natureOCT</td>
<td align="center"><a href="https://drive.google.com/file/d/1m6s7QYkjyjJDlpEuXm7Xp3PmjN-elfW2/view?usp=sharing">download</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_natureOCT">access</a></td>
<td align="center"><a href="https://www.nature.com/articles/s41586-023-06555-x">Nature RETFound paper</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
<td align="center">TBD</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
<td align="center">TBD</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
<td align="center">TBD</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">download</a></td>
<td align="center">TBD</a></td>
</tr> </tr>
</tbody></table> </tbody></table>
2. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md)) 2. Login in your HuggingFace account, where HuggingFace token can be [created and copied](https://huggingface.co/settings/tokens).
```
huggingface-cli login --token YOUR_HUGGINGFACE_TOKEN
```
3. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
``` ```
├── data folder ├── data folder
@@ -80,23 +113,29 @@ To fine tune RETFound on your own data, follow these steps:
├──class_c ├──class_c
``` ```
3. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training. 4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
``` ```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ model can be "RETFound_mae" or "RETFound_dinov2"
```
```
finetune can be "RETFound_mae_natureOCT", "RETFound_mae_natureCFP", "RETFound_mae_meh", "RETFound_mae_shanghai", "RETFound_dinov2_meh", and "RETFound_dinov2_shanghai".
```
```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
--model RETFound_mae \
--savemodel \
--global_pool \
--batch_size 16 \ --batch_size 16 \
--world_size 1 \ --world_size 1 \
--model vit_large_patch16 \ --epochs 100 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \ --blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \ --weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \ --nb_classes 5 \
--data_path ./IDRiD_data/ \ --data_path ./IDRiD \
--task ./finetune_IDRiD/ \ --input_size 224 \
--finetune ./RETFound_cfp_weights.pth \ --task RETFound_mae_meh-IDRiD \
--input_size 224 --finetune RETFound_mae_meh
``` ```
@@ -104,64 +143,32 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f
``` ```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
--eval --batch_size 16 \ --model RETFound_mae \
--savemodel \
--eval \
--global_pool \
--batch_size 16 \
--world_size 1 \ --world_size 1 \
--model vit_large_patch16 \ --epochs 100 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \ --blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \ --weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \ --nb_classes 5 \
--data_path ./IDRiD_data/ \ --data_path ./IDRiD \
--task ./internal_IDRiD/ \ --input_size 224 \
--resume ./finetune_IDRiD/checkpoint-best.pth \ --task RETFound_mae_meh-IDRiD \
--input_size 224 --resume ./finetune_IDRiD/checkpoint-best.pth
```
### Load the model and weights (if you want to call the model in your code)
```python
import torch
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_
# call the model
model = models_vit.__dict__['vit_large_patch16'](
num_classes=2,
drop_path_rate=0.2,
global_pool=True,
)
# load RETFound weights
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)
print("Model = %s" % str(model))
``` ```
### 📃Citation ### 📃Citation
If you find this repository useful, please consider citing this paper: If you find this repository useful, please consider citing this paper:
```
TBD
```
``` ```
@article{zhou2023foundation, @article{zhou2023foundation,
title={A foundation model for generalizable disease detection from retinal images}, title={A foundation model for generalizable disease detection from retinal images},
-210
View File
@@ -1,210 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1eae7403-f458-4f55-a557-4e045bd6f679",
"metadata": {
"id": "1eae7403-f458-4f55-a557-4e045bd6f679"
},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import models_vit"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4573e6be-935a-4106-8c06-e467552b0e3d",
"metadata": {
"id": "4573e6be-935a-4106-8c06-e467552b0e3d"
},
"outputs": [],
"source": [
"\n",
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"\n",
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
" # build model\n",
" model = models_vit.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" global_pool=True,\n",
" )\n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
" return model\n",
"\n",
"def run_one_image(img, model):\n",
" \n",
" x = torch.tensor(img)\n",
" x = x.unsqueeze(dim=0)\n",
" x = torch.einsum('nhwc->nchw', x)\n",
" \n",
" x = x.to(device, non_blocking=True)\n",
" latent = model.forward_features(x.float())\n",
" latent = torch.squeeze(latent)\n",
" \n",
" return latent\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5",
"metadata": {
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5"
},
"source": [
"### Load a pre-trained model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
"metadata": {
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
"outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model loaded.\n"
]
}
],
"source": [
"# download pre-trained RETFound \n",
"\n",
"chkpt_dir = './RETFound_cfp.pth'\n",
"model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n",
"\n",
"device = torch.device('cuda')\n",
"model_.to(device)\n",
"print('Model loaded.')\n"
]
},
{
"cell_type": "markdown",
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6",
"metadata": {
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6"
},
"source": [
"### Load images and save latent feature"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "27755296-05cc-4344-90de-a8ab3878f485",
"metadata": {
"id": "27755296-05cc-4344-90de-a8ab3878f485",
"outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70",
"tags": []
},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'Your data path'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_16866/3238108902.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# get image list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'Your data path'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mname_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'Your data path'"
]
}
],
"source": [
"# get image list\n",
"data_path = 'Your data path'\n",
"img_list = os.listdir(data_path)\n",
"\n",
"name_list = []\n",
"feature_list = []\n",
"model_.eval()\n",
"\n",
"for i in img_list:\n",
" img = Image.open(os.path.join(data_path, i))\n",
" img = img.resize((224, 224))\n",
" img = np.array(img) / 255.\n",
"\n",
" assert img.shape == (224, 224, 3)\n",
"\n",
" # normalize by mean and sd\n",
" # can use customised mean and sd for your data\n",
" img = img - imagenet_mean\n",
" img = img / imagenet_std\n",
" \n",
" latent_feature = run_one_image(img, model_)\n",
" \n",
" name_list.append(i)\n",
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a365ec24-8e29-485e-83b5-5ac1d02945bb",
"metadata": {},
"outputs": [],
"source": [
"latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n",
"latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e8bd5e6-5780-420d-9d4c-96025b265668",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"environment": {
"kernel": "python3",
"name": "common-cu110.m91",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
File diff suppressed because one or more lines are too long
+148 -208
View File
@@ -1,208 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. import os
# All rights reserved. import csv
# Partly revised by YZ @UCL&Moorfields import torch
# -------------------------------------------------------- import torch.nn as nn
import torch.nn.functional as F
import math import numpy as np
import sys import matplotlib.pyplot as plt
import csv from typing import Iterable, Optional
import os from timm.data import Mixup
import torch from timm.utils import accuracy
import torch.nn as nn from sklearn.metrics import (
import torch.nn.functional as F accuracy_score, roc_auc_score, f1_score, average_precision_score,
from timm.data import Mixup hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
from timm.utils import accuracy )
from typing import Iterable, Optional from pycm import ConfusionMatrix
import util.misc as misc import util.misc as misc
import util.lr_sched as lr_sched import util.lr_sched as lr_sched
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix
from pycm import * def train_one_epoch(
import matplotlib.pyplot as plt model: torch.nn.Module,
import numpy as np criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
def misc_measures(confusion_matrix): loss_scaler,
max_norm: float = 0,
acc = [] mixup_fn: Optional[Mixup] = None,
sensitivity = [] log_writer=None,
specificity = [] args=None
precision = [] ):
G = [] """Train the model for one epoch."""
F1_score_2 = [] model.train(True)
mcc_ = [] metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for i in range(1, confusion_matrix.shape[0]): print_freq, accum_iter = 20, args.accum_iter
cm1=confusion_matrix[i] optimizer.zero_grad()
acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1))
sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1]) if log_writer:
sensitivity.append(sensitivity_) print(f'log_dir: {log_writer.log_dir}')
specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0])
specificity.append(specificity_) for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1]) if data_iter_step % accum_iter == 0:
precision.append(precision_) lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
G.append(np.sqrt(sensitivity_*specificity_))
F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_)) samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1])) if mixup_fn:
mcc_.append(mcc) samples, targets = mixup_fn(samples, targets)
acc = np.array(acc).mean() with torch.cuda.amp.autocast():
sensitivity = np.array(sensitivity).mean() outputs = model(samples)
specificity = np.array(specificity).mean() loss = criterion(outputs, targets)
precision = np.array(precision).mean() loss_value = loss.item()
G = np.array(G).mean() loss /= accum_iter
F1_score_2 = np.array(F1_score_2).mean()
mcc_ = np.array(mcc_).mean() loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_ if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, max_lr = 0.
data_loader: Iterable, optimizer: torch.optim.Optimizer, for group in optimizer.param_groups:
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, min_lr = min(min_lr, group["lr"])
mixup_fn: Optional[Mixup] = None, log_writer=None, max_lr = max(max_lr, group["lr"])
args=None):
model.train(True) metric_logger.update(lr=max_lr)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) loss_value_reduce = misc.all_reduce_mean(loss_value)
header = 'Epoch: [{}]'.format(epoch) if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
print_freq = 20 """ We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
accum_iter = args.accum_iter """
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
optimizer.zero_grad() log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir)) metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# we use a per iteration (instead of per epoch) lr scheduler @torch.no_grad()
if data_iter_step % accum_iter == 0: def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) """Evaluate the model."""
criterion = nn.CrossEntropyLoss()
samples = samples.to(device, non_blocking=True) metric_logger = misc.MetricLogger(delimiter=" ")
targets = targets.to(device, non_blocking=True) os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
if mixup_fn is not None: model.eval()
samples, targets = mixup_fn(samples, targets) true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
with torch.cuda.amp.autocast(): for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
outputs = model(samples) images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
loss = criterion(outputs, targets) target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
loss_value = loss.item() with torch.cuda.amp.autocast():
output = model(images)
if not math.isfinite(loss_value): loss = criterion(output, target)
print("Loss is {}, stopping training".format(loss_value)) output_ = nn.Softmax(dim=1)(output)
sys.exit(1) output_label = output_.argmax(dim=1)
output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm, metric_logger.update(loss=loss.item())
parameters=model.parameters(), create_graph=False, true_onehot.extend(target_onehot.cpu().numpy())
update_grad=(data_iter_step + 1) % accum_iter == 0) pred_onehot.extend(output_onehot.detach().cpu().numpy())
if (data_iter_step + 1) % accum_iter == 0: true_labels.extend(target.cpu().numpy())
optimizer.zero_grad() pred_labels.extend(output_label.detach().cpu().numpy())
pred_softmax.extend(output_.detach().cpu().numpy())
torch.cuda.synchronize()
accuracy = accuracy_score(true_labels, pred_labels)
metric_logger.update(loss=loss_value) hamming = hamming_loss(true_onehot, pred_onehot)
min_lr = 10. jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
max_lr = 0. average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
for group in optimizer.param_groups: kappa = cohen_kappa_score(true_labels, pred_labels)
min_lr = min(min_lr, group["lr"]) f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
max_lr = max(max_lr, group["lr"]) roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
metric_logger.update(lr=max_lr) recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
loss_value_reduce = misc.all_reduce_mean(loss_value) score = (f1 + roc_auc + kappa) / 3
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: if log_writer:
""" We use epoch_1000x as the x-axis in tensorboard. for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
This calibrates different curves when batch size changes. [accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
""" log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) print(f'val loss: {metric_logger.meters["loss"].global_avg}')
log_writer.add_scalar('lr', max_lr, epoch_1000x) print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
# gather the stats from all processes f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger) metric_logger.synchronize_between_processes()
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
file_exists = os.path.isfile(results_path)
with open(results_path, 'a', newline='', encoding='utf8') as cfa:
wf = csv.writer(cfa)
@torch.no_grad() if not file_exists:
def evaluate(data_loader, model, device, task, epoch, mode, num_class): wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
criterion = torch.nn.CrossEntropyLoss() wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
metric_logger = misc.MetricLogger(delimiter=" ") if mode == 'test':
header = 'Test:' cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
if not os.path.exists(task): plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
os.makedirs(task)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
prediction_decode_list = []
prediction_list = []
true_label_decode_list = []
true_label_onehot_list = []
# switch to evaluation mode
model.eval()
for batch in metric_logger.log_every(data_loader, 10, header):
images = batch[0]
target = batch[-1]
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
true_label=F.one_hot(target.to(torch.int64), num_classes=num_class)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
prediction_softmax = nn.Softmax(dim=1)(output)
_,prediction_decode = torch.max(prediction_softmax, 1)
_,true_label_decode = torch.max(true_label, 1)
prediction_decode_list.extend(prediction_decode.cpu().detach().numpy())
true_label_decode_list.extend(true_label_decode.cpu().detach().numpy())
true_label_onehot_list.extend(true_label.cpu().detach().numpy())
prediction_list.extend(prediction_softmax.cpu().detach().numpy())
acc1,_ = accuracy(output, target, topk=(1,2))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
# gather the stats from all processes
true_label_decode_list = np.array(true_label_decode_list)
prediction_decode_list = np.array(prediction_decode_list)
confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)])
acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix)
auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro')
auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro')
metric_logger.synchronize_between_processes()
print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc))
results_path = task+'_metrics_{}.csv'.format(mode)
with open(results_path,mode='a',newline='',encoding='utf8') as cfa:
wf = csv.writer(cfa)
data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]]
for i in data2:
wf.writerow(i)
if mode=='test':
cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list)
cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib")
plt.savefig(task+'confusion_matrix_test.jpg',dpi=600,bbox_inches ='tight')
return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc
-82
View File
@@ -1,82 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import torch
import util.misc as misc
import util.lr_sched as lr_sched
def train_one_epoch(model: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = samples.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+194
View File
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"id": "0ae19951",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"from PIL import Image\n",
"import models_vit as models\n",
"np.set_printoptions(threshold=np.inf)\n",
"np.random.seed(1)\n",
"torch.manual_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "90c3d964",
"metadata": {},
"outputs": [],
"source": [
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
" \n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" \n",
" # build model\n",
" if arch=='vit_large_patch16':\n",
" model = models.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" global_pool=True,\n",
" )\n",
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
" else:\n",
" model = models.__dict__[arch](\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" args=None,\n",
" )\n",
" msg = model.load_state_dict(checkpoint['teacher'], strict=False)\n",
" return model\n",
"\n",
"def run_one_image(img, model, arch):\n",
" \n",
" x = torch.tensor(img)\n",
" x = x.unsqueeze(dim=0)\n",
" x = torch.einsum('nhwc->nchw', x)\n",
" \n",
" x = x.to(device, non_blocking=True)\n",
" latent = model.forward_features(x.float())\n",
" \n",
" if arch=='dinov2_large':\n",
" latent = latent[:, 1:, :].mean(dim=1,keepdim=True)\n",
" latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)\n",
" \n",
" latent = torch.squeeze(latent)\n",
"\n",
" return latent\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "9a250363",
"metadata": {},
"outputs": [],
"source": [
"def get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='vit_large_patch16'):\n",
" #loading model\n",
" model_ = prepare_model(chkpt_dir, arch)\n",
" model_.to(device)\n",
"\n",
" img_list = os.listdir(data_path)\n",
" \n",
" name_list = []\n",
" feature_list = []\n",
" model_.eval()\n",
" \n",
" finished_num = 0\n",
" for i in img_list:\n",
" finished_num+=1\n",
" if (finished_num%1000 == 0):\n",
" print(str(finished_num)+\"finished\")\n",
" \n",
" img = Image.open(os.path.join(data_path, i))\n",
" img = img.resize((224, 224))\n",
" img = np.array(img) / 255.\n",
" img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()\n",
" img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()\n",
" img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()\n",
" assert img.shape == (224, 224, 3)\n",
" \n",
" latent_feature = run_one_image(img, model_,arch)\n",
" \n",
" name_list.append(i)\n",
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
" \n",
" return [name_list,feature_list]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "54acfcd7",
"metadata": {},
"outputs": [],
"source": [
"chkpt_dir = '/home/jupyter/huggingface_repo/RETFound_dinov2_meh.pth'\n",
"data_path = '/home/jupyter/public_dataset/IDRiD_data/val/anoDR'\n",
"device = torch.device('cuda')"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "0296f74e",
"metadata": {},
"outputs": [],
"source": [
"[name_list,feature]=get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='dinov2_large')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "925d3994",
"metadata": {},
"outputs": [],
"source": [
"#save the feature\n",
"df_feature = pd.DataFrame(feature)\n",
"df_imgname = pd.DataFrame(name_list)\n",
"df_visualization = pd.concat([df_imgname,df_feature], axis=1)\n",
"column_name_list = []\n",
"\n",
"for i in range(1024):\n",
" column_name_list.append(\"feature_{}\".format(i))\n",
"df_visualization.columns = [\"name\"] + column_name_list\n",
"df_visualization.to_csv(\"Feature.csv\",index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0d13a7-2b46-40eb-ab48-5f90a6aeecb5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "test",
"name": "common-cu121.m123",
"type": "gcloud",
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/base-cu121:m123"
},
"kernelspec": {
"display_name": "Python_test (Local)",
"language": "python",
"name": "test"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+407 -374
View File
@@ -1,374 +1,407 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. import argparse
# All rights reserved. import datetime
# Partly revised by YZ @UCL&Moorfields import json
# --------------------------------------------------------
import numpy as np
import argparse import os
import datetime import time
import json from pathlib import Path
import numpy as np
import os import torch
import time import torch.backends.cudnn as cudnn
from pathlib import Path from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_
import torch from timm.data.mixup import Mixup
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter import models_vit as models
import util.lr_decay as lrd
import timm import util.misc as misc
from util.datasets import build_dataset
assert timm.__version__ == "0.3.2" # version check from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_ from util.misc import NativeScalerWithGradNormCount as NativeScaler
from timm.data.mixup import Mixup from huggingface_hub import hf_hub_download, login
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from engine_finetune import train_one_epoch, evaluate
import util.lr_decay as lrd import warnings
import util.misc as misc import faulthandler
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed faulthandler.enable()
from util.misc import NativeScalerWithGradNormCount as NativeScaler warnings.simplefilter(action='ignore', category=FutureWarning)
import models_vit
def get_args_parser():
from engine_finetune import train_one_epoch, evaluate parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
def get_args_parser(): parser.add_argument('--epochs', default=50, type=int)
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) parser.add_argument('--accum_iter', default=1, type=int,
parser.add_argument('--batch_size', default=64, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int) # Model parameters
parser.add_argument('--accum_iter', default=1, type=int, parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') help='Name of model to train')
parser.add_argument('--input_size', default=256, type=int,
# Model parameters help='images input size')
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
help='Name of model to train') help='Drop path rate (default: 0.1)')
parser.add_argument('--input_size', default=224, type=int, # Optimizer parameters
help='images input size') parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', parser.add_argument('--weight_decay', type=float, default=0.05,
help='Drop path rate (default: 0.1)') help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
# Optimizer parameters help='learning rate (absolute lr)')
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
help='Clip gradient norm (default: None, no clipping)') help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--weight_decay', type=float, default=0.05, parser.add_argument('--layer_decay', type=float, default=0.65,
help='weight decay (default: 0.05)') help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
help='learning rate (absolute lr)') parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='epochs to warmup LR')
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.75, # Augmentation parameters
help='layer-wise lr decay from ELECTRA/BEiT') parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='lower lr bound for cyclic schedulers that hit 0') help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', help='Label smoothing (default: 0.1)')
help='epochs to warmup LR')
# * Random Erase params
# Augmentation parameters parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Random erase prob (default: 0.25)')
help='Color jitter factor (enabled only when not using Auto/RandAug)') parser.add_argument('--remode', type=str, default='pixel',
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Random erase mode (default: "pixel")')
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), parser.add_argument('--recount', type=int, default=1,
parser.add_argument('--smoothing', type=float, default=0.1, help='Random erase count (default: 1)')
help='Label smoothing (default: 0.1)') parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', # * Mixup params
help='Random erase prob (default: 0.25)') parser.add_argument('--mixup', type=float, default=0,
parser.add_argument('--remode', type=str, default='pixel', help='mixup alpha, mixup enabled if > 0.')
help='Random erase mode (default: "pixel")') parser.add_argument('--cutmix', type=float, default=0,
parser.add_argument('--recount', type=int, default=1, help='cutmix alpha, cutmix enabled if > 0.')
help='Random erase count (default: 1)') parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
parser.add_argument('--resplit', action='store_true', default=False, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
# * Mixup params parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
parser.add_argument('--mixup', type=float, default=0, help='Probability of switching to cutmix when both mixup and cutmix enabled')
help='mixup alpha, mixup enabled if > 0.') parser.add_argument('--mixup_mode', type=str, default='batch',
parser.add_argument('--cutmix', type=float, default=0, help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, # * Finetuning params
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') parser.add_argument('--finetune', default='', type=str,
parser.add_argument('--mixup_prob', type=float, default=1.0, help='finetune from checkpoint')
help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--task', default='', type=str,
parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='finetune from checkpoint')
help='Probability of switching to cutmix when both mixup and cutmix enabled') parser.add_argument('--global_pool', action='store_true')
parser.add_argument('--mixup_mode', type=str, default='batch', parser.set_defaults(global_pool=True)
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# * Finetuning params
parser.add_argument('--finetune', default='',type=str, # Dataset parameters
help='finetune from checkpoint') parser.add_argument('--data_path', default='./data/', type=str,
parser.add_argument('--task', default='',type=str, help='dataset path')
help='finetune from checkpoint') parser.add_argument('--nb_classes', default=8, type=int,
parser.add_argument('--global_pool', action='store_true') help='number of the classification types')
parser.set_defaults(global_pool=True) parser.add_argument('--output_dir', default='./output_dir',
parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='path where to save, empty for no saving')
help='Use class token instead of global pool for classification') parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
# Dataset parameters parser.add_argument('--device', default='cuda',
parser.add_argument('--data_path', default='/home/jupyter/Mor_DR_data/data/data/IDRID/Disease_Grading/', type=str, help='device to use for training / testing')
help='dataset path') parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--nb_classes', default=1000, type=int, parser.add_argument('--resume', default='',
help='number of the classification types') help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
parser.add_argument('--output_dir', default='./output_dir', help='start epoch')
help='path where to save, empty for no saving') parser.add_argument('--eval', action='store_true',
parser.add_argument('--log_dir', default='./output_dir', help='Perform evaluation only')
help='path where to tensorboard log') parser.add_argument('--dist_eval', action='store_true', default=False,
parser.add_argument('--device', default='cuda', help='Enabling distributed evaluation (recommended during training for faster monitor')
help='device to use for training / testing') parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--seed', default=0, type=int) parser.add_argument('--pin_mem', action='store_true',
parser.add_argument('--resume', default='', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
help='resume from checkpoint') parser.set_defaults(pin_mem=True)
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', # distributed training parameters
help='start epoch') parser.add_argument('--world_size', default=1, type=int,
parser.add_argument('--eval', action='store_true', help='number of distributed processes')
help='Perform evaluation only') parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_eval', action='store_true', default=False, parser.add_argument('--dist_on_itp', action='store_true')
help='Enabling distributed evaluation (recommended during training for faster monitor') parser.add_argument('--dist_url', default='env://',
parser.add_argument('--num_workers', default=10, type=int) help='url used to set up distributed training')
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') # fine-tuning parameters
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.add_argument('--savemodel', action='store_true', default=True,
parser.set_defaults(pin_mem=True) help='Save model')
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
# distributed training parameters parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
parser.add_argument('--world_size', default=1, type=int, parser.add_argument('--datasets_seed', default=2026, type=int)
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int) return parser
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training') def main(args, criterion):
if args.resume and not args.eval:
return parser resume = args.resume
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
def main(args): args = checkpoint['args']
misc.init_distributed_mode(args) args.resume = resume
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) misc.init_distributed_mode(args)
print("{}".format(args).replace(', ', ',\n'))
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
device = torch.device(args.device) print("{}".format(args).replace(', ', ',\n'))
# fix the seed for reproducibility device = torch.device(args.device)
seed = args.seed + misc.get_rank()
torch.manual_seed(seed) # fix the seed for reproducibility
np.random.seed(seed) seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
cudnn.benchmark = True np.random.seed(seed)
dataset_train = build_dataset(is_train='train', args=args) cudnn.benchmark = True
dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args) if args.model=='RETFound_mae':
model = models.__dict__[args.model](
if True: # args.distributed: img_size=args.input_size,
num_tasks = misc.get_world_size() num_classes=args.nb_classes,
global_rank = misc.get_rank() drop_path_rate=args.drop_path,
sampler_train = torch.utils.data.DistributedSampler( global_pool=args.global_pool,
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True )
) else:
print("Sampler_train = %s" % str(sampler_train)) model = models.__dict__[args.model](
if args.dist_eval: num_classes=args.nb_classes,
if len(dataset_val) % num_tasks != 0: drop_path_rate=args.drop_path,
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' args=args,
'This will slightly alter validation results as extra duplicate entries are added to achieve ' )
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler( if args.finetune and not args.eval:
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
else: print(f"Downloading pre-trained weights from: {args.finetune}")
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
checkpoint_path = hf_hub_download(
if args.dist_eval: repo_id=f'YukunZhou/{args.finetune}',
if len(dataset_test) % num_tasks != 0: filename=f'{args.finetune}.pth',
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' )
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.') checkpoint = torch.load(checkpoint_path, map_location='cpu')
sampler_test = torch.utils.data.DistributedSampler( print("Load pre-trained checkpoint from: %s" % args.finetune)
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias
else: if args.model!='RETFound_mae':
sampler_test = torch.utils.data.SequentialSampler(dataset_test) checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True) checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
log_writer = SummaryWriter(log_dir=args.log_dir+args.task) checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
else: checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
log_writer = None
state_dict = model.state_dict()
data_loader_train = torch.utils.data.DataLoader( for k in ['head.weight', 'head.bias']:
dataset_train, sampler=sampler_train, if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
batch_size=args.batch_size, print(f"Removing key {k} from pretrained checkpoint")
num_workers=args.num_workers, del checkpoint_model[k]
pin_memory=args.pin_mem,
drop_last=True, # interpolate position embedding
) interpolate_pos_embed(model, checkpoint_model)
data_loader_val = torch.utils.data.DataLoader( # load pre-trained model
dataset_val, sampler=sampler_val, msg = model.load_state_dict(checkpoint_model, strict=False)
batch_size=args.batch_size,
num_workers=args.num_workers, trunc_normal_(model.head.weight, std=2e-5)
pin_memory=args.pin_mem,
drop_last=False dataset_train = build_dataset(is_train='train', args=args)
) dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test,
batch_size=args.batch_size, if True: # args.distributed:
num_workers=args.num_workers, num_tasks = misc.get_world_size()
pin_memory=args.pin_mem, global_rank = misc.get_rank()
drop_last=False if not args.eval:
) sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
mixup_fn = None print("Sampler_train = %s" % str(sampler_train))
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if args.dist_eval:
if mixup_active: if len(dataset_val) % num_tasks != 0:
print("Mixup is activated!") print(
mixup_fn = Mixup( 'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 'This will slightly alter validation results as extra duplicate entries are added to achieve '
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 'equal num of samples per-process.')
label_smoothing=args.smoothing, num_classes=args.nb_classes) sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank,
model = models_vit.__dict__[args.model]( shuffle=True) # shuffle=True to reduce monitor bias
img_size=args.input_size, else:
num_classes=args.nb_classes, sampler_val = torch.utils.data.SequentialSampler(dataset_val)
drop_path_rate=args.drop_path,
global_pool=args.global_pool, if args.dist_eval:
) if len(dataset_test) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
if args.finetune and not args.eval: 'This will slightly alter validation results as extra duplicate entries are added to achieve '
checkpoint = torch.load(args.finetune, map_location='cpu') 'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler(
print("Load pre-trained checkpoint from: %s" % args.finetune) dataset_test, num_replicas=num_tasks, rank=global_rank,
checkpoint_model = checkpoint['model'] shuffle=True) # shuffle=True to reduce monitor bias
state_dict = model.state_dict() else:
for k in ['head.weight', 'head.bias']: sampler_test = torch.utils.data.SequentialSampler(dataset_test)
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint") if global_rank == 0 and args.log_dir is not None and not args.eval:
del checkpoint_model[k] os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
# interpolate position embedding else:
interpolate_pos_embed(model, checkpoint_model) log_writer = None
# load pre-trained model if not args.eval:
msg = model.load_state_dict(checkpoint_model, strict=False) data_loader_train = torch.utils.data.DataLoader(
print(msg) dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
if args.global_pool: num_workers=args.num_workers,
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} pin_memory=args.pin_mem,
else: drop_last=True,
assert set(msg.missing_keys) == {'head.weight', 'head.bias'} )
# manually initialize fc layer print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
trunc_normal_(model.head.weight, std=2e-5)
data_loader_val = torch.utils.data.DataLoader(
model.to(device) dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
model_without_ddp = model num_workers=args.num_workers,
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) pin_memory=args.pin_mem,
drop_last=False
print("Model = %s" % str(model_without_ddp)) )
print('number of params (M): %.2f' % (n_parameters / 1.e6))
data_loader_test = torch.utils.data.DataLoader(
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() dataset_test, sampler=sampler_test,
batch_size=args.batch_size,
if args.lr is None: # only base_lr is specified num_workers=args.num_workers,
args.lr = args.blr * eff_batch_size / 256 pin_memory=args.pin_mem,
drop_last=False
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) )
print("actual lr: %.2e" % args.lr)
mixup_fn = None
print("accumulate grad iterations: %d" % args.accum_iter) mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
print("effective batch size: %d" % eff_batch_size) if mixup_active:
print("Mixup is activated!")
if args.distributed: mixup_fn = Mixup(
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
model_without_ddp = model.module prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
# build optimizer with layer-wise lr decay (lrd)
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, if args.resume and args.eval:
no_weight_decay_list=model_without_ddp.no_weight_decay(), checkpoint = torch.load(args.resume, map_location='cpu')
layer_decay=args.layer_decay print("Load checkpoint from: %s" % args.resume)
) model.load_state_dict(checkpoint['model'])
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler() model.to(device)
model_without_ddp = model
if mixup_fn is not None:
# smoothing is handled with mixup label transform n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
criterion = SoftTargetCrossEntropy() print('number of model params (M): %.2f' % (n_parameters / 1.e6))
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
else:
criterion = torch.nn.CrossEntropyLoss() if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("criterion = %s" % str(criterion))
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) print("actual lr: %.2e" % args.lr)
if args.eval: print("accumulate grad iterations: %d" % args.accum_iter)
test_stats,auc_roc = evaluate(data_loader_test, model, device, args.task, epoch=0, mode='test',num_class=args.nb_classes) print("effective batch size: %d" % eff_batch_size)
exit(0)
if args.distributed:
print(f"Start training for {args.epochs} epochs") model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
start_time = time.time() model_without_ddp = model.module
max_accuracy = 0.0
max_auc = 0.0 no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
for epoch in range(args.start_epoch, args.epochs): param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
if args.distributed: no_weight_decay_list=no_weight_decay,
data_loader_train.sampler.set_epoch(epoch) layer_decay=args.layer_decay
train_stats = train_one_epoch( )
model, criterion, data_loader_train, optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
optimizer, device, epoch, loss_scaler, loss_scaler = NativeScaler()
args.clip_grad, mixup_fn,
log_writer=log_writer, print("criterion = %s" % str(criterion))
args=args
) misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
val_stats,val_auc_roc = evaluate(data_loader_val, model, device,args.task,epoch, mode='val',num_class=args.nb_classes) if args.eval:
if max_auc<val_auc_roc: if 'epoch' in checkpoint:
max_auc = val_auc_roc print("Test with the best model at epoch = %d" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
if args.output_dir: num_class=args.nb_classes, log_writer=log_writer)
misc.save_model( exit(0)
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch) print(f"Start training for {args.epochs} epochs")
start_time = time.time()
if log_writer is not None: max_score = 0.0
log_writer.add_scalar('perf/val_acc1', val_stats['acc1'], epoch) best_epoch = 0
log_writer.add_scalar('perf/val_auc', val_auc_roc, epoch) for epoch in range(args.start_epoch, args.epochs):
log_writer.add_scalar('perf/val_loss', val_stats['loss'], epoch) if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch, train_stats = train_one_epoch(
'n_parameters': n_parameters} model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
if args.output_dir and misc.is_main_process(): args.clip_grad, mixup_fn,
if log_writer is not None: log_writer=log_writer,
log_writer.flush() args=args
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: )
f.write(json.dumps(log_stats) + "\n")
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
total_time = time.time() - start_time if max_score < val_score:
total_time_str = str(datetime.timedelta(seconds=int(total_time))) max_score = val_score
print('Training time {}'.format(total_time_str)) best_epoch = epoch
state_dict_best = torch.load(args.task+'checkpoint-best.pth', map_location='cpu') if args.output_dir and args.savemodel:
model_without_ddp.load_state_dict(state_dict_best['model']) misc.save_model(
test_stats,auc_roc = evaluate(data_loader_test, model_without_ddp, device,args.task,epoch=0, mode='test',num_class=args.nb_classes) args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, mode='best')
if __name__ == '__main__': print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
args = get_args_parser()
args = args.parse_args()
if epoch == (args.epochs - 1):
if args.output_dir: checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
Path(args.output_dir).mkdir(parents=True, exist_ok=True) model.load_state_dict(checkpoint['model'], strict=False)
main(args) model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
criterion = torch.nn.CrossEntropyLoss()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion)
+414
View File
@@ -0,0 +1,414 @@
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
import models_vit as models
import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate
import warnings
import faulthandler
faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning)
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=256, type=int,
help='images input size')
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: 0.1)')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.65,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--task', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='./data/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=8, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# fine-tuning parameters
parser.add_argument('--savemodel', action='store_true', default=True,
help='Save model')
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
parser.add_argument('--datasets_seed', default=2026, type=int)
return parser
def main(args, criterion):
if args.resume and not args.eval:
resume = args.resume
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
args = checkpoint['args']
args.resume = resume
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
if args.model=='RETFound_mae':
model = models.__dict__[args.model](
img_size=args.input_size,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
else:
model = models.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
args=args,
)
if args.finetune and not args.eval:
print(f"Downloading pre-trained weights from: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}',
filename=f'{args.finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args)
dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if not args.eval:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print(
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval:
if len(dataset_test) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else:
log_writer = None
if not args.eval:
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
if args.resume and args.eval:
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
model.load_state_dict(checkpoint['model'])
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
model_without_ddp = model.module
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler()
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
for name, param in model.named_parameters():
if 'head' in name:
param.requires_grad = True
else:
param.requires_grad = False
if args.eval:
if 'epoch' in checkpoint:
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
num_class=args.nb_classes, log_writer=log_writer)
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_score = 0.0
best_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn,
log_writer=log_writer,
args=args
)
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score:
max_score = val_score
best_epoch = epoch
if args.output_dir and args.savemodel:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, mode='best')
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
criterion = torch.nn.CrossEntropyLoss()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion)
-221
View File
@@ -1,221 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
assert timm.__version__ == "0.3.2" # version check
import timm.optim.optim_factory as optim_factory
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
import models_mae
from engine_pretrain import train_one_epoch
def get_args_parser():
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
parser.add_argument('--batch_size', default=64, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=400, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--mask_ratio', default=0.75, type=float,
help='Masking ratio (percentage of removed patches).')
parser.add_argument('--norm_pix_loss', action='store_true',
help='Use (per-patch) normalized pixels as targets for computing loss')
parser.set_defaults(norm_pix_loss=False)
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
help='epochs to warmup LR')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# simple augmentation
transform_train = transforms.Compose([
transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
print(dataset_train)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
# define the model
model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
model.to(device)
model_without_ddp = model
print("Model = %s" % str(model_without_ddp))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs):
misc.save_model_pretrain(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
-226
View File
@@ -1,226 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from util.pos_embed import get_2d_sincos_pos_embed
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# set recommended archs
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
+69 -55
View File
@@ -1,55 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# Partly revised by YZ @UCL&Moorfields # Partly revised by YZ @UCL&Moorfields
# -------------------------------------------------------- # --------------------------------------------------------
from functools import partial from functools import partial
import torch import timm.models.vision_transformer
import torch.nn as nn import torch
import torch.nn as nn
import timm.models.vision_transformer import torch.nn.functional as F
from torch import Tensor
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" Vision Transformer with support for global average pooling class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
""" """ Vision Transformer with support for global average pooling
def __init__(self, global_pool=False, **kwargs): """
super(VisionTransformer, self).__init__(**kwargs) def __init__(self, global_pool=False, **kwargs):
super(VisionTransformer, self).__init__(**kwargs)
self.global_pool = global_pool
if self.global_pool: self.global_pool = global_pool
norm_layer = kwargs['norm_layer'] if self.global_pool:
embed_dim = kwargs['embed_dim'] norm_layer = kwargs['norm_layer']
self.fc_norm = norm_layer(embed_dim) embed_dim = kwargs['embed_dim']
self.fc_norm = norm_layer(embed_dim)
del self.norm # remove the original norm
del self.norm # remove the original norm
def forward_features(self, x):
B = x.shape[0] def forward_features(self, x):
x = self.patch_embed(x) B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1) cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = x + self.pos_embed x = torch.cat((cls_tokens, x), dim=1)
x = self.pos_drop(x) x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x) for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token if self.global_pool:
outcome = self.fc_norm(x) x = x[:, 1:, :].mean(dim=1,keepdim=True) # global pool without cls token
else: outcome = self.fc_norm(x)
x = self.norm(x) else:
outcome = x[:, 0] x = self.norm(x)
outcome = x[:, 0]
return outcome
return outcome
def vit_large_patch16(**kwargs):
model = VisionTransformer( def RETFound_mae(**kwargs):
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, model = VisionTransformer(
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
return model norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def RETFound_dinov2(args, **kwargs):
model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m',
pretrained=True,
img_size=224,
**kwargs
)
return model
Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

-18
View File
@@ -1,18 +0,0 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cu111
torchvision==0.9.1+cu111
torchaudio==0.8.1
opencv-python==4.5.3.56
pandas==0.25.3
Pillow==8.3.1
protobuf==3.17.3
pycm==3.2
pydicom==2.3.0
scikit-image==0.17.2
scikit-learn==0.24.2
scipy==1.5.4
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
timm==0.3.2
tqdm==4.62.1
+11
View File
@@ -0,0 +1,11 @@
opencv-python~=4.9.0.80
Pillow~=10.2.0
pycm~=4.0
scikit-learn~=1.4.2
timm~=0.9.2
numpy~=1.26.4
matplotlib~=3.8.4
scikit-multilearn~=0.2.0
huggingface-hub~=0.23.4
tensorboard~=2.17.0
View File
+53 -54
View File
@@ -1,54 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# Partly revised by YZ @UCL&Moorfields # Partly revised by YZ @UCL&Moorfields
# -------------------------------------------------------- # --------------------------------------------------------
import os import os
from torchvision import datasets, transforms from torchvision import datasets, transforms
from timm.data import create_transform from timm.data import create_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args): def build_dataset(is_train, args):
transform = build_transform(is_train, args)
transform = build_transform(is_train, args) root = os.path.join(args.data_path, is_train)
root = os.path.join(args.data_path, is_train) dataset = datasets.ImageFolder(root, transform=transform)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset
return dataset
def build_transform(is_train, args):
def build_transform(is_train, args): mean = IMAGENET_DEFAULT_MEAN
mean = IMAGENET_DEFAULT_MEAN std = IMAGENET_DEFAULT_STD
std = IMAGENET_DEFAULT_STD # train transform
# train transform if is_train == 'train':
if is_train=='train': # this should always dispatch to transforms_imagenet_train
# this should always dispatch to transforms_imagenet_train transform = create_transform(
transform = create_transform( input_size=args.input_size,
input_size=args.input_size, is_training=True,
is_training=True, color_jitter=args.color_jitter,
color_jitter=args.color_jitter, auto_augment=args.aa,
auto_augment=args.aa, interpolation='bicubic',
interpolation='bicubic', re_prob=args.reprob,
re_prob=args.reprob, re_mode=args.remode,
re_mode=args.remode, re_count=args.recount,
re_count=args.recount, mean=mean,
mean=mean, std=std,
std=std, )
) return transform
return transform
# eval transform
# eval transform t = []
t = [] if args.input_size <= 224:
if args.input_size <= 224: crop_pct = 224 / 256
crop_pct = 224 / 256 else:
else: crop_pct = 1.0
crop_pct = 1.0 size = int(args.input_size / crop_pct)
size = int(args.input_size / crop_pct) t.append(
t.append( transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), )
) t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.CenterCrop(args.input_size)) t.append(transforms.ToTensor())
t.append(transforms.ToTensor()) t.append(transforms.Normalize(mean, std))
t.append(transforms.Normalize(mean, std)) return transforms.Compose(t)
return transforms.Compose(t)
+73 -69
View File
@@ -1,70 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# Partly revised by YZ @UCL&Moorfields # Partly revised by YZ @UCL&Moorfields
# -------------------------------------------------------- # --------------------------------------------------------
import json import json
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
""" """
Parameter groups for layer-wise lr decay Parameter groups for layer-wise lr decay
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
""" """
param_group_names = {} param_group_names = {}
param_groups = {} param_groups = {}
num_layers = len(model.blocks) + 1 if hasattr(model, 'blocks'):
num_layers = len(model.blocks) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) else:
# use the number of layers in the ResNet model as a default value
for n, p in model.named_parameters(): num_layers = len(model.layer1) + len(model.layer2) + len(model.layer3) + len(model.layer4) + 1
if not p.requires_grad:
continue layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
# no decay: all 1D parameters and model specific ones for n, p in model.named_parameters():
if p.ndim == 1 or n in no_weight_decay_list: if not p.requires_grad:
g_decay = "no_decay" continue
this_decay = 0.
else: # no decay: all 1D parameters and model specific ones
g_decay = "decay" if p.ndim == 1 or n in no_weight_decay_list:
this_decay = weight_decay g_decay = "no_decay"
this_decay = 0.
layer_id = get_layer_id_for_vit(n, num_layers) else:
group_name = "layer_%d_%s" % (layer_id, g_decay) g_decay = "decay"
this_decay = weight_decay
if group_name not in param_group_names:
this_scale = layer_scales[layer_id] layer_id = get_layer_id_for_vit(n, num_layers)
group_name = "layer_%d_%s" % (layer_id, g_decay)
param_group_names[group_name] = {
"lr_scale": this_scale, if group_name not in param_group_names:
"weight_decay": this_decay, this_scale = layer_scales[layer_id]
"params": [],
} param_group_names[group_name] = {
param_groups[group_name] = { "lr_scale": this_scale,
"lr_scale": this_scale, "weight_decay": this_decay,
"weight_decay": this_decay, "params": [],
"params": [], }
} param_groups[group_name] = {
"lr_scale": this_scale,
param_group_names[group_name]["params"].append(n) "weight_decay": this_decay,
param_groups[group_name]["params"].append(p) "params": [],
}
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
param_group_names[group_name]["params"].append(n)
return list(param_groups.values()) param_groups[group_name]["params"].append(p)
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
def get_layer_id_for_vit(name, num_layers):
""" return list(param_groups.values())
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
""" def get_layer_id_for_vit(name, num_layers):
if name in ['cls_token', 'pos_embed']: """
return 0 Assign a parameter with its layer id
elif name.startswith('patch_embed'): Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
return 0 """
elif name.startswith('blocks'): if name in ['cls_token', 'pos_embed']:
return int(name.split('.')[1]) + 1 return 0
else: elif name.startswith('patch_embed'):
return 0
elif name.startswith('blocks'):
return int(name.split('.')[1]) + 1
else:
return num_layers return num_layers
+20 -20
View File
@@ -1,20 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# Partly revised by YZ @UCL&Moorfields # Partly revised by YZ @UCL&Moorfields
# -------------------------------------------------------- # --------------------------------------------------------
import math import math
def adjust_learning_rate(optimizer, epoch, args): def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup""" """Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs: if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs lr = args.lr * epoch / args.warmup_epochs
else: else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
if "lr_scale" in param_group: if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"] param_group["lr"] = lr * param_group["lr_scale"]
else: else:
param_group["lr"] = lr param_group["lr"] = lr
return lr return lr
+369 -357
View File
@@ -1,357 +1,369 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# Partly revised by YZ @UCL&Moorfields # Partly revised by YZ @UCL&Moorfields
# -------------------------------------------------------- # --------------------------------------------------------
import builtins import builtins
import datetime import datetime
import os import os
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from pathlib import Path from pathlib import Path
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._six import inf from math import inf
class SmoothedValue(object): class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a """Track a series of values and provide access to smoothed values over a
window or the global series average. window or the global series average.
""" """
def __init__(self, window_size=20, fmt=None): def __init__(self, window_size=20, fmt=None):
if fmt is None: if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})" fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size) self.deque = deque(maxlen=window_size)
self.total = 0.0 self.total = 0.0
self.count = 0 self.count = 0
self.fmt = fmt self.fmt = fmt
def update(self, value, n=1): def update(self, value, n=1):
self.deque.append(value) self.deque.append(value)
self.count += n self.count += n
self.total += value * n self.total += value * n
def synchronize_between_processes(self): def synchronize_between_processes(self):
""" """
Warning: does not synchronize the deque! Warning: does not synchronize the deque!
""" """
if not is_dist_avail_and_initialized(): if not is_dist_avail_and_initialized():
return return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier() dist.barrier()
dist.all_reduce(t) dist.all_reduce(t)
t = t.tolist() t = t.tolist()
self.count = int(t[0]) self.count = int(t[0])
self.total = t[1] self.total = t[1]
@property @property
def median(self): def median(self):
d = torch.tensor(list(self.deque)) d = torch.tensor(list(self.deque))
return d.median().item() return d.median().item()
@property @property
def avg(self): def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32) d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item() return d.mean().item()
@property @property
def global_avg(self): def global_avg(self):
return self.total / self.count return self.total / self.count
@property @property
def max(self): def max(self):
return max(self.deque) return max(self.deque)
@property @property
def value(self): def value(self):
return self.deque[-1] return self.deque[-1]
def __str__(self): def __str__(self):
return self.fmt.format( return self.fmt.format(
median=self.median, median=self.median,
avg=self.avg, avg=self.avg,
global_avg=self.global_avg, global_avg=self.global_avg,
max=self.max, max=self.max,
value=self.value) value=self.value)
class MetricLogger(object): class MetricLogger(object):
def __init__(self, delimiter="\t"): def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue) self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter self.delimiter = delimiter
def update(self, **kwargs): def update(self, **kwargs):
for k, v in kwargs.items(): for k, v in kwargs.items():
if v is None: if v is None:
continue continue
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
v = v.item() v = v.item()
assert isinstance(v, (float, int)) assert isinstance(v, (float, int))
self.meters[k].update(v) self.meters[k].update(v)
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in self.meters: if attr in self.meters:
return self.meters[attr] return self.meters[attr]
if attr in self.__dict__: if attr in self.__dict__:
return self.__dict__[attr] return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format( raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr)) type(self).__name__, attr))
def __str__(self): def __str__(self):
loss_str = [] loss_str = []
for name, meter in self.meters.items(): for name, meter in self.meters.items():
loss_str.append( loss_str.append(
"{}: {}".format(name, str(meter)) "{}: {}".format(name, str(meter))
) )
return self.delimiter.join(loss_str) return self.delimiter.join(loss_str)
def synchronize_between_processes(self): def synchronize_between_processes(self):
for meter in self.meters.values(): for meter in self.meters.values():
meter.synchronize_between_processes() meter.synchronize_between_processes()
def add_meter(self, name, meter): def add_meter(self, name, meter):
self.meters[name] = meter self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None): def log_every(self, iterable, print_freq, header=None):
i = 0 i = 0
if not header: if not header:
header = '' header = ''
start_time = time.time() start_time = time.time()
end = time.time() end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}') iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd' space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [ log_msg = [
header, header,
'[{0' + space_fmt + '}/{1}]', '[{0' + space_fmt + '}/{1}]',
'eta: {eta}', 'eta: {eta}',
'{meters}', '{meters}',
'time: {time}', 'time: {time}',
'data: {data}' 'data: {data}'
] ]
if torch.cuda.is_available(): if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}') log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg) log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0 MB = 1024.0 * 1024.0
for obj in iterable: for obj in iterable:
data_time.update(time.time() - end) data_time.update(time.time() - end)
yield obj yield obj
iter_time.update(time.time() - end) iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1: if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available(): if torch.cuda.is_available():
print(log_msg.format( print(log_msg.format(
i, len(iterable), eta=eta_string, i, len(iterable), eta=eta_string,
meters=str(self), meters=str(self),
time=str(iter_time), data=str(data_time), time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB)) memory=torch.cuda.max_memory_allocated() / MB))
else: else:
print(log_msg.format( print(log_msg.format(
i, len(iterable), eta=eta_string, i, len(iterable), eta=eta_string,
meters=str(self), meters=str(self),
time=str(iter_time), data=str(data_time))) time=str(iter_time), data=str(data_time)))
i += 1 i += 1
end = time.time() end = time.time()
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format( print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable))) header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master): def setup_for_distributed(is_master):
""" """
This function disables printing when not in master process This function disables printing when not in master process
""" """
builtin_print = builtins.print builtin_print = builtins.print
def print(*args, **kwargs): def print(*args, **kwargs):
force = kwargs.pop('force', False) force = kwargs.pop('force', False)
force = force or (get_world_size() > 8) force = force or (get_world_size() > 8)
if is_master or force: if is_master or force:
now = datetime.datetime.now().time() now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs) builtin_print(*args, **kwargs)
builtins.print = print builtins.print = print
def is_dist_avail_and_initialized(): def is_dist_avail_and_initialized():
if not dist.is_available(): if not dist.is_available():
return False return False
if not dist.is_initialized(): if not dist.is_initialized():
return False return False
return True return True
def get_world_size(): def get_world_size():
if not is_dist_avail_and_initialized(): if not is_dist_avail_and_initialized():
return 1 return 1
return dist.get_world_size() return dist.get_world_size()
def get_rank(): def get_rank():
if not is_dist_avail_and_initialized(): if not is_dist_avail_and_initialized():
return 0 return 0
return dist.get_rank() return dist.get_rank()
def is_main_process(): def is_main_process():
return get_rank() == 0 return get_rank() == 0
def save_on_master(*args, **kwargs): def save_on_master(*args, **kwargs):
if is_main_process(): if is_main_process():
torch.save(*args, **kwargs) torch.save(*args, **kwargs)
def init_distributed_mode(args): def init_distributed_mode(args):
if args.dist_on_itp: if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu) os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank) os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"]) args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE']) args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK']) args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ: elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID']) args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count() args.gpu = args.rank % torch.cuda.device_count()
else: else:
print('Not using distributed mode') print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack setup_for_distributed(is_master=True) # hack
args.distributed = False args.distributed = False
return return
args.distributed = True args.distributed = True
torch.cuda.set_device(args.gpu) torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl' args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format( print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True) args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank) world_size=args.world_size, rank=args.rank)
torch.distributed.barrier() torch.distributed.barrier()
setup_for_distributed(args.rank == 0) setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount: class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler" state_dict_key = "amp_scaler"
def __init__(self): def __init__(self):
self._scaler = torch.cuda.amp.GradScaler() self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph) self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad: if update_grad:
if clip_grad is not None: if clip_grad is not None:
assert parameters is not None assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else: else:
self._scaler.unscale_(optimizer) self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters) norm = get_grad_norm_(parameters)
self._scaler.step(optimizer) self._scaler.step(optimizer)
self._scaler.update() self._scaler.update()
else: else:
norm = None norm = None
return norm return norm
def state_dict(self): def state_dict(self):
return self._scaler.state_dict() return self._scaler.state_dict()
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict) self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None] parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type) norm_type = float(norm_type)
if len(parameters) == 0: if len(parameters) == 0:
return torch.tensor(0.) return torch.tensor(0.)
device = parameters[0].grad.device device = parameters[0].grad.device
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else: else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
return total_norm norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir) def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, mode):
epoch_name = str(epoch) output_dir = Path(args.output_dir)
if loss_scaler is not None: epoch_name = str(epoch)
checkpoint_paths = [args.task+'checkpoint-best.pth'] os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
for checkpoint_path in checkpoint_paths: if loss_scaler is not None:
to_save = { if mode == 'best':
'model': model_without_ddp.state_dict(), checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-best.pth')]
'optimizer': optimizer.state_dict(), else:
'epoch': epoch, checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-latest.pth')]
'scaler': loss_scaler.state_dict(), for checkpoint_path in checkpoint_paths:
'args': args, if mode == 'best':
} to_save = {
'model': model_without_ddp.state_dict(),
save_on_master(to_save, checkpoint_path) 'epoch': epoch,
else: 'args': args, }
client_state = {'epoch': epoch} else:
model.save_checkpoint(save_dir=args.task, tag="checkpoint-best", client_state=client_state) if epoch == args.epochs - 1:
to_save = {
'model': model_without_ddp.state_dict(),
def save_model_pretrain(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 'args': args, }
output_dir = Path(args.output_dir) else:
epoch_name = str(epoch) to_save = {
if loss_scaler is not None: 'model': model_without_ddp.state_dict(),
print(model_without_ddp.state_dict().keys()) 'optimizer': optimizer.state_dict(),
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 'epoch': epoch,
for checkpoint_path in checkpoint_paths: 'scaler': loss_scaler.state_dict(),
to_save = { 'args': args,
'model': model_without_ddp.state_dict(), }
'optimizer': optimizer.state_dict(),
'epoch': epoch, save_on_master(to_save, checkpoint_path)
'scaler': loss_scaler.state_dict(), else:
'args': args, if mode == 'best':
} to_save = {
'model': model_without_ddp.state_dict(),
save_on_master(to_save, checkpoint_path) 'epoch': epoch, }
else: torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-best.pth"))
client_state = {'epoch': epoch} else:
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) if epoch == args.epochs - 1:
to_save = {
'model': model_without_ddp.state_dict(), }
else:
def load_model(args, model_without_ddp, optimizer, loss_scaler): to_save = {
if args.resume: 'model': model_without_ddp.state_dict(),
if args.resume.startswith('https'): 'optimizer': optimizer.state_dict(),
checkpoint = torch.hub.load_state_dict_from_url( 'epoch': epoch,
args.resume, map_location='cpu', check_hash=True) 'args': args,
else: }
checkpoint = torch.load(args.resume, map_location='cpu') torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-latest.pth"))
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): def load_model(args, model_without_ddp, optimizer, loss_scaler):
optimizer.load_state_dict(checkpoint['optimizer']) if args.resume:
args.start_epoch = checkpoint['epoch'] + 1 if args.resume.startswith('https'):
if 'scaler' in checkpoint: checkpoint = torch.hub.load_state_dict_from_url(
loss_scaler.load_state_dict(checkpoint['scaler']) args.resume, map_location='cpu', check_hash=True)
print("With optim & sched!") else:
checkpoint = torch.load(args.resume, map_location='cpu')
if 'model' in checkpoint:
def all_reduce_mean(x): checkpoint_model = checkpoint['model']
world_size = get_world_size() else:
if world_size > 1: checkpoint_model = checkpoint
x_reduce = torch.tensor(x).cuda() model_without_ddp.load_state_dict(checkpoint_model, strict=False)
dist.all_reduce(x_reduce) print("Resume checkpoint %s" % args.resume)
x_reduce /= world_size if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
return x_reduce.item() optimizer.load_state_dict(checkpoint['optimizer'])
else: args.start_epoch = checkpoint['epoch'] + 1
return x if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
+92 -92
View File
@@ -1,92 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# Partly revised by YZ @UCL&Moorfields # Partly revised by YZ @UCL&Moorfields
# -------------------------------------------------------- # --------------------------------------------------------
import numpy as np import numpy as np
import torch import torch
# -------------------------------------------------------- # --------------------------------------------------------
# 2D sine-cosine position embedding # 2D sine-cosine position embedding
# References: # References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3 # MoCo v3: https://github.com/facebookresearch/moco-v3
# -------------------------------------------------------- # --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
""" """
grid_size: int of the grid height and width grid_size: int of the grid height and width
return: return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
""" """
grid_h = np.arange(grid_size, dtype=np.float32) grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size]) grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token: if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h # use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
""" """
embed_dim: output dimension for each position embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) pos: a list of positions to be encoded: size (M,)
out: (M, D) out: (M, D)
""" """
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float) omega = np.arange(embed_dim // 2, dtype=np.float)
omega /= embed_dim / 2. omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,) omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,) pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2) emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb return emb
# -------------------------------------------------------- # --------------------------------------------------------
# Interpolate position embeddings for high-resolution # Interpolate position embeddings for high-resolution
# References: # References:
# DeiT: https://github.com/facebookresearch/deit # DeiT: https://github.com/facebookresearch/deit
# -------------------------------------------------------- # --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model): def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model: if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed'] pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1] embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding # height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding # height (== width) for the new position embedding
new_size = int(num_patches ** 0.5) new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged # class_token and dist_token are kept unchanged
if orig_size != new_size: if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated # only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate( pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed checkpoint_model['pos_embed'] = new_pos_embed