乐趣区

关于机器学习:焦点损失函数-Focal-Loss-与-GHM

文章来自公众号【机器学习炼丹术】

1 focal loss 的概述

焦点损失函数 Focal Loss(2017 年何凯明大佬的论文)被提出用于密集物体检测工作。

当然,在指标检测中,可能待检测物体有 1000 个类别,然而你想要辨认进去的物体,只是其中的某一个类别,这样其实就是一个样本十分不平衡的一个分类问题。

而 Focal Loss 简略的说,就是解决样本数量极度不均衡的问题的。

说到样本不均衡的解决方案,相比大家是晓得一个混同矩阵的 f1-score 的,然而这个如同不能用在训练中当成损失。而 Focal loss 能够在训练中,让小数量的指标类别减少权重,让分类谬误的样本减少权重

先来看一下简略的二值穿插熵的损失:

  • y’是模型给出的预测类别概率,y 是实在样本。就是说,如果一个样本的实在类别是 1,预测概率是 0.9,那么 $-log(0.9)$ 就是这个损失。
  • 讲道理,个别我不喜爱用二值穿插熵做例子,用多分类穿插熵做例子会更难受。

【而后看 focal loss 的改良】:

这个减少了一个 $(1-y’)^\gamma$ 的权重值,怎么了解呢?就是如果给出的正确类别的概率越大,那么 $(1-y’)^\gamma$ 就会越小,阐明 分类正确的样本的损失权重小 ,反之, 分类谬误的样本的损权重大


【focal loss 的进一步改良】:

这里减少了一个 $\alpha$,这个 alpha 在论文中给出的是 0.25, 这个就是 单纯的升高正样本或者负样本的权重,来解决样本不平衡的问题

两者联合起来,就是一个能够解决样本不均衡问题的损失 focal loss。


【总结】:

  1. $\alpha$ 解决了样本的不均衡问题;
  2. $\beta$ 解决了难易样本不均衡的问题。让样本更器重难样本,漠视易样本。
  3. 总之,Focal loss 会的关注程序为:样本少的、难分类的;样本多的、难分类的;样本少的,易分类的;样本多的,易分类的。

2 GHM

  • GHM 是 Gradient Harmonizing Mechanism。

这个 GHM 是为了解决 Focal loss 存在的一些问题。

【Focal Loss 的弊病 1】
让模型过多的关注特地难分类的样本是会有问题的。样本中有一些异样点、离群点(outliers)。所以模型为了拟合这些十分难拟合的离群点,就会存在过拟合的危险。

2.1 GHM 的方法

Focal Loss 是从置信度 p 的角度动手衰减 loss 的。而 GHM 是肯定范畴内置信度 p 的样本数量来衰减 loss 的。

首先定义了一个变量 g,叫做 梯度模长(gradient norm)

能够看出这个梯度模长,其实就是模型给出的置信度 $p^*$ 与这个样本实在的标签之间的差值(间隔)。 g 越小,阐明预测越准,阐明样本越容易分类。

下图中展现了 g 与样本数量的关系:

【从图中能够看到】

  • 梯度模长靠近于 0 的样本多,也就是易分类样本是十分多的
  • 而后样本数量随着梯度模长的减少迅速缩小
  • 而后当梯度模长靠近 1 的时候,样本的数量又开始减少。

GHM 是这样想的,对于梯度模长小的易分类样本,咱们漠视他们;然而 focal loss 过于关注难分类样本了。要害是难分类样本其实也有很多!, 如果模型始终学习难分类样本,那么可能模型的精确度就会降落。所以 GHM 对于难分类样本也有一个衰减。

那么,GHM 对易分类样本和难分类样本都衰减,那么真正被关注的样本,就是那些不难不易的样本。而克制的水平,能够依据样本的数量来决定。

这里定义一个GD,梯度密度

$$GD(g)=\frac{1}{l(g)}\sum_{k=1}^N{\delta(g_k,g)}$$

  • $GD(g)$ 是计算在梯度 g 地位的梯度密度;
  • $\delta(g_k,g)$ 就是样本 k 的梯度 $g_k$ 是否在 $[g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]$ 这个区间内。
  • $l(g)$ 就是 $[g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]$ 这个区间的长度,也就是 $\epsilon$

总之,$GD(g)$ 就是梯度模长在 $[g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]$ 内的样本总数除以 $\epsilon$.

而后把每一个样本的穿插熵损失除以他们对应的梯度密度就行了。
$$L_{GHM}=\sum^N_{i=1}{\frac{CE(p_i,p_i^*)}{GD(g_i)}}$$

  • $CE(p_i,p_i^*)$ 示意第 i 个样本的穿插熵损失;
  • $GD(g_i)$ 示意第 i 个样本的梯度密度;

