diff --git a/main_finetune.py b/main_finetune.py index 3a60cf9..75c822d 100644 --- a/main_finetune.py +++ b/main_finetune.py @@ -370,7 +370,7 @@ def main(args, criterion): if epoch == (args.epochs - 1): checkpoint = torch.load(os.path.join(args.output_dir, args.task, 'checkpoint-best.pth'), map_location='cpu') - model.load_state_dict(checkpoint['model'], strict=False) + model_without_ddp.load_state_dict(checkpoint['model'], strict=False) model.to(device) print("Test with the best model, epoch = %d:" % checkpoint['epoch']) test_stats, auc_roc = evaluate(data_loader_test, model, device, args, -1, mode='test',