方法一:使用torch.cuda.empty_cache()删除一些不需要的变量代码示例如下:
try: output = model(input) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception
方法二:测试的时候爆显存有可能是忘记设置no_grad, 示例代码如下:
with torch.no_grad(): for ii,(inputs,filelist) in tqdm(enumerate(test_loader), desc='predict'): if opt.use_gpu: inputs = inputs.cuda() if len(inputs.shape) < 4: inputs = inputs.unsqueeze(1) else: if len(inputs.shape) < 4: inputs = torch.transpose(inputs, 1, 2) inputs = inputs.unsqueeze(1)
感谢 https://blog.csdn.net/xiaoxifei/article/details/84377204?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param
