pytorch函数之scatter()和scatter_()
前言这两个函数,其实本来有一个大佬写的比较清楚了,但是说实话,总是给忘具体使用细节。我还是自己写一个更清晰的吧。官方文档scatter_()scatter_(input, dim, index, src) → Tensor其实这样写会造成迷惑,建议这么按下面的理解:理解input.scatter_(dim, index, src) → Tensorinput: 我们需要插入数据的起源tensor;
·
前言
这两个函数,理清楚的人很清楚,不清楚的人很不清楚,建议直接看2.举例
官方文档
scatter_()
'官方定义'
scatter(input, dim, index, src) → Tensor
实际使用:如下面
input.scatter_(dim, index, src) → Tensor
'Or'
input.scatter(dim, index, src) → Tensor
'区别是scatter_函数不会回滚,使用后返回的就是更改后的input。而scatter是在内存中生成另外一个对象,不会覆盖原input'
- input: 我们需要插入数据的起源
tensor;也就是想要改变内部的tensor - dim:我们想要从哪个维度去改
input数据 - index:给出改的元素索引,也就是位置,说在“坐标”可能好理解一点。
- src:准备好的插入到
input中指定位置的数据。
总结:input.scatter_(dim, index, src):从【src源数据】中获取的数据,按照【dim指定的维度】和【index指定的位置】,替换input中的数据。
2. 举例
先看代码
batch_size = 2
hidden_size = 8
src = torch.rand(batch_size, hidden_size)
input_ = torch.zeros(batch_size+1, hidden_size)
index = torch.LongTensor([[0,1,2,0,0,1,1,2],[2,0,0,1,2,1,1,1]])
print('src\n',src)
print('index\n',index)
print('input_\n',input_)
print('ans:\n',input_.scatter_(0, index, src))
'''
src
tensor([[0.3304, 0.5643, 0.2362, 0.1929, 0.2400, 0.6672, 0.5217, 0.4471],
[0.0433, 0.2996, 0.9913, 0.4336, 0.8540, 0.8522, 0.0408, 0.1014]])
index
tensor([[0, 1, 2, 0, 0, 1, 1, 2],
[2, 0, 0, 1, 2, 1, 1, 1]])
input_
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
ans:
tensor([[0.3304, 0.2996, 0.9913, 0.1929, 0.2400, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5643, 0.0000, 0.4336, 0.0000, 0.8522, 0.0408, 0.1014],
[0.0433, 0.0000, 0.2362, 0.0000, 0.8540, 0.0000, 0.0000, 0.4471]])
'''
比如上述代码,dim=0代表按行赋值,那么index[1][3]=1,代表更改input中的[1]行;另外,index[1][3]对应的src[1][3]的值是0.4336;index[1][3]的[3]列,因此是把0.4336这个数值放入input中的[1][3]的位置。
如果还是不太清楚,我们把dim=1设定为按列
src = torch.rand(batch_size, hidden_size).transpose(0,1)
input_ = torch.zeros(batch_size+1, hidden_size).transpose(0,1)
index = torch.LongTensor([[0,1,2,0,0,1,1,2],[2,0,0,1,2,1,1,1]]).transpose(0,1)
print('src\n',src)
print('index\n',index)
print('input_\n',input_)
# print('ans:\n',input_.scatter_(0, index, src))
print('ans:\n',input_.scatter_(1, index, src))
'''
src
tensor([[0.3504, 0.3369],
[0.1163, 0.3850],
[0.5554, 0.5531],
[0.0440, 0.2904],
[0.2444, 0.6650],
[0.4698, 0.5640],
[0.1331, 0.5830],
[0.0408, 0.8508]])
index
tensor([[0, 2],
[1, 0],
[2, 0],
[0, 1],
[0, 2],
[1, 1],
[1, 1],
[2, 1]])
input_
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])
ans:
tensor([[0.3504, 0.0000, 0.3369],
[0.3850, 0.1163, 0.0000],
[0.5531, 0.0000, 0.5554],
[0.0440, 0.2904, 0.0000],
[0.2444, 0.0000, 0.6650],
[0.0000, 0.5640, 0.0000],
[0.0000, 0.5830, 0.0000],
[0.0000, 0.8508, 0.0408]])
'''
同上,举例: dim=1代表按列赋值, index[4][1]=2,代表行是[4]列是[2],说明是把src[4][1]的值,赋值给input[4][2]
更多推荐

所有评论(0)