学习 Pytorch1:What is Pytorch?

it2024-04-14  57

学习 Pytorch1:What is Pytorch?

创建张量命令

x1 = torch.empty(5,3) x2 = torch.zeros(5,3) x3 = torch.ones(5,3) x4 = torch.tensor([5.5,3.2]) x5 = torch.rand(5,3) x6 = torch.randn_like(x5, dtype = torch.float)

获取张量形状

print(x6.size()) print(x6.shape)

改变张量形状

x = torch. rand(4,4) y = x.view(16)

获取张量元素值

x = torch.rand(1) print(x.item())

Torch Tensor 与 Numpy Array 之间的转换

Torch Tensor 转换为 Numpy Array

a = torch.ones(5) b = a.numpy() a.add_(1)

Numpy Array 转换为 Torch Tensor

b = np.ones(5) a = torch.from_numpy(b) np.add(b, 1, out = b)

CUDA Tensor

张量可以使用 .to 方法移到任何设备上

if torch.cuda.is_available(): device = torch.device("cuda") # a CUDA device object y = torch.ones_like(x, device=device) # directly create a tensor on GPU x = x.to(device) # or just use strings ``.to("cuda")`` z = x + y print(z) print(z.to("cpu", torch.double)) # ``.to`` can also change dtype together!
最新回复(0)