This commit is contained in:
rmaphoh
2024-06-03 11:17:50 +01:00
parent 8593fef1ef
commit f5ab012b71
4 changed files with 328 additions and 6 deletions
+3 -6
View File
@@ -342,11 +342,6 @@ def main(args):
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
if epoch==(args.epochs-1):
test_stats,auc_roc = evaluate(data_loader_test, model, device,args.task,epoch, mode='test',num_class=args.nb_classes)
if log_writer is not None:
log_writer.add_scalar('perf/val_acc1', val_stats['acc1'], epoch)
log_writer.add_scalar('perf/val_auc', val_auc_roc, epoch)
@@ -366,7 +361,9 @@ def main(args):
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
state_dict_best = torch.load(args.task+'checkpoint-best.pth', map_location='cpu')
model_without_ddp.load_state_dict(state_dict_best['model'])
test_stats,auc_roc = evaluate(data_loader_test, model_without_ddp, device,args.task,epoch=0, mode='test',num_class=args.nb_classes)
if __name__ == '__main__':
args = get_args_parser()