PyTorch之对类别张量进行one-hot编码

本文已受权极市平台, 并首发于极市平台公众号. 未经容许不得二次转载.
  • 原始文档:https://www.yuque.com/lart/ugkv9f/src5w8
  • 代码仓库:https://github.com/lartpang/CodeForArticle/tree/main/OneHotEncoding.PyTorch

前言

one-hot 模式的编码在深度学习工作中十分常见,然而却并不是一种很天然的数据存储形式。所以大多数状况下都须要咱们本人手动转换。尽管思路很间接,就是将类别拆分成一一对应的 0-1 向量,然而具体实现起来的确还是须要思考下的。实际上 pytorch 本身在nn.functional中曾经提供了one_hot办法来疾速利用。然而这并不能影响咱们的思考与实际:>!所以本文尽可能将基于 pytorch 中罕用办法来实现one-hot编码的形式整顿了下,心愿有用。

次要的形式有这么几种:

  • for循环
  • scatter
  • index_select

for循环

这种办法十分直观,说白了就是对一个空白(全零)张量中的指定地位进行赋值(赋 1)操作即可。
关键在于如何设定索引。
上面设计了两种实质雷同但因为指定维度不同而导致些许差别的计划。

def bhw_to_onehot_by_for(bhw_tensor: torch.Tensor, num_classes: int):    """    Args:        bhw_tensor: b,h,w        num_classes:    Returns: b,h,w,num_classes    """    assert bhw_tensor.ndim == 3, bhw_tensor.shape    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)    one_hot = bhw_tensor.new_zeros(size=(num_classes, *bhw_tensor.shape))    for i in range(num_classes):        one_hot[i, bhw_tensor == i] = 1    one_hot = one_hot.permute(1, 2, 3, 0)    return one_hotdef bhw_to_onehot_by_for_V1(bhw_tensor: torch.Tensor, num_classes: int):    """    Args:        bhw_tensor: b,h,w        num_classes:    Returns: b,h,w,num_classes    """    assert bhw_tensor.ndim == 3, bhw_tensor.shape    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)    one_hot = bhw_tensor.new_zeros(size=(*bhw_tensor.shape, num_classes))    for i in range(num_classes):        one_hot[..., i][bhw_tensor == i] = 1    return one_hot

scatter

该办法应该是网上大多数简洁的one_hot写法的罕用模式了。其实际上次要的作用是向 tensor 中指定的地位上赋值。

因为其能够应用专门结构的索引矩阵来作为索引,所以更加灵便。当然,灵便带来的也就是了解上的艰难。官网文档中提供的解释十分直观:

'''https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html* (int dim, Tensor index, Tensor src) * (int dim, Tensor index, Tensor src, *, str reduce) * (int dim, Tensor index, Number value) * (int dim, Tensor index, Number value, *, str reduce)'''self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

文档中应用的是原地置换(in-place)版本,并且基于替换值为src,即 tensor 的状况下来解释。实际上在咱们的利用中次要基于原地置换版本并搭配替换值为标量浮点数value的模式。

上述的模式中,咱们能够看到,通过指定参数 tensor index,咱们就能够将src(i,j,k)的值搁置到办法调用者(这里是self)的指定地位上。该指定地位由index(i,j,k)处的值替换坐标(i,j,k)中的dim地位的值来形成(这里也反映进去了index tensor 的一个要求,就是维度数量要和selfsrc(如果src为 tensor 的话。后文中应用的是具体的标量值 1,即src替换为value)统一)。这倒是和one-hot的概念十分吻合。因为one-hot自身模式上的含意就是对于第i类数据,第i个地位为 1,其余地位为 0。所以对全零 tensor 应用scatter_是能够非常容易的结构出one-hottensor 的,即对对应于类别编号的地位搁置 1 即可。

对于咱们的问题而言,index非常适合应用输出的蕴含类别编号的 tensor(形态为B,H,W)来示意。基于这样的思考,能够构思出两种不同的策略:

