pytroch中的scatter

    科技2024-06-18  69

    scatter_add_(dim,index,src )

    将张量src中的所有值添加到张量中self指定的索引index中。

    self的更新公式

    二维 self[index[i][j]][j] += other[i][j] # 如果 dim == 0 self[i][index[i][j]] += other[i][j] # 如果 dim == 1 三维 self[index[i][j][k]][j][k] += other[i][j][k] # 如果 dim == 0 self[i][index[i][j][k]][k] += other[i][j][k] # 如果 dim == 1 self[i][j][index[i][j][k]] += other[i][j][k] # 如果 dim == 2

    self,index并且src应具有相同数量的尺寸。还要求所有尺寸,以及所有尺寸 。index.size(d) <= src.size(d)dindex.size(d) <= self.size(d)d != dim 只看官网和公式还是有点模糊,下面我们借助一个二维的例子推导一下具体的实现过程。

    二维

    import torch src = torch.rand(2, 5) output_0 = torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src) output_1 = torch.ones(3, 5).scatter_add_(1, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), src) print(torch.ones(3,5)) print("src:{}".format(src)) print("output_0:{}".format(output_0)) print("output_1:{}".format(output_1)) tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]]) src:tensor([[0.0278, 0.3649, 0.0442, 0.4839, 0.5509], [0.4008, 0.4180, 0.4243, 0.8407, 0.0208]]) output_0:tensor([[1.0278, 1.4180, 1.4243, 1.4839, 1.5509], [1.0000, 1.3649, 1.0000, 1.8407, 1.0000], [1.4008, 1.0000, 1.0442, 1.0000, 1.0208]]) output_1:tensor([[2.0626, 1.3649, 1.0442, 1.0000, 1.0000], [1.8423, 1.8407, 1.4216, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]])

    在上述代码中:

    self = torch.ones(3,5) index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) other = src

    dim=0:

    self[index[i][j]][j] += other[i][j] # 如果 dim == 0

    i = 0: j= 0:self [index[0][0]] [0] += other[0][0] = self[0][0] += other[0][0] 所以:self[0][0] = 1 + 0.0278 = 1.0278 j= 1:self [index[0][1]] [1] += other[0][1] = self[1][1] += other[0][1] 所以:self[1][1] = 1 + 0.3649 = 1.3649 j= 2:self [index[0][2]] [2] += other[0][2] = self[2][2] += other[0][2] 所以:self[2][2] = 1 + 0.0442 = 1.0442 j= 3:self [index[0][3]] [3] += other[0][3] = self[0][3] += other[0][3] 所以:self[0][3] = 1 + 0.4839 = 1.4839 j= 4:self [index[0][4]] [0] += other[0][4] = self[0][4] += other[0][4] 所以:self[0][4] = 1 + 0.5509 = 1.5509 i=1 同理
    dim=1:

    self[i][index[i][j]] += other[i][j] # 如果 dim == 1

    i = 0: j=0: self[0][index[0][0]] += other[0][0] = self[0][0] += other[0][0] 所以:self[0][0] = 1 + 0.0278 = 1.0278 j=1: self[0][index[0][1]] += other[0][1] = self[0][1] += other[0][1] 所以:self[0][1] = 1 + 0.3649 = 1.3649 j=2: self[0][index[0][2]] += other[0][2] = self[0][2] += other[0][2] 所以:self[0][2] = 1 + 0.0442 = 1.0442 j=3: self[0][index[0][3]] += other[0][3] = self[0][0] += other[0][3] 所以:self[0][0] = 1.0278 + 0.4839 = 1.5117 j=4: self[0][index[0][4]] += other[0][4] = self[0][0] += other[0][4] 所以:self[0][0] = 1.5117 + 0.5509= 2.0626 i =2 同理 三维上也是如此推算的
    Processed: 0.023, SQL: 8