diff --git a/README.md b/README.md index 6d64bf2..7deb1a8 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,95 @@ -# RETFound_MAE -RETFound - A foundation model for retinal image +## RETFound - A foundation model for retinal image + + +This is official repo for RETFound, which heavily bases on [MAE](https://github.com/facebookresearch/mae): + + +### Key features + +- RETFound was trained on 1.6 million retinal images +- RETFound has been validated in multiple disease detection tasks +- RETFound can be efficiently adapted to customised task + + +### Install enviroment + +Create enviroment with conda: + +``` +conda create -n retfound python=3.6.15 -y +``` + +Install Pytorch 1.81 (cuda 11.1) +``` +pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html +``` + +Install others +``` +pip install -r requirement.txt +``` + + +### Fine-tuning with RETFound weights + +- RETFound pre-trained weights + + + + + + + + + + + + + +
ViT-Large
Colour fundus imagedownload
OCTdownload
+ +- Organise data (use IDRiD as example) + +

+ +

+ + +- Start fine-tuning (use IDRiD as example). A fine-tuned checkpoint will be saved during training. Evaluation will be run after training. + + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py + --batch_size 16 \ + --world_size 1 \ + --model vit_large_patch16 \ + --epochs 50 \ + --blr 5e-3 --layer_decay 0.65 \ + --weight_decay 0.05 --drop_path 0.2 \ + --nb_classes 5 \ + --data_path ./IDRiD_data/ \ + --task ./finetune_IDRiD/ \ + --finetune ./RETFound_cfp_weights.pth + +``` + + +- For evaluation only + + +``` +python -m torch.distributed.launch --nproc_per_node=1 --master_port=48798 main_finetune.py + --eval --batch_size 16 \ + --world_size 1 \ + --model vit_large_patch16 \ + --epochs 40 \ + --blr 5e-3 --layer_decay 0.65 \ + --weight_decay 0.05 --drop_path 0.2 \ + --nb_classes 5 \ + --data_path ./IDRiD_data/ \ + --task ./internal_IDRiD/ \ + --resume ./finetune_IDRiD/checkpoint-best.pth + +``` + + diff --git a/engine_finetune.py b/engine_finetune.py new file mode 100644 index 0000000..f501738 --- /dev/null +++ b/engine_finetune.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import math +import sys +import csv +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.data import Mixup +from timm.utils import accuracy +from typing import Iterable, Optional +import util.misc as misc +import util.lr_sched as lr_sched +from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix +from pycm import * +import matplotlib.pyplot as plt +import numpy as np + + + + +def misc_measures(confusion_matrix): + + acc = [] + sensitivity = [] + specificity = [] + precision = [] + G = [] + F1_score_2 = [] + mcc_ = [] + + for i in range(1, confusion_matrix.shape[0]): + cm1=confusion_matrix[i] + acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1)) + sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1]) + sensitivity.append(sensitivity_) + specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0]) + specificity.append(specificity_) + precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1]) + precision.append(precision_) + G.append(np.sqrt(sensitivity_*specificity_)) + F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_)) + mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1])) + mcc_.append(mcc) + + acc = np.array(acc).mean() + sensitivity = np.array(sensitivity).mean() + specificity = np.array(specificity).mean() + precision = np.array(precision).mean() + G = np.array(G).mean() + F1_score_2 = np.array(F1_score_2).mean() + mcc_ = np.array(mcc_).mean() + + return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_ + + + + + +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, + mixup_fn: Optional[Mixup] = None, log_writer=None, + args=None): + model.train(True) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = 20 + + accum_iter = args.accum_iter + + optimizer.zero_grad() + + if log_writer is not None: + print('log_dir: {}'.format(log_writer.log_dir)) + + for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + + # we use a per iteration (instead of per epoch) lr scheduler + if data_iter_step % accum_iter == 0: + lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) + + samples = samples.to(device, non_blocking=True) + targets = targets.to(device, non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + with torch.cuda.amp.autocast(): + outputs = model(samples) + loss = criterion(outputs, targets) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, clip_grad=max_norm, + parameters=model.parameters(), create_graph=False, + update_grad=(data_iter_step + 1) % accum_iter == 0) + if (data_iter_step + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + min_lr = 10. + max_lr = 0. + for group in optimizer.param_groups: + min_lr = min(min_lr, group["lr"]) + max_lr = max(max_lr, group["lr"]) + + metric_logger.update(lr=max_lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) + log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) + log_writer.add_scalar('lr', max_lr, epoch_1000x) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + + +@torch.no_grad() +def evaluate(data_loader, model, device, task, epoch, mode, num_class): + criterion = torch.nn.CrossEntropyLoss() + + metric_logger = misc.MetricLogger(delimiter=" ") + header = 'Test:' + + if not os.path.exists(task): + os.makedirs(task) + + prediction_decode_list = [] + prediction_list = [] + true_label_decode_list = [] + true_label_onehot_list = [] + + # switch to evaluation mode + model.eval() + + for batch in metric_logger.log_every(data_loader, 10, header): + images = batch[0] + target = batch[-1] + images = images.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + true_label=F.one_hot(target.to(torch.int64), num_classes=num_class) + + # compute output + with torch.cuda.amp.autocast(): + output = model(images) + loss = criterion(output, target) + prediction_softmax = nn.Softmax(dim=1)(output) + _,prediction_decode = torch.max(prediction_softmax, 1) + _,true_label_decode = torch.max(true_label, 1) + + prediction_decode_list.extend(prediction_decode.cpu().detach().numpy()) + true_label_decode_list.extend(true_label_decode.cpu().detach().numpy()) + true_label_onehot_list.extend(true_label.cpu().detach().numpy()) + prediction_list.extend(prediction_softmax.cpu().detach().numpy()) + + acc1,_ = accuracy(output, target, topk=(1,2)) + + batch_size = images.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + # gather the stats from all processes + true_label_decode_list = np.array(true_label_decode_list) + prediction_decode_list = np.array(prediction_decode_list) + confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)]) + acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix) + + auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro') + auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro') + + metric_logger.synchronize_between_processes() + + print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc)) + results_path = task+'_metrics_{}.csv'.format(mode) + with open(results_path,mode='a',newline='',encoding='utf8') as cfa: + wf = csv.writer(cfa) + data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]] + for i in data2: + wf.writerow(i) + + + if mode=='test': + cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list) + cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib") + plt.savefig(task+'confusion_matrix_test.jpg',dpi=600,bbox_inches ='tight') + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc + diff --git a/main_finetune.py b/main_finetune.py new file mode 100644 index 0000000..6c174d4 --- /dev/null +++ b/main_finetune.py @@ -0,0 +1,387 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import argparse +import datetime +import json +import numpy as np +import os +import time +from pathlib import Path + +import torch +import torch.backends.cudnn as cudnn +from torch.utils.tensorboard import SummaryWriter + +import timm + +assert timm.__version__ == "0.3.2" # version check +from timm.models.layers import trunc_normal_ +from timm.data.mixup import Mixup +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy + +import util.lr_decay as lrd +import util.misc as misc +from util.datasets import build_dataset +from util.pos_embed import interpolate_pos_embed +from util.misc import NativeScalerWithGradNormCount as NativeScaler + +import models_vit + +from engine_finetune import train_one_epoch, evaluate + + +def get_args_parser(): + parser = argparse.ArgumentParser('MAE fine-tuning for image classification', add_help=False) + parser.add_argument('--batch_size', default=64, type=int, + help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') + parser.add_argument('--epochs', default=50, type=int) + parser.add_argument('--accum_iter', default=1, type=int, + help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') + + # Model parameters + parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', + help='Name of model to train') + + parser.add_argument('--input_size', default=224, type=int, + help='images input size') + + parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', + help='Drop path rate (default: 0.1)') + + # Optimizer parameters + parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--weight_decay', type=float, default=0.05, + help='weight decay (default: 0.05)') + + parser.add_argument('--lr', type=float, default=None, metavar='LR', + help='learning rate (absolute lr)') + parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', + help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') + parser.add_argument('--layer_decay', type=float, default=0.75, + help='layer-wise lr decay from ELECTRA/BEiT') + + parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0') + + parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', + help='epochs to warmup LR') + + # Augmentation parameters + parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', + help='Color jitter factor (enabled only when not using Auto/RandAug)') + parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), + parser.add_argument('--smoothing', type=float, default=0.1, + help='Label smoothing (default: 0.1)') + + # * Random Erase params + parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') + parser.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') + parser.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') + parser.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') + + # * Mixup params + parser.add_argument('--mixup', type=float, default=0, + help='mixup alpha, mixup enabled if > 0.') + parser.add_argument('--cutmix', type=float, default=0, + help='cutmix alpha, cutmix enabled if > 0.') + parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + parser.add_argument('--mixup_prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') + parser.add_argument('--mixup_switch_prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') + parser.add_argument('--mixup_mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + + # * Finetuning params + parser.add_argument('--finetune', default='',type=str, + help='finetune from checkpoint') + parser.add_argument('--task', default='',type=str, + help='finetune from checkpoint') + parser.add_argument('--global_pool', action='store_true') + parser.set_defaults(global_pool=True) + parser.add_argument('--cls_token', action='store_false', dest='global_pool', + help='Use class token instead of global pool for classification') + + # Dataset parameters + parser.add_argument('--data_path', default='/home/jupyter/Mor_DR_data/data/data/IDRID/Disease_Grading/', type=str, + help='dataset path') + parser.add_argument('--nb_classes', default=1000, type=int, + help='number of the classification types') + + parser.add_argument('--output_dir', default='./output_dir', + help='path where to save, empty for no saving') + parser.add_argument('--log_dir', default='./output_dir', + help='path where to tensorboard log') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', + help='resume from checkpoint') + + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true', + help='Perform evaluation only') + parser.add_argument('--dist_eval', action='store_true', default=False, + help='Enabling distributed evaluation (recommended during training for faster monitor') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--pin_mem', action='store_true', + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--local_rank', default=-1, type=int) + parser.add_argument('--dist_on_itp', action='store_true') + parser.add_argument('--dist_url', default='env://', + help='url used to set up distributed training') + + return parser + + +def main(args): + misc.init_distributed_mode(args) + + print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) + print("{}".format(args).replace(', ', ',\n')) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + misc.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + + cudnn.benchmark = True + + dataset_train = build_dataset(is_train='train', args=args) + dataset_val = build_dataset(is_train='val', args=args) + dataset_test = build_dataset(is_train='test', args=args) + + if True: # args.distributed: + num_tasks = misc.get_world_size() + global_rank = misc.get_rank() + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + print("Sampler_train = %s" % str(sampler_train)) + if args.dist_eval: + if len(dataset_val) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_val = torch.utils.data.DistributedSampler( + dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias + else: + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + + if args.dist_eval: + if len(dataset_test) % num_tasks != 0: + print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' + 'This will slightly alter validation results as extra duplicate entries are added to achieve ' + 'equal num of samples per-process.') + sampler_test = torch.utils.data.DistributedSampler( + dataset_test, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias + else: + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + + else: + sampler_train = torch.utils.data.RandomSampler(dataset_train) + sampler_val = torch.utils.data.SequentialSampler(dataset_val) + sampler_test = torch.utils.data.SequentialSampler(dataset_test) + + if global_rank == 0 and args.log_dir is not None and not args.eval: + os.makedirs(args.log_dir, exist_ok=True) + log_writer = SummaryWriter(log_dir=args.log_dir+args.task) + else: + log_writer = None + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=True, + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, sampler=sampler_test, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + drop_last=False + ) + + + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + print("Mixup is activated!") + mixup_fn = Mixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.nb_classes) + + model = models_vit.__dict__[args.model]( + num_classes=args.nb_classes, + drop_path_rate=args.drop_path, + global_pool=args.global_pool, + ) + + if args.finetune and not args.eval: + checkpoint = torch.load(args.finetune, map_location='cpu') + + print("Load pre-trained checkpoint from: %s" % args.finetune) + checkpoint_model = checkpoint['model'] + state_dict = model.state_dict() + for k in ['head.weight', 'head.bias']: + if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: + print(f"Removing key {k} from pretrained checkpoint") + del checkpoint_model[k] + + # interpolate position embedding + interpolate_pos_embed(model, checkpoint_model) + + # load pre-trained model + msg = model.load_state_dict(checkpoint_model, strict=False) + print(msg) + + if args.global_pool: + assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} + else: + assert set(msg.missing_keys) == {'head.weight', 'head.bias'} + + # manually initialize fc layer + trunc_normal_(model.head.weight, std=2e-5) + + model.to(device) + + model_without_ddp = model + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + + print("Model = %s" % str(model_without_ddp)) + print('number of params (M): %.2f' % (n_parameters / 1.e6)) + + eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() + + if args.lr is None: # only base_lr is specified + args.lr = args.blr * eff_batch_size / 256 + + print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) + print("actual lr: %.2e" % args.lr) + + print("accumulate grad iterations: %d" % args.accum_iter) + print("effective batch size: %d" % eff_batch_size) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + # build optimizer with layer-wise lr decay (lrd) + param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, + no_weight_decay_list=model_without_ddp.no_weight_decay(), + layer_decay=args.layer_decay + ) + optimizer = torch.optim.AdamW(param_groups, lr=args.lr) + loss_scaler = NativeScaler() + + if mixup_fn is not None: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif args.smoothing > 0.: + criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + criterion = torch.nn.CrossEntropyLoss() + + print("criterion = %s" % str(criterion)) + + misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) + + if args.eval: + test_stats,auc_roc = evaluate(data_loader_test, model, device, args.task, epoch=0, mode='test',num_class=args.nb_classes) + exit(0) + + print(f"Start training for {args.epochs} epochs") + start_time = time.time() + max_accuracy = 0.0 + max_auc = 0.0 + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + data_loader_train.sampler.set_epoch(epoch) + train_stats = train_one_epoch( + model, criterion, data_loader_train, + optimizer, device, epoch, loss_scaler, + args.clip_grad, mixup_fn, + log_writer=log_writer, + args=args + ) + + val_stats,val_auc_roc = evaluate(data_loader_val, model, device,args.task,epoch, mode='val',num_class=args.nb_classes) + if max_auc 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [args.task+'checkpoint-best.pth'] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch} + model.save_checkpoint(save_dir=args.task, tag="checkpoint-best", client_state=client_state) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + \ No newline at end of file diff --git a/util/pos_embed.py b/util/pos_embed.py new file mode 100644 index 0000000..6acf8bd --- /dev/null +++ b/util/pos_embed.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed