Files
RETFound/engine_finetune.py
2025-02-19 12:37:05 +00:00

149 lines
6.9 KiB
Python

import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Iterable, Optional
from timm.data import Mixup
from timm.utils import accuracy
from sklearn.metrics import (
accuracy_score, roc_auc_score, f1_score, average_precision_score,
hamming_loss, jaccard_score, recall_score, precision_score, cohen_kappa_score
)
from pycm import ConfusionMatrix
import util.misc as misc
import util.lr_sched as lr_sched
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
loss_scaler,
max_norm: float = 0,
mixup_fn: Optional[Mixup] = None,
log_writer=None,
args=None
):
"""Train the model for one epoch."""
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
print_freq, accum_iter = 20, args.accum_iter
optimizer.zero_grad()
if log_writer:
print(f'log_dir: {log_writer.log_dir}')
for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, f'Epoch: [{epoch}]')):
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples, targets = samples.to(device, non_blocking=True), targets.to(device, non_blocking=True)
if mixup_fn:
samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(outputs, targets)
loss_value = loss.item()
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss/train', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device, args, epoch, mode, num_class, log_writer):
"""Evaluate the model."""
criterion = nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
os.makedirs(os.path.join(args.output_dir, args.task), exist_ok=True)
model.eval()
true_onehot, pred_onehot, true_labels, pred_labels, pred_softmax = [], [], [], [], []
for batch in metric_logger.log_every(data_loader, 10, f'{mode}:'):
images, target = batch[0].to(device, non_blocking=True), batch[-1].to(device, non_blocking=True)
target_onehot = F.one_hot(target.to(torch.int64), num_classes=num_class)
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
output_ = nn.Softmax(dim=1)(output)
output_label = output_.argmax(dim=1)
output_onehot = F.one_hot(output_label.to(torch.int64), num_classes=num_class)
metric_logger.update(loss=loss.item())
true_onehot.extend(target_onehot.cpu().numpy())
pred_onehot.extend(output_onehot.detach().cpu().numpy())
true_labels.extend(target.cpu().numpy())
pred_labels.extend(output_label.detach().cpu().numpy())
pred_softmax.extend(output_.detach().cpu().numpy())
accuracy = accuracy_score(true_labels, pred_labels)
hamming = hamming_loss(true_onehot, pred_onehot)
jaccard = jaccard_score(true_onehot, pred_onehot, average='macro')
average_precision = average_precision_score(true_onehot, pred_softmax, average='macro')
kappa = cohen_kappa_score(true_labels, pred_labels)
f1 = f1_score(true_onehot, pred_onehot, zero_division=0, average='macro')
roc_auc = roc_auc_score(true_onehot, pred_softmax, multi_class='ovr', average='macro')
precision = precision_score(true_onehot, pred_onehot, zero_division=0, average='macro')
recall = recall_score(true_onehot, pred_onehot, zero_division=0, average='macro')
score = (f1 + roc_auc + kappa) / 3
if log_writer:
for metric_name, value in zip(['accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa', 'score'],
[accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa, score]):
log_writer.add_scalar(f'perf/{metric_name}', value, epoch)
print(f'val loss: {metric_logger.meters["loss"].global_avg}')
print(f'Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}, ROC AUC: {roc_auc:.4f}, Hamming Loss: {hamming:.4f},\n'
f' Jaccard Score: {jaccard:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f},\n'
f' Average Precision: {average_precision:.4f}, Kappa: {kappa:.4f}, Score: {score:.4f}')
metric_logger.synchronize_between_processes()
results_path = os.path.join(args.output_dir, args.task, f'metrics_{mode}.csv')
file_exists = os.path.isfile(results_path)
with open(results_path, 'a', newline='', encoding='utf8') as cfa:
wf = csv.writer(cfa)
if not file_exists:
wf.writerow(['val_loss', 'accuracy', 'f1', 'roc_auc', 'hamming', 'jaccard', 'precision', 'recall', 'average_precision', 'kappa'])
wf.writerow([metric_logger.meters["loss"].global_avg, accuracy, f1, roc_auc, hamming, jaccard, precision, recall, average_precision, kappa])
if mode == 'test':
cm = ConfusionMatrix(actual_vector=true_labels, predict_vector=pred_labels)
cm.plot(cmap=plt.cm.Blues, number_label=True, normalized=True, plot_lib="matplotlib")
plt.savefig(os.path.join(args.output_dir, args.task, 'confusion_matrix_test.jpg'), dpi=600, bbox_inches='tight')
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, score