关于python:关于pytorch中scatteradd函数的分析理解与实现

40次阅读

共计 2164 个字符,预计需要花费 6 分钟才能阅读完成。

import torch
import numpy as np
from torch import Tensor
“””
@overload
def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: …
@overload
def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: …
def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: …
对 pytorch 中的 scatter_add 函数的了解和简略测试:

参数:tensor,dim,index,tensor

返回:tensor

性能:将 other_tensor 的值累加到 self_tensor 的相应地位,用 index_tensor 对应地位的值替换掉 self_tensor 下标的 dim 维

举例:

self_tensor  = [[1, 2], [3, 4]] shape=(2,2)
other_tensor = [[5, 6], [7, 8]] shape=(2,2)
index_tensor = [[0, 0], [1, 1]] shape=(2,2)
dim = 1
以上三个 tensor 的 shape 必须统一,下标为:[0,0] [0,1] [1,0] [1,1]
dim=1,那么,self_tensor 的第 1 维下标由 index_tensor 示意,[0,0] [0,0] [1,1] [1,1]
则:
    self_tensor[0,0] = 1 + 5 + 6 = 12
    self_tensor[0,1] = 2
    self_tensor[1,0] = 3
    self_tensor[1,1] = 4 + 7 + 8 = 19

“””
def scatter_add(input_tensor: torch.Tensor, dim: int, index: torch.Tensor, other: torch.Tensor) -> torch.Tensor:

# tensor 的维数是不确定的,因而无奈用 for 循环的形式
# 如果 tensor 是 2 维,[金属期货](https://www.gendan5.com/cf/mf.html) 那么 dim= 0 或 1,两层 for 循环,用 other 对 self 进行填充
# 如果 tensor 是 3 维,那么 dim=0、1、2,须要三层 for 循环来遍历 other
if input_tensor.dim() == 2:
    for i in range(index_tensor.size()[0]):
        for j in range(index_tensor.size()[1]):
            if dim == 0:  # self 矩阵的第 0 维索引
                self_tensor[index_tensor[i][j]][j] += other_tensor[i][j]
            elif dim == 1:  # self 矩阵的第 1 维索引
                self_tensor[i][index_tensor[i][j]] += other_tensor[i][j]
elif input_tensor.dim() == 3:
    pass
return self_tensor

if name == ‘__main__’:

index_tensor = torch.tensor([[0, 0], [1, 1]])
print('index_tensor: \n', index_tensor.dim())
self_tensor = torch.arange(1, 5).view(2, 2)
print('self_tensor: \n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: \n', other_tensor)
dim = 1
for i in range(index_tensor.size()[0]):
    for j in range(index_tensor.size()[1]):
        replace_index = index_tensor[i][j]
        print(i, j, replace_index)
        if dim == 0:
            # self 矩阵的第 0 维索引
            self_tensor[replace_index][j] += other_tensor[i][j]
        elif dim == 1:
            # self 矩阵的第 1 维索引
            self_tensor[i][replace_index] += other_tensor[i][j]
print(self_tensor)
index_tensor = torch.tensor([[0, 1], [1, 1]])
print('index_tensor: \n', index_tensor)
self_tensor = torch.arange(0, 4).view(2, 2)
print('self_tensor: \n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: \n', other_tensor)
self_tensor.scatter_add_(dim=0, index=index_tensor, src=other_tensor)
print(self_tensor)

正文完
 0