2.2 论文中的 GHM

论文中呢,是把梯度模长划分成了 10 个区域,因为置信度 p 是从 0~1 的,所以梯度密度的区域长度就是 0.1,比方是 0~0.1 为一个区域。

下图是论文中给出的比照图:

【从图中能够失去】

  • 绿色的示意穿插熵损失;
  • 蓝色的是 focal loss 的损失,发现梯度模长小的损失衰减很无效;
  • 红色是 GHM 的穿插熵损失,发现梯度模长在 0 左近和 1 左近存在显著的衰减。

当然能够想到的是,GHM 看起来是须要整个样本的模型估计值,能力计算出梯度密度,能力进行更新。也就是说 mini-batch 看起来仿佛不能用 GHM。

在 GHM 原文中也提到了这个问题,如果光应用 mini-batch 的话,那么很可能呈现不平衡的状况。

【我集体感觉的解决办法】

  1. 能够应用上一个 epoch 的梯度密度,来作为这一个 epoch 来应用;
  2. 或者一开始先应用 mini-batch 计算梯度密度,而后模型收敛速度降落之后,再应用第一种形式进行更新。

3 python 实现

下面讲述的关键在于 focal loss 实现的性能:

  1. 分类正确的样本的损失权重小,分类谬误的样本的损权重大
  2. 样本过多的类别的权重较小

在 CenterNet 中预测中心点地位的时候,也是应用了 Focal Loss,然而稍有改变。

3.1 概述


这外面和下面讲的比拟相似,咱们漠视脚标。

  • 假如 $Y=1$, 那么预测的 $\hat{Y}$ 越凑近 1,阐明预测的约正确,而后 $(1-\hat{Y})^\alpha$ 就会越小,从而体现 分类正确的样本的损失权重小;otherwize 的状况也是这样。
  • 然而这里的 otherwize 中多了一个 $(1-Y)^\beta$, 这个是用来均衡样本不平衡问题的,在前面的代码局部会提到 CenterNet 的热力求。就会明确这个了。

3.2 代码解说

上面通过代码来了解:

class FocalLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.neg_loss = _neg_loss

    def forward(self, output, target, mask):
        output = torch.sigmoid(output)
        loss = self.neg_loss(output, target, mask)
        return loss

这外面的 output 能够了解为是一个 1 通道的特色图,每一个 pixel 的值都是模型给出的置信度,而后通过 sigmoid 函数转换成 0~1 区间的置信度。

而 target 是 CenterNet 的热力求,这一点可能比拟难了解。打个比方,一个 10*10 的全都是 0 的特色图,而后这个特色图中只有一个 pixel 是 1,那么这个 pixel 的地位就是一个指标检测物体的中心点。有几个 1 就阐明这个图中有几个要检测的指标物体。

而后,如果一个特色图上,全都是 0,只有几个孤零零的 1,未免显得过于稠密了,直观上也十分的不平滑。所以 CenterNet 的热力求还须要对这些 1 为核心做一个高斯

能够看作是一种平滑:

能够看到,数字 1 的周围是同样的数字。这是一个以 1 为核心的高斯平滑。


这里咱们回到下面说到的 $(1-Y)^\beta$:

对于数字 1 来说,咱们计算 loss 天然是用第一行来计算,然而对于 1 左近的其余点来说,就要思考 $(1-Y)^\beta$ 了。越凑近 1 的点的 $Y$ 越大,那么 $(1-Y)^\beta$ 就会越小,这样从而升高 1 左近的权重值。其实这里我也讲不太明确,就是依据间隔 1 的间隔升高负样本的权重值,从而能够实现 样本过多的类别的权重较小


咱们回到主题,对 output 进行 sigmoid 之后,与 output 一起放到了 neg_loss 中。咱们来看什么是 neg_loss:

def _neg_loss(pred, gt, mask):
    pos_inds = gt.eq(1).float() * mask
    neg_inds = gt.lt(1).float() * mask

    neg_weights = torch.pow(1 - gt, 4)

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \
               neg_weights * neg_inds

    num_pos = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos
    return loss

先说一下,这外面的 mask 是依据特定工作中加上的一个小性能,就是在该工作中,一张图片中有一部分是不须要计算 loss 的,所以先用过 mask 把那个局部过滤掉。这里间接漠视 mask 就好了。

neg_weights = torch.pow(1 - gt, 4) 能够得悉 $\beta=4$, 从上面的代码中也不难推出,$\alpha=2$,剩下的内容就都一样了。

把每一个 pixel 的损失都加起来,除以指标物体的数量即可。

退出移动版