PyTorch中scatter()和scatter

    科技2022-07-14  124

    本文讲的是我对PyTorch中scatter()函数的理解。

    原创,转载请标明来源。

     

    一言以蔽之:修改tensor中的指定位置的值。

    函数

    scatter(dim, index, src) 

    dim: 索引的维度。按照i, j, k, ...的哪个方向进行索引index: 索引。可以是一个tensor,存储需要改的元素的位置的tensorsrc: 用src中的值来修改。可以是tensor;可以是一个数字,用同样的数字写入tensor

    scatter() 和 scatter_() 函数功能相同:只不过带下划线的函数,通常是直接修改原来的tensor

    原理

    self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

    函数的具体实现,如上述代码框所示:使用src中的值,修改self中位置为index[i][j][k]的值。

    举例

    # 这是src #tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945], # [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]]) # index是[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]] # self就是下面的torch.zeros(3, 5) torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) # 这是结果 #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945], # [0.0000, 0.3340, 0.0000, 0.0943, 0.0000], # [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

    在这个例子中,dim=0,是按照i的方向修改torch.zeros(3, 5)的。可以看出,index实际上表达了一种映射关系:同样都是在第j列上,将src在该列的值 根据index在该列的指示 映射到self的这一列上。

    比如:src在第0列的0.01940和0.2078,被放到self的第0列上,但不是完全一样的放过来,而是经过index变化了上下位置。其他列同理。

    在简单RNN中的应用

    该应用代码如下 [1]:

    def one_hot(x, n_class, dtype=torch.float32): # X shape: (batch), output shape: (batch, n_class) x = x.long() res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) res.scatter_(1, x.view(-1, 1), 1) return res x = torch.tensor([0, 2]) one_hot(x, vocab_size)

     运行结果:

    tensor([[1., 0., 0., ..., 0., 0., 0.], [0., 0., 1., ..., 0., 0., 0.]])

    在本例中,该函数的任务是将输入的文本使用one_hot编码。其中,x是一个vector,代表一个二字词语,其中的0和2代表汉字(在程序上文定义的字典中)所对应的数字。vocab_size是字典大小,即在该程序中所考虑的汉字总个数。n_class是one_hot编码中所考虑的类别数,在本例中等于vocab_size。

    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device) :生成了两行,vocab_size列的零矩阵

    res.scatter_(1, x.view(-1, 1), 1) :在res中,将1,按照dim=1(即不改行改列)的方向,根据[[0],[2]]所指示的位置,放入res中。(比如,x中的0,代表要放入第0列;而0本身处于第0行,所以是第0行中的第0列。)

     

    参考文献

    [1] 循环神经网络的从零开始实现. 原书作者:阿斯顿·张、李沐、扎卡里 C. 立顿、亚历山大 J. 斯莫拉以及其他社区贡献者. 原书名称:动手学深度学习Pytorch版

    [2] PyTorch官方文档

    Processed: 0.016, SQL: 8