先来看官方文档的解释:
Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 If input is an n-dimensional tensor with size (x0,x1...,xi−1,xi,xi+1,...,xn−1)(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})(x0,x1...,xi−1,xi,xi+1,...,xn−1) and dim = i, then index must be an nnn -dimensional tensor with size (x0,x1,...,xi−1,y,xi+1,...,xn−1)(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})(x0,x1,...,xi−1,y,xi+1,...,xn−1) where y≥1y \geq 1y≥1 and out will have the same size as index.接下来举个例子:
import torch b = torch.Tensor([[1, 2, 3], [4, 5, 6]]) index_1 = torch.LongTensor([[0, 1], [2, 0]]) print(torch.gather(b, dim=1, index=index_1)) # 输出 tensor([[1., 2.], [6., 4.]])接下来根据文档计算一下结果的输出,out[0][0] = input[0][index[0][0]] = input[0][0] = 1
out[0][1] = input[0][index[0][1]] = input[0][1] = 2
out[1][0] = input[1][index[1][0]] = input[1][2] = 6
out[1][1] = input[1][index[1][1]] = input[1][0] = 4