pytorch相关

it2025-10-30  2

函数

torch.bmm(a,b)——矩阵乘法a*b -tensor维度为3,a的size为(b,h,w),b的size为(b,w,h)

x.permute(0,2,1)——交换 -x的第二列与第一列交换位置

torch.matmul()——矩阵乘法 针对高维

调用gpu训练

device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”) model.to(device) input = Variable().cuda() 将所有最开始读取数据时的tersor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。 对所有的输入都要进行这步复制操作
最新回复(0)