major package upgrade&new weights

This commit is contained in:
rmaphoh
2025-02-19 12:37:05 +00:00
parent f5ab012b71
commit f0425f5526
21 changed files with 1931 additions and 2333 deletions
+6
View File
@@ -0,0 +1,6 @@
[Desktop Entry]
Encoding=UTF-8
Name=Link to
Type=Link
URL=file:///home/yukun/Moorfield/git_repo/RETFound_MAE/README.md
Icon=text-markdown
+75 -68
View File
@@ -1,8 +1,9 @@
## RETFound - A foundation model for retinal imaging ## RETFound - A foundation model for retinal imaging
Official repo for [RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae): Official repo including a series of retinal foundation models.
[RETFound: a foundation model for generalizable disease detection from retinal images](https://www.nature.com/articles/s41586-023-06555-x), which is based on [MAE](https://github.com/facebookresearch/mae):
[New checkpoints](https://www.nature.com/articles/s41586-023-06555-x), which is based on [DINOV2](https://github.com/facebookresearch/dinov2):
Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions. Please contact **ykzhoua@gmail.com** or **yukun.zhou.19@ucl.ac.uk** if you have questions.
Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE) Keras version implemented by Yuka Kihara can be found [here](https://github.com/uw-biomedical-ml/RETFound_MAE)
@@ -17,6 +18,9 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
### 🎉News ### 🎉News
- 🐉2025/02: **We organised the model weights on HuggingFace, no more manual downloads needed!**
- 🐉2025/02: **Multiple [pre-trained weights](https://huggingface.co/YukunZhou), including MAE-based and DINOV2-based, are added!**
- 🐉2025/02: **We update the version of packages, such as CUDA12+ and PyTorch 2.3+!**
- 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_Feature.ipynb) are now online! - 🐉2024/01: [Feature vector notebook](https://github.com/rmaphoh/RETFound_MAE/blob/main/RETFound_Feature.ipynb) are now online!
- 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online! - 🐉2024/01: [Data split and model checkpoints](BENCHMARK.md) for public datasets are now online!
- 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation! - 🎄2023/12: [Colab notebook](https://colab.research.google.com/drive/1_X19zdMegmAlqPAEY0Ao659fzzzlx2IZ?usp=sharing) is now online - free GPU & simple operation!
@@ -29,16 +33,17 @@ Keras version implemented by Yuka Kihara can be found [here](https://github.com/
1. Create environment with conda: 1. Create environment with conda:
``` ```
conda create -n retfound python=3.7.5 -y conda create -n retfound python=3.11.0 -y
conda activate retfound conda activate retfound
``` ```
2. Install dependencies 2. Install dependencies
``` ```
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=12.1 -c pytorch -c nvidia
git clone https://github.com/rmaphoh/RETFound_MAE/ git clone https://github.com/rmaphoh/RETFound_MAE/
cd RETFound_MAE cd RETFound_MAE
pip install -r requirement.txt pip install -r requirements.txt
``` ```
@@ -46,23 +51,51 @@ pip install -r requirement.txt
To fine tune RETFound on your own data, follow these steps: To fine tune RETFound on your own data, follow these steps:
1. Download the RETFound pre-trained weights 1. Get access to the pre-trained models on HuggingFace (register an account and fill in the form) and go to step 2:
<table><tbody> <table><tbody>
<!-- START TABLE --> <!-- START TABLE -->
<!-- TABLE HEADER --> <!-- TABLE HEADER -->
<th valign="bottom"></th> <th valign="bottom"></th>
<th valign="bottom">ViT-Large</th> <th valign="bottom">ViT-Large</th>
<th valign="bottom">Source</th>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">Colour fundus image</td> <tr><td align="left">RETFound_mae_natureCFP</td>
<td align="center"><a href="https://drive.google.com/file/d/1l62zbWUFTlp214SvK6eMwPQZAzcwoeBE/view?usp=sharing">download</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_natureCFP">access</a></td>
<td align="center"><a href="https://www.nature.com/articles/s41586-023-06555-x">Nature RETFound paper</a></td>
</tr> </tr>
<!-- TABLE BODY --> <!-- TABLE BODY -->
<tr><td align="left">OCT</td> <tr><td align="left">RETFound_mae_natureOCT</td>
<td align="center"><a href="https://drive.google.com/file/d/1m6s7QYkjyjJDlpEuXm7Xp3PmjN-elfW2/view?usp=sharing">download</a></td> <td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_natureOCT">access</a></td>
<td align="center"><a href="https://www.nature.com/articles/s41586-023-06555-x">Nature RETFound paper</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_meh">access</a></td>
<td align="center">TBD</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_mae_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_mae_shanghai">access</a></td>
<td align="center">TBD</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_meh</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_meh">access</a></td>
<td align="center">TBD</a></td>
</tr>
<!-- TABLE BODY -->
<tr><td align="left">RETFound_dinov2_shanghai</td>
<td align="center"><a href="https://huggingface.co/YukunZhou/RETFound_dinov2_shanghai">download</a></td>
<td align="center">TBD</a></td>
</tr> </tr>
</tbody></table> </tbody></table>
2. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md)) 2. Login in your HuggingFace account, where HuggingFace token can be [created and copied](https://huggingface.co/settings/tokens).
```
huggingface-cli login --token YOUR_HUGGINGFACE_TOKEN
```
3. Organise your data into this directory structure (Public datasets used in this study can be [downloaded here](BENCHMARK.md))
``` ```
├── data folder ├── data folder
@@ -80,23 +113,29 @@ To fine tune RETFound on your own data, follow these steps:
├──class_c ├──class_c
``` ```
3. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training. 4. Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be automatically run after training.
``` ```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ model can be "RETFound_mae" or "RETFound_dinov2"
```
```
finetune can be "RETFound_mae_natureOCT", "RETFound_mae_natureCFP", "RETFound_mae_meh", "RETFound_mae_shanghai", "RETFound_dinov2_meh", and "RETFound_dinov2_shanghai".
```
```
torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
--model RETFound_mae \
--savemodel \
--global_pool \
--batch_size 16 \ --batch_size 16 \
--world_size 1 \ --world_size 1 \
--model vit_large_patch16 \ --epochs 100 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \ --blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \ --weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \ --nb_classes 5 \
--data_path ./IDRiD_data/ \ --data_path ./IDRiD \
--task ./finetune_IDRiD/ \ --input_size 224 \
--finetune ./RETFound_cfp_weights.pth \ --task RETFound_mae_meh-IDRiD \
--input_size 224 --finetune RETFound_mae_meh
``` ```
@@ -104,64 +143,32 @@ python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_f
``` ```
python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py \ torchrun --nproc_per_node=1 --master_port=48798 main_finetune.py \
--eval --batch_size 16 \ --model RETFound_mae \
--savemodel \
--eval \
--global_pool \
--batch_size 16 \
--world_size 1 \ --world_size 1 \
--model vit_large_patch16 \ --epochs 100 \
--epochs 50 \
--blr 5e-3 --layer_decay 0.65 \ --blr 5e-3 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.2 \ --weight_decay 0.05 --drop_path 0.2 \
--nb_classes 5 \ --nb_classes 5 \
--data_path ./IDRiD_data/ \ --data_path ./IDRiD \
--task ./internal_IDRiD/ \ --input_size 224 \
--resume ./finetune_IDRiD/checkpoint-best.pth \ --task RETFound_mae_meh-IDRiD \
--input_size 224 --resume ./finetune_IDRiD/checkpoint-best.pth
```
### Load the model and weights (if you want to call the model in your code)
```python
import torch
import models_vit
from util.pos_embed import interpolate_pos_embed
from timm.models.layers import trunc_normal_
# call the model
model = models_vit.__dict__['vit_large_patch16'](
num_classes=2,
drop_path_rate=0.2,
global_pool=True,
)
# load RETFound weights
checkpoint = torch.load('RETFound_cfp_weights.pth', map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)
print("Model = %s" % str(model))
``` ```
### 📃Citation ### 📃Citation
If you find this repository useful, please consider citing this paper: If you find this repository useful, please consider citing this paper:
```
TBD
```
``` ```
@article{zhou2023foundation, @article{zhou2023foundation,
title={A foundation model for generalizable disease detection from retinal images}, title={A foundation model for generalizable disease detection from retinal images},
-210
View File
@@ -1,210 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1eae7403-f458-4f55-a557-4e045bd6f679",
"metadata": {
"id": "1eae7403-f458-4f55-a557-4e045bd6f679"
},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import models_vit"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4573e6be-935a-4106-8c06-e467552b0e3d",
"metadata": {
"id": "4573e6be-935a-4106-8c06-e467552b0e3d"
},
"outputs": [],
"source": [
"\n",
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"\n",
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
" # build model\n",
" model = models_vit.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" global_pool=True,\n",
" )\n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
" return model\n",
"\n",
"def run_one_image(img, model):\n",
" \n",
" x = torch.tensor(img)\n",
" x = x.unsqueeze(dim=0)\n",
" x = torch.einsum('nhwc->nchw', x)\n",
" \n",
" x = x.to(device, non_blocking=True)\n",
" latent = model.forward_features(x.float())\n",
" latent = torch.squeeze(latent)\n",
" \n",
" return latent\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5",
"metadata": {
"id": "8b7e691d-93d2-439f-91d6-c22716a897b5"
},
"source": [
"### Load a pre-trained model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
"metadata": {
"id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
"outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model loaded.\n"
]
}
],
"source": [
"# download pre-trained RETFound \n",
"\n",
"chkpt_dir = './RETFound_cfp.pth'\n",
"model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n",
"\n",
"device = torch.device('cuda')\n",
"model_.to(device)\n",
"print('Model loaded.')\n"
]
},
{
"cell_type": "markdown",
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6",
"metadata": {
"id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6"
},
"source": [
"### Load images and save latent feature"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "27755296-05cc-4344-90de-a8ab3878f485",
"metadata": {
"id": "27755296-05cc-4344-90de-a8ab3878f485",
"outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70",
"tags": []
},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'Your data path'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_16866/3238108902.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# get image list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'Your data path'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mimg_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlistdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mname_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'Your data path'"
]
}
],
"source": [
"# get image list\n",
"data_path = 'Your data path'\n",
"img_list = os.listdir(data_path)\n",
"\n",
"name_list = []\n",
"feature_list = []\n",
"model_.eval()\n",
"\n",
"for i in img_list:\n",
" img = Image.open(os.path.join(data_path, i))\n",
" img = img.resize((224, 224))\n",
" img = np.array(img) / 255.\n",
"\n",
" assert img.shape == (224, 224, 3)\n",
"\n",
" # normalize by mean and sd\n",
" # can use customised mean and sd for your data\n",
" img = img - imagenet_mean\n",
" img = img / imagenet_std\n",
" \n",
" latent_feature = run_one_image(img, model_)\n",
" \n",
" name_list.append(i)\n",
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a365ec24-8e29-485e-83b5-5ac1d02945bb",
"metadata": {},
"outputs": [],
"source": [
"latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n",
"latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2e8bd5e6-5780-420d-9d4c-96025b265668",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"provenance": []
},
"environment": {
"kernel": "python3",
"name": "common-cu110.m91",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
File diff suppressed because one or more lines are too long
+79 -139
View File
@@ -1,115 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
import math
import sys
import csv
import os import os
import csv
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Iterable, Optional
from timm.data import Mixup from timm.data import Mixup
from timm.utils import accuracy from timm.utils import accuracy
from typing import Iterable, Optional from sklearn.metrics import (
accuracy_score, roc_auc_score, f1_score, average_precision_score,
hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
)
from pycm import ConfusionMatrix
import util.misc as misc import util.misc as misc
import util.lr_sched as lr_sched import util.lr_sched as lr_sched
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix
from pycm import *
import matplotlib.pyplot as plt
import numpy as np
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
def misc_measures(confusion_matrix): data_loader: Iterable,
optimizer: torch.optim.Optimizer,
acc = [] device: torch.device,
sensitivity = [] epoch: int,
specificity = [] loss_scaler,
precision = [] max_norm: float = 0,
G = [] mixup_fn: Optional[Mixup] = None,
F1_score_2 = [] log_writer=None,
mcc_ = [] args=None
):
for i in range(1, confusion_matrix.shape[0]): """Train the model for one epoch."""
cm1=confusion_matrix[i]
acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1))
sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1])
sensitivity.append(sensitivity_)
specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0])
specificity.append(specificity_)
precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1])
precision.append(precision_)
G.append(np.sqrt(sensitivity_*specificity_))
F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_))
mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1]))
mcc_.append(mcc)
acc = np.array(acc).mean()
sensitivity = np.array(sensitivity).mean()
specificity = np.array(specificity).mean()
precision = np.array(precision).mean()
G = np.array(G).mean()
F1_score_2 = np.array(F1_score_2).mean()
mcc_ = np.array(mcc_).mean()
return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None, log_writer=None,
args=None):
model.train(True) model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ") metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch) print_freq, accum_iter = 20, args.accum_iter
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad() optimizer.zero_grad()
if log_writer is not None: if log_writer:
print('log_dir: {}'.format(log_writer.log_dir)) print(f'log_dir: {log_writer.log_dir}')
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0: if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = samples.to(device, non_blocking=True) samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True) if mixup_fn:
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets) samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
outputs = model(samples) outputs = model(samples)
loss = criterion(outputs, targets) loss = criterion(outputs, targets)
loss_value = loss.item() loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False, loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0) update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0: if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad() optimizer.zero_grad()
torch.cuda.synchronize() torch.cuda.synchronize()
metric_logger.update(loss=loss_value) metric_logger.update(loss=loss_value)
min_lr = 10. min_lr = 10.
max_lr = 0. max_lr = 0.
@@ -125,84 +74,75 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
This calibrates different curves when batch size changes. This calibrates different curves when batch size changes.
""" """
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x) log_writer.add_scalar('lr', max_lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger) print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad() @torch.no_grad()
def evaluate(data_loader, model, device, task, epoch, mode, num_class): def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
criterion = torch.nn.CrossEntropyLoss() """Evaluate the model."""
criterion = nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ") metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:' os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
if not os.path.exists(task):
os.makedirs(task)
prediction_decode_list = []
prediction_list = []
true_label_decode_list = []
true_label_onehot_list = []
# switch to evaluation mode
model.eval() model.eval()
true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
for batch in metric_logger.log_every(data_loader, 10, header): for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
images = batch[0] images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
target = batch[-1] target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
true_label=F.one_hot(target.to(torch.int64), num_classes=num_class)
# compute output
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
output = model(images) output = model(images)
loss = criterion(output, target) loss = criterion(output, target)
prediction_softmax = nn.Softmax(dim=1)(output) output_ = nn.Softmax(dim=1)(output)
_,prediction_decode = torch.max(prediction_softmax, 1) output_label = output_.argmax(dim=1)
_,true_label_decode = torch.max(true_label, 1) output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
prediction_decode_list.extend(prediction_decode.cpu().detach().numpy())
true_label_decode_list.extend(true_label_decode.cpu().detach().numpy())
true_label_onehot_list.extend(true_label.cpu().detach().numpy())
prediction_list.extend(prediction_softmax.cpu().detach().numpy())
acc1,_ = accuracy(output, target, topk=(1,2))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item()) metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) true_onehot.extend(target_onehot.cpu().numpy())
# gather the stats from all processes pred_onehot.extend(output_onehot.detach().cpu().numpy())
true_label_decode_list = np.array(true_label_decode_list) true_labels.extend(target.cpu().numpy())
prediction_decode_list = np.array(prediction_decode_list) pred_labels.extend(output_label.detach().cpu().numpy())
confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)]) pred_softmax.extend(output_.detach().cpu().numpy())
acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix)
auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro') accuracy = accuracy_score(true_labels, pred_labels)
auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro') hamming = hamming_loss(true_onehot, pred_onehot)
jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
kappa = cohen_kappa_score(true_labels, pred_labels)
f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
score = (f1 + roc_auc + kappa) / 3
if log_writer:
for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
[accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
print(f'val loss: {metric_logger.meters["loss"].global_avg}')
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc)) results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
results_path = task+'_metrics_{}.csv'.format(mode) file_exists = os.path.isfile(results_path)
with open(results_path,mode='a',newline='',encoding='utf8') as cfa: with open(results_path, 'a', newline='', encoding='utf8') as cfa:
wf = csv.writer(cfa) wf = csv.writer(cfa)
data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]] if not file_exists:
for i in data2: wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
wf.writerow(i) wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
if mode == 'test':
cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
if mode=='test': return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list)
cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib")
plt.savefig(task+'confusion_matrix_test.jpg',dpi=600,bbox_inches ='tight')
return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc
-82
View File
@@ -1,82 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import torch
import util.misc as misc
import util.lr_sched as lr_sched
def train_one_epoch(model: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = samples.to(device, non_blocking=True)
with torch.cuda.amp.autocast():
loss, _, _ = model(samples, mask_ratio=args.mask_ratio)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+194
View File
@@ -0,0 +1,194 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 14,
"id": "0ae19951",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import pandas as pd\n",
"from PIL import Image\n",
"import models_vit as models\n",
"np.set_printoptions(threshold=np.inf)\n",
"np.random.seed(1)\n",
"torch.manual_seed(1)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "90c3d964",
"metadata": {},
"outputs": [],
"source": [
"def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
" \n",
" # load model\n",
" checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
" \n",
" # build model\n",
" if arch=='vit_large_patch16':\n",
" model = models.__dict__[arch](\n",
" img_size=224,\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" global_pool=True,\n",
" )\n",
" msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
" else:\n",
" model = models.__dict__[arch](\n",
" num_classes=5,\n",
" drop_path_rate=0,\n",
" args=None,\n",
" )\n",
" msg = model.load_state_dict(checkpoint['teacher'], strict=False)\n",
" return model\n",
"\n",
"def run_one_image(img, model, arch):\n",
" \n",
" x = torch.tensor(img)\n",
" x = x.unsqueeze(dim=0)\n",
" x = torch.einsum('nhwc->nchw', x)\n",
" \n",
" x = x.to(device, non_blocking=True)\n",
" latent = model.forward_features(x.float())\n",
" \n",
" if arch=='dinov2_large':\n",
" latent = latent[:, 1:, :].mean(dim=1,keepdim=True)\n",
" latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)\n",
" \n",
" latent = torch.squeeze(latent)\n",
"\n",
" return latent\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "9a250363",
"metadata": {},
"outputs": [],
"source": [
"def get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='vit_large_patch16'):\n",
" #loading model\n",
" model_ = prepare_model(chkpt_dir, arch)\n",
" model_.to(device)\n",
"\n",
" img_list = os.listdir(data_path)\n",
" \n",
" name_list = []\n",
" feature_list = []\n",
" model_.eval()\n",
" \n",
" finished_num = 0\n",
" for i in img_list:\n",
" finished_num+=1\n",
" if (finished_num%1000 == 0):\n",
" print(str(finished_num)+\"finished\")\n",
" \n",
" img = Image.open(os.path.join(data_path, i))\n",
" img = img.resize((224, 224))\n",
" img = np.array(img) / 255.\n",
" img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()\n",
" img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()\n",
" img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()\n",
" assert img.shape == (224, 224, 3)\n",
" \n",
" latent_feature = run_one_image(img, model_,arch)\n",
" \n",
" name_list.append(i)\n",
" feature_list.append(latent_feature.detach().cpu().numpy())\n",
" \n",
" return [name_list,feature_list]\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "54acfcd7",
"metadata": {},
"outputs": [],
"source": [
"chkpt_dir = '/home/jupyter/huggingface_repo/RETFound_dinov2_meh.pth'\n",
"data_path = '/home/jupyter/public_dataset/IDRiD_data/val/anoDR'\n",
"device = torch.device('cuda')"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "0296f74e",
"metadata": {},
"outputs": [],
"source": [
"[name_list,feature]=get_feature(data_path,\n",
" chkpt_dir,\n",
" device,\n",
" arch='dinov2_large')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "925d3994",
"metadata": {},
"outputs": [],
"source": [
"#save the feature\n",
"df_feature = pd.DataFrame(feature)\n",
"df_imgname = pd.DataFrame(name_list)\n",
"df_visualization = pd.concat([df_imgname,df_feature], axis=1)\n",
"column_name_list = []\n",
"\n",
"for i in range(1024):\n",
" column_name_list.append(\"feature_{}\".format(i))\n",
"df_visualization.columns = [\"name\"] + column_name_list\n",
"df_visualization.to_csv(\"Feature.csv\",index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0d13a7-2b46-40eb-ab48-5f90a6aeecb5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "test",
"name": "common-cu121.m123",
"type": "gcloud",
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/base-cu121:m123"
},
"kernelspec": {
"display_name": "Python_test (Local)",
"language": "python",
"name": "test"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+135 -102
View File
@@ -1,11 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Partly revised by YZ @UCL&Moorfields
# --------------------------------------------------------
import argparse import argparse
import datetime import datetime
import json import json
import numpy as np import numpy as np
import os import os
import time import time
@@ -14,28 +10,28 @@ from pathlib import Path
import torch import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
import timm
assert timm.__version__ == "0.3.2" # version check
from timm.models.layers import trunc_normal_ from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup from timm.data.mixup import Mixup
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
import models_vit as models
import util.lr_decay as lrd import util.lr_decay as lrd
import util.misc as misc import util.misc as misc
from util.datasets import build_dataset from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
import models_vit
from engine_finetune import train_one_epoch, evaluate from engine_finetune import train_one_epoch, evaluate
import warnings
import faulthandler
faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning)
def get_args_parser(): def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=64, type=int, parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int) parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int, parser.add_argument('--accum_iter', default=1, type=int,
@@ -44,11 +40,9 @@ def get_args_parser():
# Model parameters # Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train') help='Name of model to train')
parser.add_argument('--input_size', default=256, type=int,
parser.add_argument('--input_size', default=224, type=int,
help='images input size') help='images input size')
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT',
help='Drop path rate (default: 0.1)') help='Drop path rate (default: 0.1)')
# Optimizer parameters # Optimizer parameters
@@ -56,17 +50,14 @@ def get_args_parser():
help='Clip gradient norm (default: None, no clipping)') help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05, parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)') help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR', parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)') help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.75, parser.add_argument('--layer_decay', type=float, default=0.65,
help='layer-wise lr decay from ELECTRA/BEiT') help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0') help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR') help='epochs to warmup LR')
@@ -103,9 +94,9 @@ def get_args_parser():
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params # * Finetuning params
parser.add_argument('--finetune', default='',type=str, parser.add_argument('--finetune', default='', type=str,
help='finetune from checkpoint') help='finetune from checkpoint')
parser.add_argument('--task', default='',type=str, parser.add_argument('--task', default='', type=str,
help='finetune from checkpoint') help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true') parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True) parser.set_defaults(global_pool=True)
@@ -113,21 +104,19 @@ def get_args_parser():
help='Use class token instead of global pool for classification') help='Use class token instead of global pool for classification')
# Dataset parameters # Dataset parameters
parser.add_argument('--data_path', default='/home/jupyter/Mor_DR_data/data/data/IDRID/Disease_Grading/', type=str, parser.add_argument('--data_path', default='./data/', type=str,
help='dataset path') help='dataset path')
parser.add_argument('--nb_classes', default=1000, type=int, parser.add_argument('--nb_classes', default=8, type=int,
help='number of the classification types') help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir', parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving') help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir', parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log') help='path where to tensorboard log')
parser.add_argument('--device', default='cuda', parser.add_argument('--device', default='cuda',
help='device to use for training / testing') help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int) parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', parser.add_argument('--resume', default='',
help='resume from checkpoint') help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch') help='start epoch')
parser.add_argument('--eval', action='store_true', parser.add_argument('--eval', action='store_true',
@@ -137,7 +126,6 @@ def get_args_parser():
parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true', parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True) parser.set_defaults(pin_mem=True)
# distributed training parameters # distributed training parameters
@@ -148,10 +136,24 @@ def get_args_parser():
parser.add_argument('--dist_url', default='env://', parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training') help='url used to set up distributed training')
# fine-tuning parameters
parser.add_argument('--savemodel', action='store_true', default=True,
help='Save model')
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
parser.add_argument('--datasets_seed', default=2026, type=int)
return parser return parser
def main(args): def main(args, criterion):
if args.resume and not args.eval:
resume = args.resume
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
args = checkpoint['args']
args.resume = resume
misc.init_distributed_mode(args) misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
@@ -166,24 +168,77 @@ def main(args):
cudnn.benchmark = True cudnn.benchmark = True
if args.model=='RETFound_mae':
model = models.__dict__[args.model](
img_size=args.input_size,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
else:
model = models.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
args=args,
)
if args.finetune and not args.eval:
print(f"Downloading pre-trained weights from: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}',
filename=f'{args.finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args) dataset_train = build_dataset(is_train='train', args=args)
dataset_val = build_dataset(is_train='val', args=args) dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args) dataset_test = build_dataset(is_train='test', args=args)
if True: # args.distributed: if True: # args.distributed:
num_tasks = misc.get_world_size() num_tasks = misc.get_world_size()
global_rank = misc.get_rank() global_rank = misc.get_rank()
if not args.eval:
sampler_train = torch.utils.data.DistributedSampler( sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
) )
print("Sampler_train = %s" % str(sampler_train)) print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval: if args.dist_eval:
if len(dataset_val) % num_tasks != 0: if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' print(
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.') 'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler( sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias dataset_val, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else: else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val) sampler_val = torch.utils.data.SequentialSampler(dataset_val)
@@ -193,17 +248,18 @@ def main(args):
'This will slightly alter validation results as extra duplicate entries are added to achieve ' 'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.') 'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler( sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias dataset_test, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else: else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test) sampler_test = torch.utils.data.SequentialSampler(dataset_test)
if global_rank == 0 and args.log_dir is not None and not args.eval: if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True) os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir+args.task) log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else: else:
log_writer = None log_writer = None
if not args.eval:
data_loader_train = torch.utils.data.DataLoader( data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train, dataset_train, sampler=sampler_train,
batch_size=args.batch_size, batch_size=args.batch_size,
@@ -212,6 +268,8 @@ def main(args):
drop_last=True, drop_last=True,
) )
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
data_loader_val = torch.utils.data.DataLoader( data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val, dataset_val, sampler=sampler_val,
batch_size=args.batch_size, batch_size=args.batch_size,
@@ -228,7 +286,6 @@ def main(args):
drop_last=False drop_last=False
) )
mixup_fn = None mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active: if mixup_active:
@@ -238,46 +295,16 @@ def main(args):
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes) label_smoothing=args.smoothing, num_classes=args.nb_classes)
model = models_vit.__dict__[args.model]( if args.resume and args.eval:
img_size=args.input_size, checkpoint = torch.load(args.resume, map_location='cpu')
num_classes=args.nb_classes, print("Load checkpoint from: %s" % args.resume)
drop_path_rate=args.drop_path, model.load_state_dict(checkpoint['model'])
global_pool=args.global_pool,
)
if args.finetune and not args.eval:
checkpoint = torch.load(args.finetune, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
if args.global_pool:
assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'}
else:
assert set(msg.missing_keys) == {'head.weight', 'head.bias'}
# manually initialize fc layer
trunc_normal_(model.head.weight, std=2e-5)
model.to(device) model.to(device)
model_without_ddp = model model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp)) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params (M): %.2f' % (n_parameters / 1.e6)) print('number of model params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
@@ -294,37 +321,33 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module model_without_ddp = model.module
# build optimizer with layer-wise lr decay (lrd) no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=model_without_ddp.no_weight_decay(), no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay layer_decay=args.layer_decay
) )
optimizer = torch.optim.AdamW(param_groups, lr=args.lr) optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler() loss_scaler = NativeScaler()
if mixup_fn is not None:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing > 0.:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
print("criterion = %s" % str(criterion)) print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
if args.eval: if args.eval:
test_stats,auc_roc = evaluate(data_loader_test, model, device, args.task, epoch=0, mode='test',num_class=args.nb_classes) if 'epoch' in checkpoint:
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
num_class=args.nb_classes, log_writer=log_writer)
exit(0) exit(0)
print(f"Start training for {args.epochs} epochs") print(f"Start training for {args.epochs} epochs")
start_time = time.time() start_time = time.time()
max_accuracy = 0.0 max_score = 0.0
max_auc = 0.0 best_epoch = 0
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
data_loader_train.sampler.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch( train_stats = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler,
@@ -333,19 +356,28 @@ def main(args):
args=args args=args
) )
val_stats,val_auc_roc = evaluate(data_loader_val, model, device,args.task,epoch, mode='val',num_class=args.nb_classes) val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
if max_auc<val_auc_roc: num_class=args.nb_classes, log_writer=log_writer)
max_auc = val_auc_roc if max_score < val_score:
max_score = val_score
if args.output_dir: best_epoch = epoch
if args.output_dir and args.savemodel:
misc.save_model( misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch) loss_scaler=loss_scaler, epoch=epoch, mode='best')
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None: if log_writer is not None:
log_writer.add_scalar('perf/val_acc1', val_stats['acc1'], epoch) log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
log_writer.add_scalar('perf/val_auc', val_auc_roc, epoch)
log_writer.add_scalar('perf/val_loss', val_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch, 'epoch': epoch,
@@ -354,21 +386,22 @@ def main(args):
if args.output_dir and misc.is_main_process(): if args.output_dir and misc.is_main_process():
if log_writer is not None: if log_writer is not None:
log_writer.flush() log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n") f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str)) print('Training time {}'.format(total_time_str))
state_dict_best = torch.load(args.task+'checkpoint-best.pth', map_location='cpu')
model_without_ddp.load_state_dict(state_dict_best['model'])
test_stats,auc_roc = evaluate(data_loader_test, model_without_ddp, device,args.task,epoch=0, mode='test',num_class=args.nb_classes)
if __name__ == '__main__': if __name__ == '__main__':
args = get_args_parser() args = get_args_parser()
args = args.parse_args() args = args.parse_args()
criterion = torch.nn.CrossEntropyLoss()
if args.output_dir: if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args) main(args, criterion)
+414
View File
@@ -0,0 +1,414 @@
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from timm.models.layers import trunc_normal_
from timm.data.mixup import Mixup
import models_vit as models
import util.lr_decay as lrd
import util.misc as misc
from util.datasets import build_dataset
from util.pos_embed import interpolate_pos_embed
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from huggingface_hub import hf_hub_download, login
from engine_finetune import train_one_epoch, evaluate
import warnings
import faulthandler
faulthandler.enable()
warnings.simplefilter(action='ignore', category=FutureWarning)
def get_args_parser():
parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False)
parser.add_argument('--batch_size', default=128, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=256, type=int,
help='images input size')
parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
help='Drop path rate (default: 0.1)')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=5e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.65,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N',
help='epochs to warmup LR')
# Augmentation parameters
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT',
help='Color jitter factor (enabled only when not using Auto/RandAug)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# * Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# * Mixup params
parser.add_argument('--mixup', type=float, default=0,
help='mixup alpha, mixup enabled if > 0.')
parser.add_argument('--cutmix', type=float, default=0,
help='cutmix alpha, cutmix enabled if > 0.')
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup_prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup_mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# * Finetuning params
parser.add_argument('--finetune', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--task', default='', type=str,
help='finetune from checkpoint')
parser.add_argument('--global_pool', action='store_true')
parser.set_defaults(global_pool=True)
parser.add_argument('--cls_token', action='store_false', dest='global_pool',
help='Use class token instead of global pool for classification')
# Dataset parameters
parser.add_argument('--data_path', default='./data/', type=str,
help='dataset path')
parser.add_argument('--nb_classes', default=8, type=int,
help='number of the classification types')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_logs',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
# fine-tuning parameters
parser.add_argument('--savemodel', action='store_true', default=True,
help='Save model')
parser.add_argument('--norm', default='IMAGENET', type=str, help='Normalization method')
parser.add_argument('--enhance', action='store_true', default=False, help='Use enhanced data')
parser.add_argument('--datasets_seed', default=2026, type=int)
return parser
def main(args, criterion):
if args.resume and not args.eval:
resume = args.resume
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
args = checkpoint['args']
args.resume = resume
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
if args.model=='RETFound_mae':
model = models.__dict__[args.model](
img_size=args.input_size,
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
global_pool=args.global_pool,
)
else:
model = models.__dict__[args.model](
num_classes=args.nb_classes,
drop_path_rate=args.drop_path,
args=args,
)
if args.finetune and not args.eval:
print(f"Downloading pre-trained weights from: {args.finetune}")
checkpoint_path = hf_hub_download(
repo_id=f'YukunZhou/{args.finetune}',
filename=f'{args.finetune}.pth',
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print("Load pre-trained checkpoint from: %s" % args.finetune)
if args.model!='RETFound_mae':
checkpoint_model = checkpoint['teacher']
else:
checkpoint_model = checkpoint['model']
checkpoint_model = {k.replace("backbone.", ""): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w12.", "mlp.fc1."): v for k, v in checkpoint_model.items()}
checkpoint_model = {k.replace("mlp.w3.", "mlp.fc2."): v for k, v in checkpoint_model.items()}
state_dict = model.state_dict()
for k in ['head.weight', 'head.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
# interpolate position embedding
interpolate_pos_embed(model, checkpoint_model)
# load pre-trained model
msg = model.load_state_dict(checkpoint_model, strict=False)
trunc_normal_(model.head.weight, std=2e-5)
dataset_train = build_dataset(is_train='train', args=args)
dataset_val = build_dataset(is_train='val', args=args)
dataset_test = build_dataset(is_train='test', args=args)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
if not args.eval:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print(
'Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
if args.dist_eval:
if len(dataset_test) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_test = torch.utils.data.DistributedSampler(
dataset_test, num_replicas=num_tasks, rank=global_rank,
shuffle=True) # shuffle=True to reduce monitor bias
else:
sampler_test = torch.utils.data.SequentialSampler(dataset_test)
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.task))
else:
log_writer = None
if not args.eval:
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
print(f'len of train_set: {len(data_loader_train) * args.batch_size}')
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, sampler=sampler_test,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
print("Mixup is activated!")
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
if args.resume and args.eval:
checkpoint = torch.load(args.resume, map_location='cpu')
print("Load checkpoint from: %s" % args.resume)
model.load_state_dict(checkpoint['model'])
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of model params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
model_without_ddp = model.module
no_weight_decay = model_without_ddp.no_weight_decay() if hasattr(model_without_ddp, 'no_weight_decay') else []
param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay,
no_weight_decay_list=no_weight_decay,
layer_decay=args.layer_decay
)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
loss_scaler = NativeScaler()
print("criterion = %s" % str(criterion))
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
for name, param in model.named_parameters():
if 'head' in name:
param.requires_grad = True
else:
param.requires_grad = False
if args.eval:
if 'epoch' in checkpoint:
print("Test with the best model at epoch = %d" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, epoch=0, mode='test',
num_class=args.nb_classes, log_writer=log_writer)
exit(0)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_score = 0.0
best_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, mixup_fn,
log_writer=log_writer,
args=args
)
val_stats, val_score = evaluate(data_loader_val, model, device, args, epoch, mode='val',
num_class=args.nb_classes, log_writer=log_writer)
if max_score < val_score:
max_score = val_score
best_epoch = epoch
if args.output_dir and args.savemodel:
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch, mode='best')
print("Best epoch = %d, Best score = %.4f" % (best_epoch, max_score))
if epoch == (args.epochs - 1):
checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
model.to(device)
print("Test with the best model, epoch = %d:" % checkpoint['epoch'])
test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',
num_class=args.nb_classes, log_writer=None)
if log_writer is not None:
log_writer.add_scalar('loss/val', val_stats['loss'], epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, args.task, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
criterion = torch.nn.CrossEntropyLoss()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args, criterion)
-221
View File
@@ -1,221 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import timm
assert timm.__version__ == "0.3.2" # version check
import timm.optim.optim_factory as optim_factory
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
import models_mae
from engine_pretrain import train_one_epoch
def get_args_parser():
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
parser.add_argument('--batch_size', default=64, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=400, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='mae_vit_large_patch16', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
parser.add_argument('--mask_ratio', default=0.75, type=float,
help='Masking ratio (percentage of removed patches).')
parser.add_argument('--norm_pix_loss', action='store_true',
help='Use (per-patch) normalized pixels as targets for computing loss')
parser.set_defaults(norm_pix_loss=False)
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
help='epochs to warmup LR')
# Dataset parameters
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str,
help='dataset path')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# simple augmentation
transform_train = transforms.Compose([
transforms.RandomResizedCrop(args.input_size, scale=(0.2, 1.0), interpolation=3), # 3 is bicubic
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
print(dataset_train)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
# define the model
model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss)
model.to(device)
model_without_ddp = model
print("Model = %s" % str(model_without_ddp))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
if args.output_dir and (epoch % 50 == 0 or epoch + 1 == args.epochs):
misc.save_model_pretrain(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
-226
View File
@@ -1,226 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block
from util.pos_embed import get_2d_sincos_pos_embed
class MaskedAutoencoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
for i in range(decoder_depth)])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x)
# remove cls token
x = x[:, 1:, :]
return x
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, imgs, mask_ratio=0.75):
latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
# set recommended archs
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
+18 -4
View File
@@ -5,10 +5,11 @@
from functools import partial from functools import partial
import timm.models.vision_transformer
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import timm.models.vision_transformer from torch import Tensor
class VisionTransformer(timm.models.vision_transformer.VisionTransformer): class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
@@ -38,7 +39,7 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
x = blk(x) x = blk(x)
if self.global_pool: if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token x = x[:, 1:, :].mean(dim=1,keepdim=True) # global pool without cls token
outcome = self.fc_norm(x) outcome = self.fc_norm(x)
else: else:
x = self.norm(x) x = self.norm(x)
@@ -47,9 +48,22 @@ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
return outcome return outcome
def vit_large_patch16(**kwargs): def RETFound_mae(**kwargs):
model = VisionTransformer( model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model return model
def RETFound_dinov2(args, **kwargs):
model = timm.create_model(
'vit_large_patch14_dinov2.lvd142m',
pretrained=True,
img_size=224,
**kwargs
)
return model
Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

-18
View File
@@ -1,18 +0,0 @@
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cu111
torchvision==0.9.1+cu111
torchaudio==0.8.1
opencv-python==4.5.3.56
pandas==0.25.3
Pillow==8.3.1
protobuf==3.17.3
pycm==3.2
pydicom==2.3.0
scikit-image==0.17.2
scikit-learn==0.24.2
scipy==1.5.4
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
timm==0.3.2
tqdm==4.62.1
+11
View File
@@ -0,0 +1,11 @@
opencv-python~=4.9.0.80
Pillow~=10.2.0
pycm~=4.0
scikit-learn~=1.4.2
timm~=0.9.2
numpy~=1.26.4
matplotlib~=3.8.4
scikit-multilearn~=0.2.0
huggingface-hub~=0.23.4
tensorboard~=2.17.0
View File
+1 -2
View File
@@ -10,7 +10,6 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def build_dataset(is_train, args): def build_dataset(is_train, args):
transform = build_transform(is_train, args) transform = build_transform(is_train, args)
root = os.path.join(args.data_path, is_train) root = os.path.join(args.data_path, is_train)
dataset = datasets.ImageFolder(root, transform=transform) dataset = datasets.ImageFolder(root, transform=transform)
@@ -22,7 +21,7 @@ def build_transform(is_train, args):
mean = IMAGENET_DEFAULT_MEAN mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD std = IMAGENET_DEFAULT_STD
# train transform # train transform
if is_train=='train': if is_train == 'train':
# this should always dispatch to transforms_imagenet_train # this should always dispatch to transforms_imagenet_train
transform = create_transform( transform = create_transform(
input_size=args.input_size, input_size=args.input_size,
+4
View File
@@ -14,7 +14,11 @@ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_de
param_group_names = {} param_group_names = {}
param_groups = {} param_groups = {}
if hasattr(model, 'blocks'):
num_layers = len(model.blocks) + 1 num_layers = len(model.blocks) + 1
else:
# use the number of layers in the ResNet model as a default value
num_layers = len(model.layer1) + len(model.layer2) + len(model.layer3) + len(model.layer4) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+36 -24
View File
@@ -12,7 +12,7 @@ from pathlib import Path
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch._six import inf from math import inf
class SmoothedValue(object): class SmoothedValue(object):
@@ -282,16 +282,32 @@ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else: else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
norm_type)
return total_norm return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, mode):
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
epoch_name = str(epoch) epoch_name = str(epoch)
os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
if loss_scaler is not None: if loss_scaler is not None:
checkpoint_paths = [args.task+'checkpoint-best.pth'] if mode == 'best':
checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-best.pth')]
else:
checkpoint_paths = [os.path.join(args.output_dir, args.task, 'checkpoint-latest.pth')]
for checkpoint_path in checkpoint_paths: for checkpoint_path in checkpoint_paths:
if mode == 'best':
to_save = {
'model': model_without_ddp.state_dict(),
'epoch': epoch,
'args': args, }
else:
if epoch == args.epochs - 1:
to_save = {
'model': model_without_ddp.state_dict(),
'args': args, }
else:
to_save = { to_save = {
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
@@ -302,30 +318,23 @@ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
save_on_master(to_save, checkpoint_path) save_on_master(to_save, checkpoint_path)
else: else:
client_state = {'epoch': epoch} if mode == 'best':
model.save_checkpoint(save_dir=args.task, tag="checkpoint-best", client_state=client_state) to_save = {
'model': model_without_ddp.state_dict(),
'epoch': epoch, }
def save_model_pretrain(args, epoch, model, model_without_ddp, optimizer, loss_scaler): torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-best.pth"))
output_dir = Path(args.output_dir) else:
epoch_name = str(epoch) if epoch == args.epochs - 1:
if loss_scaler is not None: to_save = {
print(model_without_ddp.state_dict().keys()) 'model': model_without_ddp.state_dict(), }
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] else:
for checkpoint_path in checkpoint_paths:
to_save = { to_save = {
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'epoch': epoch, 'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args, 'args': args,
} }
torch.save(to_save, os.path.join(args.output_dir, args.task, "checkpoint-latest.pth"))
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler): def load_model(args, model_without_ddp, optimizer, loss_scaler):
@@ -335,7 +344,11 @@ def load_model(args, model_without_ddp, optimizer, loss_scaler):
args.resume, map_location='cpu', check_hash=True) args.resume, map_location='cpu', check_hash=True)
else: else:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model']) if 'model' in checkpoint:
checkpoint_model = checkpoint['model']
else:
checkpoint_model = checkpoint
model_without_ddp.load_state_dict(checkpoint_model, strict=False)
print("Resume checkpoint %s" % args.resume) print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
@@ -354,4 +367,3 @@ def all_reduce_mean(x):
return x_reduce.item() return x_reduce.item()
else: else:
return x return x