DataWhale街景字符编码识别项目模型集成

44次阅读

共计 1885 个字符,预计需要花费 5 分钟才能阅读完成。

对于 Baseline 的一点总结

经过了数据准备、数据增强、模型构建以及模型验证和迭代 4 个阶段,基本上得到了一个还不错的模型。

使用改进后的模型得到测试结果,去网站提交,结果只有 0.625 的准确率。发现测试集上的效果和验证集上的效果差距还挺大的。

很容易想到的是测试集和验证集会不会存在分布上的差异。如果是,有没有办法来解决这种差异。

现在的问题是如何提高网络的泛化性能。在加了一些 tricks 之后,训练集的精度已经可以达到 99%,但验证集的精度基本上稳定在 71% 左右,无法再继续提升。当模型没法进一步改进时,就应该回到数据本身,可以检查一下误分类的样本,对其进行一些统计。针对数据进行一些改进。

对于常规的图像分类网络,输入的图像尺寸对于网络最终的结果不会有太大的影响。而这个任务中,需要同时对图片中的所有数字进行分类,并且强制网络的输出分别去识别对应相应的数字。那么其实网络需要去自动的学习,哪个位置对应哪一个字符。因此该分类网络其实是有强烈的位置敏感的。

为了尝试解决这个问题,使用了 YOLO V3 中的一个 trick,每 10 个 epoch,改变一次图像大小,在这里,我只改变了图像的宽度 (因为文字主要是横向分布的)。重新训练了 resnet18 和 mobilenet v2。

模型集成介绍

模型集成是一种有效的涨分手段,主要包括 bagging、stacking 和 boosting3 类。

bagging:该方法通常考虑的是同质弱学习器,相互独立地并行学习这些弱学习器,并按照某种确定性的平均过程将它们组合起来。

boosting,该方法通常考虑的也是同质弱学习器。它以一种高度自适应的方法顺序地学习这些弱学习器(每个基础模型都依赖于前面的模型),并按照某种确定性的策略将它们组合起来。

stacking,该方法通常考虑的是异质弱学习器,并行地学习它们,并通过训练一个「元模型」将它们组合起来,根据不同弱模型的预测结果输出一个最终的预测结果。

视觉任务中常用的集成方法

Dropout

Dropout 也叫丢弃法,即在训练过程中,随机的丢弃一些神经元,相当于每一次都在 train 不同的模型。它同样可以看做是一种集成学习方式。

SnapShot

SnapShot 是一种不需要增加额外的训练成本就能获取多个网络模型的方式。它通过保存训练过程中,学习率处于最低点位置时的网络权重,最后会得到对应不同局部极小值的网络。

snapshot 需要结合学习率周期调整策略,下图为学习率余弦衰减策略,snapShot 会在每个学习率调整的周期结束位置,拍快照,也就是保存当前模型权重。

上图中的学习率的最小值为 5e-4,使用在这样的极端值(0.1 至 5e-4,M 倍)之间波动的学习速率计划的背后理论是,训练模型时存在多个局部极小值。不断降低学习率会迫使模型停留在低于最佳的局部最小值位置。因此,我们使用很高的学习率来逃避当前的局部最小值,并尝试找到另一个可能更好的局部最小值。

最后对保存的多个模型进行 voting 或者 staking。

TTA

TTA(test time augmentation) 是一种后处理方式,测试阶段,对每张图片进行数据增强 n 次,分别得到 n 个结果,最后进行 vote 得到最终的预测结果。
其思路跟集成学习的思路基本一致,同样可以看做是一种集成方法。

总结

  1. 模型只能决定精度的下限,数据才能决定精度的上限。在模型训练前,需要对数据有一定的认识,则可以帮助我们选择合适的模型以及合适的参数;模型的进一步提升遇到瓶颈时,我们还是需要回到数据的本身,分析误分类样本,看是否可以发现一些规律。
  2. 设置随机种子非常重要,我们需要保证结果的可复现。因为网络模型中的随机性太多,每一次训练都会产生完全不一样的模型。
  3. 图形的可视化或者良好的 log 可以很好的帮助到我们的结果分析过程,推荐使用 tensorboard。另外,每一次模型保存都额外的将模型的配置参数同样保存,特别是在模型迭代了很多次之后,保存的模型很多,自己都不一定能分辨清楚。
  4. 调参过程中,控制变量特别重要,每次只改动一个变量。每一次的调整,都能手动的记录下来或者以 log 的形式保存。
  5. 在硬件有限的情况下,推荐使用 Snapshot 和 TTA 模型集成方式。

最终通过训练的 3 个模型进行 stacking,达到了 0.91 的 acc。

这次比赛虽然整体难度不大,但想要取得比较好的成绩还是挺难的,能得到这样的成绩还是有点意外的。但抛开成绩,收获更多的可能是在模型的迭代过程中,对于调参的一些理解和收获以及对数据的认识。

Reference

[1] 常用的模型集成方法介绍:bagging、boosting、stacking

正文完
 0