在某些状况下,咱们须要用 Pytorch 做一些高级的索引 / 抉择,所以在这篇文章中,咱们将介绍这类工作的三种最常见的办法:torch.index_select, torch.gather and torch.take
咱们首先从一个 2D 示例开始,并将抉择后果可视化,而后延申到 3D 和更简单场景。最初以表格的模式总结了这些函数及其区别。
torch.index_select
torch.index_select
是 PyTorch 中用于按索引抉择张量元素的函数。它的作用是从输出张量中依照给定的索引值,选取对应的元素造成一个新的张量。它沿着一个维度抉择元素,同时放弃其余维度不变。也就是说: 保留所有其余维度的元素,但在索引张量之后的指标维度中抉择元素。
num_picks = 2
values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)
下面代码将失去的张量形态为[len_dim_0, num_picks]: 对于沿维度 0 的每个元素,咱们从维度 1 中抉择了雷同的元素。
当初咱们应用 3D 张量,一个形态为 [batch_size, num_elements, num_features] 的张量: 这样咱们就有了 num_elements 元素和 num_feature 特色,并且是一个批次进行解决的。咱们为每个批处理 / 个性组合抉择雷同的元素:
import torch
batch_size = 16
num_elements = 64
num_features = 1024
num_picks = 2
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(num_picks,))
# [batch_size, num_picks, num_features]
picked = torch.index_select(values, 1, indices)
上面是如何应用简略的 for 循环从新实现这个函数的办法:
picked_manual = torch.zeros_like(picked)
for i in range(batch_size):
for j in range(num_picks):
for k in range(num_features):
picked_manual[i, j, k] = values[i, indices[j], k]
assert torch.all(torch.eq(picked, picked_manual))
这样比照能够对 index_select 有一个更深刻的理解
torch.gather
torch.gather
是 PyTorch 中用于依照指定索引从输出张量中收集值的函数。它容许你依据指定的索引从输出张量中取出对应地位的元素,并组成一个新的张量。它的行为相似于 index_select,然而当初所需维度中的元素抉择依赖于其余维度——也就是说对于每个批次索引,对于每个特色,咱们能够从“元素”维度中抉择不同的元素——咱们将从一个张量作为另一个张量的索引。
num_picks = 2
values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(len_dim_0, num_picks))
# [len_dim_0, num_picks]
picked = torch.gather(values, 1, indices)
当初的抉择不再以直线为特色,而是对于沿着维度 0 的每个索引,在维度 1 中抉择一个不同的元素:
咱们持续扩大为 3D 的张量,并展现 Python 代码来从新实现这个抉择:
import torch
batch_size = 16
num_elements = 64
num_features = 1024
num_picks = 5
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(batch_size, num_picks, num_features))
picked = torch.gather(values, 1, indices)
picked_manual = torch.zeros_like(picked)
for i in range(batch_size):
for j in range(num_picks):
for k in range(num_features):
picked_manual[i, j, k] = values[i, indices[i, j, k], k]
assert torch.all(torch.eq(picked, picked_manual))
torch.gather
是一个灵便且弱小的函数,能够在许多状况下用于数据收集和操作,尤其在须要依照指定索引收集数据的状况下十分有用。
torch.take
torch.take
是 PyTorch 中用于从输出张量中依照给定索引取值的函数。它相似于
torch.index_select
和
torch.gather
,然而更简略,只须要一个索引张量即可。它实质上是将输出张量视为扁平的,而后从这个列表中抉择元素。例如: 当对形态为 [4,5] 的输出张量利用 take,并抉择指标 6 和 19 时,咱们将取得扁平张量的第 6 和第 19 个元素——即来自第 2 行的第 2 个元素,以及最初一个元素。
num_picks = 2
values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_0 * len_dim_1, size=(num_picks,))
# [num_picks]
picked = torch.take(values, indices)
咱们当初只失去两个元素:
3D 张量也是一样的这里索引张量能够是任意形态的,只有最大索引不超过张量的总数即可:
import torch
batch_size = 16
num_elements = 64
num_features = 1024
num_picks = (2, 5, 3)
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, batch_size * num_elements * num_features, size=num_picks)
# [2, 5, 3]
picked = torch.take(values, indices)
picked_manual = torch.zeros(num_picks)
for i in range(num_picks[0]):
for j in range(num_picks[1]):
for k in range(num_picks[2]):
picked_manual[i, j, k] = values.flatten()[indices[i, j, k]]
assert torch.all(torch.eq(picked, picked_manual))
总结
为了总结这篇文章,咱们在一个表格中总结了这些函数之间的区别——蕴含简短的形容和示例形态。样本形态是针对后面提到的 3D ML 示例量身定制的,并将列出索引张量的必要形态,以及由此产生的输入形态:
当你想要从一个张量中依照索引选取子集时能够应用
torch.index_select
,它通常用于在给定维度上抉择元素。实用于较为简单的索引选取操作。
torch.gather
实用于依据索引从输出张量中收集元素并造成新张量的状况。能够依据须要在不同维度上进行收集操作。
torch.take
实用于一维索引,从输出张量中取出对应索引地位的元素。当只须要依照一维索引取值时,十分不便。
https://avoid.overfit.cn/post/e4844e899c4d4600813be7d09e91b9ef
作者:Oliver S