This commit is contained in:
rmaphoh
2024-06-03 11:17:50 +01:00
parent 8593fef1ef
commit f5ab012b71
4 changed files with 328 additions and 6 deletions
+22
View File
@@ -306,6 +306,28 @@ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
model.save_checkpoint(save_dir=args.task, tag="checkpoint-best", client_state=client_state)
def save_model_pretrain(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
print(model_without_ddp.state_dict().keys())
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):