关于人工智能:论文回顾Batch-Augmentation在批次中进行数据扩充可以减少训练时间并提高泛化能力

32次阅读

共计 1281 个字符,预计需要花费 4 分钟才能阅读完成。

Batch Augmentation(BA):提出应用不同的数据加强在同一批次中复制样本实例,通过批次内的加强在达到雷同准确性的前提下缩小了 SGD 更新次数,还能够进步泛化能力。

Batch Augmentation (BA)

没有 BA 的一般 SGD:

一个具备损失函数 ℓ (w, xn, yn) 的模型,{xn, yn} 示意指标对的数据集,n 从 1 到 N(是 N 个数据样本),其中 xn ∈ X 和 T:X → X 是利用于每个示例的一些数据加强变换,例如,图像的随机裁剪。每个批次的通用训练过程包含以下更新规定(为简略起见,这里应用具备学习率 η 和批次大小 B 的 一般 SGD):

其中 k (t) 是从 [N / B] = {1,…, N / B} 中采样的,B (t) 是批次 t 中的样本集。

SGD 和 BA:

BA 倡议通过利用变换 Ti 来引入同一输出样本的 M 个多个实例,这里用下标 i ∈ [M],以示意每个变换的差别。这样学习规定则变为如下公式:

其中 M·B 是由 B 个样本通过 M 个不同的变换进行裁减并进行合并后的一个批次数据,反向流传更新的规定能够通过评估整个 M·B 批次或通过累积原始梯度计算的 M 个实例来计算。应用大批量更新作为批量裁减的一部分不会扭转每个 epoch 执行的 SGD 迭代次数。

BA 也可用于在中间层上进行转换。例如,能够应用常见的 Dropout 在给定层中生成同一样本的多个实例。带有 Dropout 的 BA 能够利用于语言工作或机器翻译工作。

试验后果

上图显示了改良后的验证收敛速度(以 epoch 计),最终验证分类谬误明显降低。随着 M 的减少,这一趋势在很大水平上持续改善,与论文的预期统一。

在试验中,ResNet44 with Cutout 在 Cifar10 上进行训练。ResNet44 仅在 23 个 epoch 中就达到了 94.15% 的准确率,而 baseline 为 93.07%,并且迭代次数超过了四倍(100 个 epoch)。对于 M = 12 的 AmoebaNet,在 14 个 epoch 后达到 94.46% 的验证准确率,而无需应用任何的 LR 调整策略。

Cifar、ImageNet 模型的验证准确度 (Top1) 后果、测试性能后果和 Penn-Tree-Bank (PTB) 和 WMT 数据集上的 BLEU 分数。

图中的两个基线计划:

(1)“Fixed #Steps”– 与 BA 具备雷同训练的原始计划

(2)“Fixed #Samples”– BA 雷同数量的样本(应用 M·B 批大小)。

PTB 和 WMT En-De 为应用 Dropout 的 BA 利用于语言和机器翻译工作,从图上能够看到在 CIFAR、ImageNet、PTB 和 WMT En-De 上应用 BA 都能够进步性能。通过比拟“Fixed #Steps”和“Fixed #Samples”,BA 减少批次中的样本对于进步性能至关重要

论文地址:

[2020 CVPR] [Batch Augment, BA]Augment Your Batch: Improving Generalization Through Instance Repetition

https://www.overfit.cn/post/8c40c9c388664099af15cfe57cd9e0ba

作者:Sik-Ho Tsang

正文完
 0