乐趣区

关于算法:GIT斯坦福大学提出应对复杂变换的不变性提升方法-ICLR-2022

论文对长尾数据集中的简单变换不变性进行了钻研,发现不变性在很大水平上取决于类别的图片数量,实际上分类器并不能将从大类中学习到的不变性转移到小类中。为此,论文提出了 GIT 生成模型,从数据集中学习到类无关的简单变换,从而在训练时对小类进行无效加强,整体成果不错

起源:晓飞的算法工程笔记 公众号

论文: Do Deep Networks Transfer Invariances Across Classes?

  • 论文地址:https://arxiv.org/abs/2203.09739
  • 论文代码:https://github.com/AllanYangZhou/generative-invariance-transfer

Introduction


  优良的泛化能力须要模型具备疏忽不相干细节的能力,比方分类器应该对图像的指标是猫还是狗进行响应,而不是背景或光照条件。换句话说,泛化能力须要蕴含对简单但不影响预测后果的变换的不变性。在给定足够多的不同图片的状况下,比方训练数据集蕴含在大量不同背景下的猫和狗的图像,深度神经网络确实能够学习到不变性。但如果狗类的所有训练图片都是草地背景,那分类器很可能会误判房子背景中的狗为猫,这种状况往往就是不均衡数据集存在的问题。
  类不均衡在实践中很常见,许多事实世界的数据集遵循长尾散布,除几个头部类有很多图片外,而其余的每个尾部类都有很少的图片。因而,即便长尾数据集中图片总量很大,分类器也可能难以学习尾部类的不变性。尽管罕用的数据加强能够通过减少尾部类中的图片数量和多样性来解决这个问题,但这种策略并不能用于模拟简单变换,如更换图片背景。须要留神的是,像照明变动之类的许多简单变换是类别无关的,可能相似地利用于任何类别的图片。现实状况下,经过训练的模型应该可能主动将这些不变性转为类无关的不变性,兼容尾部类的预测。
  论文通过试验察看分类器跨类迁徙学习到的不变性的能力,从后果中发现即便通过过采样等均衡策略后,神经网络在不同类别之间传递学习到的不变性也很差。例如,在一个长尾数据集上,每个图片都是随机平均旋转的,分类器往往对来自头部类的图片放弃旋转不变,而对来自尾部类的图片则不放弃旋转不变。
  为此,论文提出了一种更无效地跨类传递不变性的简略办法。首先训练一个 input conditioned 但与类无关的生成模型,该模型用于捕捉数据集的简单变换,暗藏了类信息以便鼓励类之间的变换转移。而后应用这个生成模型来转换训练输出,相似于学习数据加强来训练分类器。论文通过试验证实,因为尾部类的不变性失去显著晋升,整体分类器对简单变换更具不变性,从而有更好的测试准确率。

Measuring Invariance Transfer In Class-Imbalanced Datasets


  论文先对不均衡场景中的不变性进行介绍,随后定义一个用于度量不变性的指标,最初再剖析不变性与类别大小之间的关系。

Setup:Classification,Imbalance,and Invariances

  定义输出 $(x,y)$,标签 $y$ 属于 $\{1,\cdots,C\}$,$C$ 为类别数。定义训练后的模型的权值 $w$,用于预测条件概率 $\tilde{P}_w(y=j|x)$,分类器将抉择概率最大的类别 $j$ 作为输入。给定训练集 $\{(x^{(i)}, y^{(i)})\}^N_{i=1}\sim \mathbb{P}_{train}$,通过教训危险最小化(ERM)来最小化训练样本的均匀损失。但在不均衡场景下,因为 $\{y^{(i)}\}$ 的散布不是平均的,导致 ERM 在多数类别上体现不佳。
  在事实场景中,最现实的是模型在所有类别上都体现得不错。为此,论文采纳类别均衡的指标来评估分类器,相当于测试散布 $\mathbb{P}_{test}$ 在 $y$ 上是平均的。
  为了剖析不变性,论文假如 $x$ 的简单变换散布为 $T(\cdot|x)$。对于不影响标签的简单变换,论文心愿分类器是不变的,即预测的概率不会扭转:

