2020-10-22

it2026-04-03  6

方法一:使用torch.cuda.empty_cache()删除一些不需要的变量代码示例如下:

try:     output = model(input) except RuntimeError as exception:     if "out of memory" in str(exception):         print("WARNING: out of memory")         if hasattr(torch.cuda, 'empty_cache'):             torch.cuda.empty_cache()     else:         raise exception

 

方法二:测试的时候爆显存有可能是忘记设置no_grad, 示例代码如下:

    with torch.no_grad():         for ii,(inputs,filelist) in tqdm(enumerate(test_loader), desc='predict'):             if opt.use_gpu:                 inputs = inputs.cuda()                 if len(inputs.shape) < 4:                     inputs = inputs.unsqueeze(1)               else:                 if len(inputs.shape) < 4:                     inputs = torch.transpose(inputs, 1, 2)                     inputs = inputs.unsqueeze(1)  

感谢 https://blog.csdn.net/xiaoxifei/article/details/84377204?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

最新回复(0)