关于深度学习:为什么基于树的模型在表格数据上仍然优于深度学习

3次阅读

共计 1967 个字符,预计需要花费 5 分钟才能阅读完成。

在这篇文章中,我将具体解释这篇论文《Why do tree-based models still outperform deep learning on tabular data》这篇论文解释了一个被世界各地的机器学习从业者在各种畛域察看到的景象——基于树的模型在剖析表格数据方面比深度学习 / 神经网络好得多。

论文的注意事项

这篇论文进行了大量的预处理。例如像删除失落的数据会妨碍树的性能,然而随机森林非常适合短少数据的状况,如果你的数据十分芜杂:蕴含大量的特色和维度。RF 的鲁棒性和长处使其优于更“先进”的解决方案,因为后者很容易呈现问题。

其余的大部分工作都很规范。我集体不太喜爱利用太多的预处理技术,因为这可能会导致失去数据集的许多细微差别,但论文中所采取的步骤基本上会产生雷同的数据集。然而须要阐明的是,在评估最终后果时要应用雷同的解决办法。

论文还应用随机搜寻来进行超参数调优。这也是行业标准,但依据我的教训,贝叶斯搜寻更适宜在更宽泛的搜寻空间中进行搜寻。

理解了这些就能够深刻咱们的次要问题了——为什么基于树的办法胜过深度学习?

1、神经网络偏差过于平滑的解决方案

这是作者分享深度学习神经网络无奈与随机森林竞争的第一个起因。简而言之,当波及到非平滑函数 / 决策边界时,神经网络很难创立最适宜的函数。随机森林在怪异 / 锯齿 / 不规则模式下做得更好。

如果我来猜想起因的话,可能是在神经网络中应用了梯度,而梯度依赖于可微的搜寻空间,依据定义这些空间是平滑的,所以无奈辨别尖利点和一些随机函数。所以我举荐学习诸如进化算法、传统搜寻等更根本的概念等 AI 概念,因为这些概念能够在 NN 失败时的各种状况下获得很好的后果。

无关基于树的办法(RandomForests)和深度学习者之间决策边界差别的更具体示例,请查看下图 –

在附录中,作者对上述可视化进行了上面阐明:

在这一部分中,咱们能够看到 RandomForest 可能学习 MLP 无奈学习的 x 轴(对应日期特色)上的不规则模式。咱们展现了默认超参数的这种差别,这是神经网络的典型行为,然而实际上很难(只管并非不可能)找到胜利学习这些模式的超参数。

2、无信息个性会影响相似 mlp 的神经网络

另一个重要因素,特地是对于那些同时编码多个关系的大型数据集的状况。如果向神经网络输出不相干的特色后果会很蹩脚(而且你会节约更多的资源训练你的模型)。这就是为什么花大量工夫在 EDA/ 畛域摸索上是如此重要。这将有助于了解个性,并确保一切顺利运行。

论文的作者测试了模型在增加随机和删除无用个性时的性能。基于他们的后果,发现了 2 个很乏味的后果

  1. 删除大量个性缩小了模型之间的性能差距。这分明地表明,树型模型的一大劣势是它们可能判断特色是否有用并且可能防止无用特色的影响。
  2. 与基于树的办法相比,向数据集增加随机特色表明神经网络的消退要重大得多。ResNet 尤其受到这些无用个性的影响。transformer 的晋升可能是因为其中的注意力机制在肯定水平上会有一些帮忙。

对这种景象的一种可能解释是决策树的设计形式。任何学习过 AI 课程的人都会晓得决策树中的信息增益和熵的概念。这使得决策树可能通过比拟剩下的个性来抉择最佳的门路。

回到正题,在表格数据方面,还有最初一件事使 RF 比 NN 体现更好。那就是旋转不变性。

3、NNs 是旋转不变性的,然而理论数据却不是

神经网络是旋转不变的。这意味着如果对数据集进行旋转操作,它不会扭转它们的性能。旋转数据集后,不同模型的性能和排名产生了很大的变动,尽管 ResNets 始终是最差的,然而旋转后放弃原来的体现,而所有其余模型的变动却很大。

这很景象十分乏味:旋转数据集到底意味着什么? 整个论文中也没有具体的细节阐明(我曾经分割了作者,并将持续跟进这个景象)。如果有任何想法,也请在评论中分享。

然而这个操作让咱们看到为什么旋转方差很重要。依据作者的说法,采纳特色的线性组合 (这就是使 ResNets 不变的起因) 实际上可能会谬误地示意特色及其关系。

通过对原始数据的编码获得最佳的数据偏差,这些最佳的偏差可能会混合具备十分不同的统计个性的特色并且不能通过旋转不变的模型来复原,会为模型提供更好的性能。

总结

这是一篇十分乏味的论文,尽管深度学习在文本和图像数据集上获得了巨大进步,但它在表格数据上的根本没有劣势可言。论文应用了 45 个来自不同畛域的数据集进行测试,结果表明即便不思考其卓越的速度,基于树的模型在中等数据(~10K 样本)上依然是最先进的,如果你对表格数据感兴趣,倡议间接浏览:

Why do tree-based models still outperform deep learning on tabular data

https://avoid.overfit.cn/post/e4682d6810d7427caf9aae6f6d1f3734

作者:Devansh

正文完
 0