def bhw_to_onehot_by_scatter(bhw_tensor: torch.Tensor, num_classes: int):    """    Args:        bhw_tensor: b,h,w        num_classes:    Returns: b,h,w,num_classes    """    assert bhw_tensor.ndim == 3, bhw_tensor.shape    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)    return one_hotdef bhw_to_onehot_by_scatter_V1(bhw_tensor: torch.Tensor, num_classes: int):    """    Args:        bhw_tensor: b,h,w        num_classes:    Returns: b,h,w,num_classes    """    assert bhw_tensor.ndim == 3, bhw_tensor.shape    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)    return one_hot

这两种模式的差别的本源在于对形态的解决上。由此带来了scatter不同的利用模式。

对于第一种模式,将B,H,W三个维度合并,这样的益处是对通道(类别)的索引的了解变得直观起来。

    one_hot = torch.zeros(size=(math.prod(bhw_tensor.shape), num_classes))    one_hot.scatter_(dim=1, index=bhw_tensor.reshape(-1, 1), value=1)

这里将类别维度和其余维度间接拆散,移到了末位。通过dim指定该维度,于是就有了这样的对应关系:

zero_tensor[abc, index[abc][d]] = value  # d=0

而在第二种状况下依然保留了后面的三个维度,类别维度仍然挪动到最初一位。

    one_hot = torch.zeros(size=(*bhw_tensor.shape, num_classes))    one_hot.scatter_(dim=-1, index=bhw_tensor[..., None], value=1)

此时的对应关系是这样的:

zero_tensor[a,b,c, index[a][b][c][d]] = value # d=0

另外在 pytorch 分类模型库 timm 中,也应用了相似的办法:

# https://github.com/rwightman/pytorch-image-models/blob/2c33ca6d8ce5d9257edf8cab5ab7ece81780aaf7/timm/data/mixup.py#L17-L19def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):    x = x.long().view(-1, 1)    return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)

index_select

torch.index_select(input, dim, index, *, out=None) → Tensor- input (Tensor) – the input tensor.- dim (int) – the dimension in which we index- index (IntTensor or LongTensor) – the 1-D tensor containing the indices to index

该函数如其名,就是用索引来抉择 tensor 的指定维度的子 tensor 的。

想要了解这一办法的动机,实际上须要反过来,从类别标签的角度对待one-hot编码。

对于原始从小到大排布的类别序号对应的one-hot编码成的矩阵就是一个单位矩阵。所以每个类别对应的就是该单位矩阵的特定的列(或者行)。这一需要恰好合乎index_select的性能。所以咱们能够应用其实现one_hot编码,只须要应用类别序号索引特定的列或者行即可。上面就是一个例子:

def bhw_to_onehot_by_index_select(bhw_tensor: torch.Tensor, num_classes: int):    """    Args:        bhw_tensor: b,h,w        num_classes:    Returns: b,h,w,num_classes    """    assert bhw_tensor.ndim == 3, bhw_tensor.shape    assert num_classes > bhw_tensor.max(), torch.unique(bhw_tensor)    one_hot = torch.eye(num_classes).index_select(dim=0, index=bhw_tensor.reshape(-1))    one_hot = one_hot.reshape(*bhw_tensor.shape, num_classes)    return one_hot

性能比照

整体代码可见GitHub。

上面展现了不同办法的大抵的绝对性能(因为后盾在跑程序,可能并不是非常精确,倡议大家自行测试)。能够看到,pytorch 自带的函数在 CPU 上效率并不是很高,然而在 GPU 上体现良好。其中乏味的是,基于index_select的模式体现十分亮眼。

1.10.0 GeForce RTX 2080 Ticpu('bhw_to_onehot_by_for', 0.5411529541015625)('bhw_to_onehot_by_for_V1', 0.4515676498413086)('bhw_to_onehot_by_scatter', 0.0686192512512207)('bhw_to_onehot_by_scatter_V1', 0.08529376983642578)('bhw_to_onehot_by_index_select', 0.05156970024108887)('F.one_hot', 0.07366824150085449)gpu('bhw_to_onehot_by_for', 0.005235433578491211)('bhw_to_onehot_by_for_V1', 0.045584678649902344)('bhw_to_onehot_by_scatter', 0.0025513172149658203)('bhw_to_onehot_by_scatter_V1', 0.0024869441986083984)('bhw_to_onehot_by_index_select', 0.002012014389038086)('F.one_hot', 0.0024051666259765625)