20240603
This commit is contained in:
+3
-6
@@ -342,11 +342,6 @@ def main(args):
|
||||
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
|
||||
loss_scaler=loss_scaler, epoch=epoch)
|
||||
|
||||
|
||||
if epoch==(args.epochs-1):
|
||||
test_stats,auc_roc = evaluate(data_loader_test, model, device,args.task,epoch, mode='test',num_class=args.nb_classes)
|
||||
|
||||
|
||||
if log_writer is not None:
|
||||
log_writer.add_scalar('perf/val_acc1', val_stats['acc1'], epoch)
|
||||
log_writer.add_scalar('perf/val_auc', val_auc_roc, epoch)
|
||||
@@ -366,7 +361,9 @@ def main(args):
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
state_dict_best = torch.load(args.task+'checkpoint-best.pth', map_location='cpu')
|
||||
model_without_ddp.load_state_dict(state_dict_best['model'])
|
||||
test_stats,auc_roc = evaluate(data_loader_test, model_without_ddp, device,args.task,epoch=0, mode='test',num_class=args.nb_classes)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args_parser()
|
||||
|
||||
Reference in New Issue
Block a user