Measuring Learned Invariacnes

  为了度量分类器学习不变性的水平,论文定义了原输出和变换输出之间的冀望 KL 散度(eKLD):

  这是一个非正数,eKLD 越低代表不变性水平就越高,对 $T$ 齐全不变的分类器的 eKLD 为 0。如果有方法采样 $x^{‘}\sim T(\cdot|x)$,就能计算训练后的分类器的 eKLD。此外,为了钻研不变性与类图片数量的关系,能够通过别离计算类特定的 eKLD 进行剖析,行将公式 2 的 $x$ 限定为类别 $j$ 所属。
  计算 eKLD 的难点在于简单变动散布 $T$ 的获取。对于大多数事实世界的数据集而言,其简单变动散布是不可知的。为此,论文通过选定简单散布来生成数据集,如 RotMNIST 数据集。与数据加强不同,这种生成形式是通过变换对数据集进行裁减,而不是在训练过程对同一图片利用多个随机采样的变换。
  论文以 Kuzushiji-49 作为根底,用三种不同的简单变换生成了三个不同的数据集:图片旋转(K49-ROT-LT)、不同背景强度(K49-BG-LT)和图像收缩或侵蚀(K49-DIL-LT)。为了使数据集具备长尾散布(LT),先从大到小随机抉择类别,而后有选择地缩小类别的图片数直到数量散布合乎参数为 2.0 的 Zipf 定律,同时强制起码的类为 5 张图片。反复以上操作 30 次,结构 30 个不同的长尾数据集。每个长尾数据集有 7864 张图片,最多的类有 4828 张图片,最小的类有 5 张图片,而测试集则放弃原先的不变。

  训练方面,采纳规范 ERM 和 CE+DRS 两种办法,其中 CE+DRS 基于穿插熵损失进行提早的类均衡重采样。DRS 在开始阶段跟 ERM 一样随机采样,随后再切换为类均衡采样进行训练。论文为每个训练集进行两种分类器的训练,随后计算每个分类器每个类别的 eKLD 指标。后果如图 1 所示,能够看到两个景象:

  • 在不同变动数据集上,不变性随着类图片数缩小都升高了。这表明尽管简单变换是类无关的,但在不均衡数据集上,模型无奈在类之间传递学习到的不变性。
  • 对于图片数量雷同的类,应用 CE+DRS 训练的分类器往往会有较低的 eKLD,即更好的不变性。但从曲线上看,DRS 仍有较大的晋升空间,还没达到类别之间统一的不变性。

Trasnferring Invariances with Generative Models


  从后面的剖析能够看到,长尾数据集的尾部类对简单变换的不变性较差。上面将介绍如何通过生成式不变性变换 (GIT) 来显式学习数据集中的简单变换散布 $T(\cdot|x)$,进而在类间转移不变性。

Learning Nuisance Transformations from Data

  如果有数据集理论相干的简单变换的办法,能够间接将其用作数据加强来增强所有类的不变性,但在实践中很少呈现这种状况。于是论文提出 GIT,通过训练 input conditioned 的生成模型 $\tilde{T}(\cdot|x)$ 来近似实在的简单变换散布 $T(\cdot|x)$。

  论文参考了多模态图像转换模型 MUNIT 来结构生成模型,该类模型可能从数据中学习到多种简单变换,而后对输出进行变换生成不同的输入。论文对 MUNIT 进行了大量批改,使其可能学习单数据集图片之间的变换,而不是两个不同域数据集之间的变换。从图 2 的生成后果来看,生成模型可能很好地捕获数据集中的简单变换,即便是尾部类也有不错的成果。须要留神的是,MUNIT 是非必须的,也能够尝试其它可能更好的办法。
  在训练好生成模型后,应用 GIT 作为实在简单变换的代理来为分类器进行数据加强,心愿可能进步尾部类对简单变换的不变性。给定训练输出 $\{(x^{(i)}, y^{(i)})\}^{|B|}_{i=1}$,变换输出 $\tilde{x}^{(i)}\gets \tilde{T}(\cdot|x^{(i)})$,放弃标签不变。这样的变换可能进步分类器在训练期间的输出多样性,特地是对于尾部类。须要留神的是,batch 能够搭配任意的采样办法(Batch Sampler),比方类均衡采样器。此外,还能够有选择地进行加强,防止因为生成模型的缺点侵害性能的可能性,比方对数量足够且不变性曾经很好的头部类不进行加强。

  在训练中,论文设置阈值 $K$,仅图片数量少于 $K$ 的类进行数据加强。此外,仅对每个 batch 的 $p$ 比例进行加强。$p$ 个别取 0.5,而 $K$ 依据数据集能够设为 20-500,整体逻辑如算法 1 所示。

GIT Improves Invariance on Smaller Classes

  论文基于算法 1 进行了试验,将 Batch Sampler 设为提早重采样(DRS),Update Classifier 应用穿插熵梯度更新,整体模型标记为 $CE+DRS+GIT(all classes)$。all classes 示意禁用阈值 $K$,仅对 K49 数据集应用。作为比照,Oracle 则是用于结构生成数据集的实在变换。从图 3 的比照后果能够看到,GIT 可能无效地加强尾部类的不变性,但同时也侵害了图片富余的头部类的不变性,这表明了阈值 $K$ 的必要性。

Experiment


  不同训练策略搭配 GIT 的成果比照。

  在 GTSRB 和 CIFAR 数据集上的变换输入。

  CIFAR-10 上每个类的准确率。

  比照试验,包含阈值 $K$ 对性能的影响,GTSRB-LT, CIFAR-10 LT 和 CIFAR-100 LT 别离取 25、500 和 100。这里的最好性能貌似都比 RandAugment 差点,有可能是因为论文还没对试验进行调参,而是间接复用了 RandAugment 的试验参数。这里比拟好奇的是,如果在训练生成模型的时候加上 RandAugment,说不定性能会更好。

Conclusion


  论文对长尾数据集中的简单变换不变性进行了钻研,发现不变性在很大水平上取决于类别的图片数量,实际上分类器并不能将从大类中学习到的不变性转移到小类中。为此,论文提出了 GIT 生成模型,从数据集中学习到类无关的简单变换,从而在训练时对小类进行无效加强,整体成果不错。



如果本文对你有帮忙,简单点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

退出移动版