pytorch中的gather函数

it2023-08-21  64

先来看官方文档的解释:

Gathers values along an axis specified by dim. For a 3-D tensor the output is specified by: out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 If input is an n-dimensional tensor with size (x0,x1...,xi−1,xi,xi+1,...,xn−1)(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})(x0​,x1​...,xi−1​,xi​,xi+1​,...,xn−1​) and dim = i, then index must be an nnn -dimensional tensor with size (x0,x1,...,xi−1,y,xi+1,...,xn−1)(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})(x0​,x1​,...,xi−1​,y,xi+1​,...,xn−1​) where y≥1y \geq 1y≥1 and out will have the same size as index.

接下来举个例子:

import torch b = torch.Tensor([[1, 2, 3], [4, 5, 6]]) index_1 = torch.LongTensor([[0, 1], [2, 0]]) print(torch.gather(b, dim=1, index=index_1)) # 输出 tensor([[1., 2.], [6., 4.]])

接下来根据文档计算一下结果的输出,out[0][0] = input[0][index[0][0]] = input[0][0] = 1

                                                             out[0][1] = input[0][index[0][1]] = input[0][1] = 2

                                                             out[1][0] = input[1][index[1][0]] = input[1][2] = 6

                                                             out[1][1] = input[1][index[1][1]] = input[1][0] = 4

 

 

最新回复(0)