pytorch 断点续训checkpoint

it2023-04-05  72

保存训练点checkpoint:

def save_checkpoint_state(epoch,model,optimizer,dir,scheduler): checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict':scheduler.state_dict(), } if not os.path.isdir(dir): os.mkdir(dir) torch.save(checkpoint, os.path.join(dir,'check_point.pth'))

载入训练点checkpoint:

def get_checkpoint_state(dir,model,optimizer,scheduler): # 恢复上次的训练状态 logger.info("Resume from checkpoint...") checkpoint = torch.load(os.path.join(dir,'check_point.pth')) model.load_state_dict(checkpoint['model_state_dict']) epoch=checkpoint['epoch'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) logger.info('sucessfully recover from the last state') return model,epoch,optimizer,scheduler
最新回复(0)