在运行任何程序之前写入下面代码(可以放在主代码的开头)
def seed_torch(seed=666): torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) # Numpy module. random.seed(seed) # Python random module. torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True seed_torch()对模型和相应的数据进行.cuda()处理。就可以将内存中的数据复制到GPU的显存中去。从而可以通过GPU来进行运算了 判断是否有GPU资源
import torch print(torch.cuda.is_available()) # 判断是否有使用的GPU print(torch.cuda.device_count()) # 返回能够使用GPU的数量将tensor迁移到显存中
import torch from torch.autograd import Variable ten1 = torch.FloatTensor(2) print(ten1) print(ten1.cuda())将Variable迁移到显存中
import torch from torch.autograd import Variable ten1 = torch.FloatTensor(2) ten1 = Variable(ten1) print(ten1.cuda())位于不同GPU显存上的数据也是不能直接进行计算的。torch.FloatTensor是不可以直接与torch.cuda.FloatTensor进行基本运算的 Variable可以进行反向传播来进行自动求导,可以说是一种能够记录操作信息并且能够自动求导的容器
将torch.nn下的基本模型迁移到显存中 对模型.cuda()实际上也相当于将模型使用到的参数存储到了显存上去
import torch.nn linear = torch.nn.Linear(2, 2) linear_cuda = linear.cuda()加一句,如果要将显存中的数据复制到内存中,则对cuda数据类型使用.cpu()方法即可
model.train()和model.eval()分别在训练和测试中都要写,它们的作用如下: (1) model.train() 启用BatchNormalization和 Dropout,将BatchNormalization和Dropout置为True (2) model.eval() 不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False 注意:在训练模块中千万不要忘了写model.train();在评估(或测试)模块千万不要忘了写model.eval()
requires_grad=True 要求计算梯度; requires_grad=False 不要求计算梯度; 在pytorch中,tensor有一个 requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。 tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么 所有依赖它的节点requires_grad都为True (即使其他相依赖的tensor的requires_grad = False)
x = torch.randn(10, 5, requires_grad = True) y = torch.randn(10, 5, requires_grad = False) z = torch.randn(10, 5, requires_grad = False) w = x + y + z print(w.requires_grad)torch.no_grad()是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。 with torch.no_grad()或者@torch.no_grad()中的数据不需要计算梯度,也不会进行反向传播。
x = torch.randn(2, 3, requires_grad = True) y = torch.randn(2, 3, requires_grad = False) z = torch.randn(2, 3, requires_grad = False) m=x+y+z with torch.no_grad(): w = x + y + z print(w) print(m) print(w.requires_grad) print(w.grad_fn) print(w.requires_grad)后续更新…