Pytorch .masked

it2026-01-24  5

masked_fill_(mask, value) 掩码操作 用value填充tensor中与mask中值为1位置相对应的元素。mask的形状必须与要填充的tensor形状一致。

a = torch.randn(5,6) x = [5,4,3,2,1] mask = torch.zeros(5,6,dtype=torch.float) for e_id, src_len in enumerate(x): mask[e_id, src_len:] = 1 mask = mask.to(device = 'cpu') print(mask) a.data.masked_fill_(mask.byte(),-float('inf')) print(a) ----------------------------输出 tensor([[0., 0., 0., 0., 0., 1.], [0., 0., 0., 0., 1., 1.], [0., 0., 0., 1., 1., 1.], [0., 0., 1., 1., 1., 1.], [0., 1., 1., 1., 1., 1.]]) tensor([[-0.1053, -0.0352, 1.4759, 0.8849, -0.7233, -inf], [-0.0529, 0.6663, -0.1082, -0.7243, -inf, -inf], [-0.0364, -1.0657, 0.8359, -inf, -inf, -inf], [ 1.4160, 1.1594, -inf, -inf, -inf, -inf], [ 0.4163, -inf, -inf, -inf, -inf, -inf]])
最新回复(0)