major package upgrade&new weights
This commit is contained in:
+148
-208
@@ -1,208 +1,148 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
# Partly revised by YZ @UCL&Moorfields
|
||||
# --------------------------------------------------------
|
||||
|
||||
import math
|
||||
import sys
|
||||
import csv
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.data import Mixup
|
||||
from timm.utils import accuracy
|
||||
from typing import Iterable, Optional
|
||||
import util.misc as misc
|
||||
import util.lr_sched as lr_sched
|
||||
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score, average_precision_score,multilabel_confusion_matrix
|
||||
from pycm import *
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
|
||||
def misc_measures(confusion_matrix):
|
||||
|
||||
acc = []
|
||||
sensitivity = []
|
||||
specificity = []
|
||||
precision = []
|
||||
G = []
|
||||
F1_score_2 = []
|
||||
mcc_ = []
|
||||
|
||||
for i in range(1, confusion_matrix.shape[0]):
|
||||
cm1=confusion_matrix[i]
|
||||
acc.append(1.*(cm1[0,0]+cm1[1,1])/np.sum(cm1))
|
||||
sensitivity_ = 1.*cm1[1,1]/(cm1[1,0]+cm1[1,1])
|
||||
sensitivity.append(sensitivity_)
|
||||
specificity_ = 1.*cm1[0,0]/(cm1[0,1]+cm1[0,0])
|
||||
specificity.append(specificity_)
|
||||
precision_ = 1.*cm1[1,1]/(cm1[1,1]+cm1[0,1])
|
||||
precision.append(precision_)
|
||||
G.append(np.sqrt(sensitivity_*specificity_))
|
||||
F1_score_2.append(2*precision_*sensitivity_/(precision_+sensitivity_))
|
||||
mcc = (cm1[0,0]*cm1[1,1]-cm1[0,1]*cm1[1,0])/np.sqrt((cm1[0,0]+cm1[0,1])*(cm1[0,0]+cm1[1,0])*(cm1[1,1]+cm1[1,0])*(cm1[1,1]+cm1[0,1]))
|
||||
mcc_.append(mcc)
|
||||
|
||||
acc = np.array(acc).mean()
|
||||
sensitivity = np.array(sensitivity).mean()
|
||||
specificity = np.array(specificity).mean()
|
||||
precision = np.array(precision).mean()
|
||||
G = np.array(G).mean()
|
||||
F1_score_2 = np.array(F1_score_2).mean()
|
||||
mcc_ = np.array(mcc_).mean()
|
||||
|
||||
return acc, sensitivity, specificity, precision, G, F1_score_2, mcc_
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
||||
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
||||
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
||||
mixup_fn: Optional[Mixup] = None, log_writer=None,
|
||||
args=None):
|
||||
model.train(True)
|
||||
metric_logger = misc.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
header = 'Epoch: [{}]'.format(epoch)
|
||||
print_freq = 20
|
||||
|
||||
accum_iter = args.accum_iter
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if log_writer is not None:
|
||||
print('log_dir: {}'.format(log_writer.log_dir))
|
||||
|
||||
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
||||
|
||||
# we use a per iteration (instead of per epoch) lr scheduler
|
||||
if data_iter_step % accum_iter == 0:
|
||||
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
||||
|
||||
samples = samples.to(device, non_blocking=True)
|
||||
targets = targets.to(device, non_blocking=True)
|
||||
|
||||
if mixup_fn is not None:
|
||||
samples, targets = mixup_fn(samples, targets)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(samples)
|
||||
loss = criterion(outputs, targets)
|
||||
|
||||
loss_value = loss.item()
|
||||
|
||||
if not math.isfinite(loss_value):
|
||||
print("Loss is {}, stopping training".format(loss_value))
|
||||
sys.exit(1)
|
||||
|
||||
loss /= accum_iter
|
||||
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
||||
parameters=model.parameters(), create_graph=False,
|
||||
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
||||
if (data_iter_step + 1) % accum_iter == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
metric_logger.update(loss=loss_value)
|
||||
min_lr = 10.
|
||||
max_lr = 0.
|
||||
for group in optimizer.param_groups:
|
||||
min_lr = min(min_lr, group["lr"])
|
||||
max_lr = max(max_lr, group["lr"])
|
||||
|
||||
metric_logger.update(lr=max_lr)
|
||||
|
||||
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
||||
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
||||
""" We use epoch_1000x as the x-axis in tensorboard.
|
||||
This calibrates different curves when batch size changes.
|
||||
"""
|
||||
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
||||
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
|
||||
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
||||
|
||||
# gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, device, task, epoch, mode, num_class):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
metric_logger = misc.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
if not os.path.exists(task):
|
||||
os.makedirs(task)
|
||||
|
||||
prediction_decode_list = []
|
||||
prediction_list = []
|
||||
true_label_decode_list = []
|
||||
true_label_onehot_list = []
|
||||
|
||||
# switch to evaluation mode
|
||||
model.eval()
|
||||
|
||||
for batch in metric_logger.log_every(data_loader, 10, header):
|
||||
images = batch[0]
|
||||
target = batch[-1]
|
||||
images = images.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
true_label=F.one_hot(target.to(torch.int64), num_classes=num_class)
|
||||
|
||||
# compute output
|
||||
with torch.cuda.amp.autocast():
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
prediction_softmax = nn.Softmax(dim=1)(output)
|
||||
_,prediction_decode = torch.max(prediction_softmax, 1)
|
||||
_,true_label_decode = torch.max(true_label, 1)
|
||||
|
||||
prediction_decode_list.extend(prediction_decode.cpu().detach().numpy())
|
||||
true_label_decode_list.extend(true_label_decode.cpu().detach().numpy())
|
||||
true_label_onehot_list.extend(true_label.cpu().detach().numpy())
|
||||
prediction_list.extend(prediction_softmax.cpu().detach().numpy())
|
||||
|
||||
acc1,_ = accuracy(output, target, topk=(1,2))
|
||||
|
||||
batch_size = images.shape[0]
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
# gather the stats from all processes
|
||||
true_label_decode_list = np.array(true_label_decode_list)
|
||||
prediction_decode_list = np.array(prediction_decode_list)
|
||||
confusion_matrix = multilabel_confusion_matrix(true_label_decode_list, prediction_decode_list,labels=[i for i in range(num_class)])
|
||||
acc, sensitivity, specificity, precision, G, F1, mcc = misc_measures(confusion_matrix)
|
||||
|
||||
auc_roc = roc_auc_score(true_label_onehot_list, prediction_list,multi_class='ovr',average='macro')
|
||||
auc_pr = average_precision_score(true_label_onehot_list, prediction_list,average='macro')
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
|
||||
print('Sklearn Metrics - Acc: {:.4f} AUC-roc: {:.4f} AUC-pr: {:.4f} F1-score: {:.4f} MCC: {:.4f}'.format(acc, auc_roc, auc_pr, F1, mcc))
|
||||
results_path = task+'_metrics_{}.csv'.format(mode)
|
||||
with open(results_path,mode='a',newline='',encoding='utf8') as cfa:
|
||||
wf = csv.writer(cfa)
|
||||
data2=[[acc,sensitivity,specificity,precision,auc_roc,auc_pr,F1,mcc,metric_logger.loss]]
|
||||
for i in data2:
|
||||
wf.writerow(i)
|
||||
|
||||
|
||||
if mode=='test':
|
||||
cm = ConfusionMatrix(actual_vector=true_label_decode_list, predict_vector=prediction_decode_list)
|
||||
cm.plot(cmap=plt.cm.Blues,number_label=True,normalized=True,plot_lib="matplotlib")
|
||||
plt.savefig(task+'confusion_matrix_test.jpg',dpi=600,bbox_inches ='tight')
|
||||
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()},auc_roc
|
||||
|
||||
import os
|
||||
import csv
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Iterable, Optional
|
||||
from timm.data import Mixup
|
||||
from timm.utils import accuracy
|
||||
from sklearn.metrics import (
|
||||
accuracy_score, roc_auc_score, f1_score, average_precision_score,
|
||||
hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
|
||||
)
|
||||
from pycm import ConfusionMatrix
|
||||
import util.misc as misc
|
||||
import util.lr_sched as lr_sched
|
||||
|
||||
def train_one_epoch(
|
||||
model: torch.nn.Module,
|
||||
criterion: torch.nn.Module,
|
||||
data_loader: Iterable,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
device: torch.device,
|
||||
epoch: int,
|
||||
loss_scaler,
|
||||
max_norm: float = 0,
|
||||
mixup_fn: Optional[Mixup] = None,
|
||||
log_writer=None,
|
||||
args=None
|
||||
):
|
||||
"""Train the model for one epoch."""
|
||||
model.train(True)
|
||||
metric_logger = misc.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
print_freq, accum_iter = 20, args.accum_iter
|
||||
optimizer.zero_grad()
|
||||
|
||||
if log_writer:
|
||||
print(f'log_dir: {log_writer.log_dir}')
|
||||
|
||||
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
|
||||
if data_iter_step % accum_iter == 0:
|
||||
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
||||
|
||||
samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
|
||||
if mixup_fn:
|
||||
samples, targets = mixup_fn(samples, targets)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(samples)
|
||||
loss = criterion(outputs, targets)
|
||||
loss_value = loss.item()
|
||||
loss /= accum_iter
|
||||
|
||||
loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
|
||||
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
||||
if (data_iter_step + 1) % accum_iter == 0:
|
||||
optimizer.zero_grad()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
metric_logger.update(loss=loss_value)
|
||||
min_lr = 10.
|
||||
max_lr = 0.
|
||||
for group in optimizer.param_groups:
|
||||
min_lr = min(min_lr, group["lr"])
|
||||
max_lr = max(max_lr, group["lr"])
|
||||
|
||||
metric_logger.update(lr=max_lr)
|
||||
|
||||
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
||||
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
||||
""" We use epoch_1000x as the x-axis in tensorboard.
|
||||
This calibrates different curves when batch size changes.
|
||||
"""
|
||||
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
||||
log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
|
||||
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
|
||||
"""Evaluate the model."""
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
metric_logger = misc.MetricLogger(delimiter=" ")
|
||||
os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
|
||||
|
||||
model.eval()
|
||||
true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
|
||||
|
||||
for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
|
||||
images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
|
||||
target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
output_ = nn.Softmax(dim=1)(output)
|
||||
output_label = output_.argmax(dim=1)
|
||||
output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
|
||||
|
||||
metric_logger.update(loss=loss.item())
|
||||
true_onehot.extend(target_onehot.cpu().numpy())
|
||||
pred_onehot.extend(output_onehot.detach().cpu().numpy())
|
||||
true_labels.extend(target.cpu().numpy())
|
||||
pred_labels.extend(output_label.detach().cpu().numpy())
|
||||
pred_softmax.extend(output_.detach().cpu().numpy())
|
||||
|
||||
accuracy = accuracy_score(true_labels, pred_labels)
|
||||
hamming = hamming_loss(true_onehot, pred_onehot)
|
||||
jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
|
||||
average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
|
||||
kappa = cohen_kappa_score(true_labels, pred_labels)
|
||||
f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
|
||||
roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
|
||||
precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
|
||||
recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
|
||||
|
||||
score = (f1 + roc_auc + kappa) / 3
|
||||
if log_writer:
|
||||
for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
|
||||
[accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
|
||||
log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
|
||||
|
||||
print(f'val loss: {metric_logger.meters["loss"].global_avg}')
|
||||
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
|
||||
f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
|
||||
f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
|
||||
results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
|
||||
file_exists = os.path.isfile(results_path)
|
||||
with open(results_path, 'a', newline='', encoding='utf8') as cfa:
|
||||
wf = csv.writer(cfa)
|
||||
if not file_exists:
|
||||
wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
|
||||
wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
|
||||
|
||||
if mode == 'test':
|
||||
cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
|
||||
cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
|
||||
plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
|
||||
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score
|
||||
|
||||
Reference in New Issue
Block a user