乐趣区

关于深度学习:一分钟带你认识深度学习中的知识蒸馏

摘要: 常识蒸馏(knowledge distillation)是模型压缩的一种罕用的办法

一、常识蒸馏入门

1.1 概念介绍

常识蒸馏(knowledge distillation)是模型压缩的一种罕用的办法,不同于模型压缩中的剪枝和量化,常识蒸馏是通过构建一个轻量化的小模型,利用性能更好的大模型的监督信息,来训练这个小模型,以期达到更好的性能和精度。最早是由 Hinton 在 2015 年首次提出并利用在分类工作下面,这个大模型咱们称之为 teacher(老师模型),小模型咱们称之为 Student(学生模型)。来自 Teacher 模型输入的监督信息称之为 knowledge(常识),而 student 学习迁徙来自 teacher 的监督信息的过程称之为 Distillation(蒸馏)。

1.2 常识蒸馏的品种

图 1 常识蒸馏的品种

1、离线蒸馏

离线蒸馏形式即为传统的常识蒸馏,如上图(a)。用户须要在已知数据集下面提前训练好一个 teacher 模型,而后在对 student 模型进行训练的时候,利用所获取的 teacher 模型进行监督训练来达到蒸馏的目标,而且这个 teacher 的训练精度要比 student 模型精度要高,差值越大,蒸馏成果也就越显著。一般来讲,teacher 的模型参数在蒸馏训练的过程中放弃不变,达到训练 student 模型的目标。蒸馏的损失函数 distillation loss 计算 teacher 和 student 之前输入预测值的差异,和 student 的 loss 加在一起作为整个训练 loss,来进行梯度更新,最终失去一个更高性能和精度的 student 模型。

2、半监督蒸馏

半监督形式的蒸馏利用了 teacher 模型的预测信息作为标签,来对 student 网络进行监督学习,如上图(b)。那么不同于传统离线蒸馏的形式,在对 student 模型训练之前,先输出局部的未标记的数据,利用 teacher 网络输入标签作为监督信息再输出到 student 网络中,来实现蒸馏过程,这样就能够应用更少标注量的数据集,达到晋升模型精度的目标。

3、自监督蒸馏

自监督蒸馏相比于传统的离线蒸馏的形式是不须要提前训练一个 teacher 网络模型,而是 student 网络自身的训练实现一个蒸馏过程,如上图(c)。具体实现形式 有多种,例如先开始训练 student 模型,在整个训练过程的最初几个 epoch 的时候,利用后面训练的 student 作为监督模型,在剩下的 epoch 中,对模型进行蒸馏。这样做的益处是不须要提前训练好 teacher 模型,就能够变训练边蒸馏,节俭整个蒸馏过程的训练工夫。

1.3 常识蒸馏的性能

1、晋升模型精度

用户如果对目前的网络模型 A 的精度不是很称心,那么能够先训练一个更高精度的 teacher 模型 B(通常参数量更多,时延更大),而后用这个训练好的 teacher 模型 B 对 student 模型 A 进行常识蒸馏,失去一个更高精度的模型。

2、升高模型时延,压缩网络参数

用户如果对目前的网络模型 A 的时延不称心,能够先找到一个时延更低,参数量更小的模型 B,通常来讲,这种模型精度也会比拟低,而后通过训练一个更高精度的 teacher 模型 C 来对这个参数量小的模型 B 进行常识蒸馏,使得该模型 B 的精度靠近最原始的模型 A,从而达到升高时延的目标。

3、图片标签之间的域迁徙

用户应用狗和猫的数据集训练了一个 teacher 模型 A,应用香蕉和苹果训练了一个 teacher 模型 B,那么就能够用这两个模型同时蒸馏出一个能够辨认狗,猫,香蕉以及苹果的模型,将两个不同与的数据集进行集成和迁徙。

图 2 图像域迁徙训练

4、升高标注量

该性能能够通过半监督的蒸馏形式来实现,用户利用训练好的 teacher 网络模型来对未标注的数据集进行蒸馏,达到升高标注量的目标。

1.4 常识蒸馏的原理

图 3 常识蒸馏原理介绍

个别应用蒸馏的时候,往往会找一个参数量更小的 student 网络,那么相比于 teacher 来说,这个轻量级的网络不能很好的学习到数据集之前暗藏的潜在关系,如上图所示,相比于 one hot 的输入,teacher 网络是将输入的 logits 进行了 softmax,更加平滑的解决了标签,行将数字 1 输入成了 0.6(对 1 的预测)和 0.4(对 0 的预测)而后输出到 student 网络中,相比于 1 来说,这种 softmax 含有更多的信息。好模型的指标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的指标是让 student 学习到 teacher 的泛化能力,实践上失去的后果会比单纯拟合训练数据的 student 要好。另外,对于分类工作,如果 soft targets 的熵比 hard targets 高,那显然 student 会学习到更多的信息。最终 student 模型学习的是 teacher 模型的泛化能力,而不是“过拟合训练数据”

二、入手实际常识蒸馏

ModelArts 模型市场中的 efficientDet 指标检测算法目前曾经反对常识蒸馏,用户能够通过上面的一个案例,来入门和相熟常识蒸馏在检测网络中的应用流程。

2.1 筹备数据集

数据集应用 kaggle 公开的 Images of Canine Coccidiosis Parasite 的辨认工作,下载地址:https://www.kaggle.com/kvinic…。用户下载数据集之后,公布到 ModelArts 的数据集治理中,同时进行数据集切分,默认依照 8:2 的比例切分成 train 和 eval 两种。

2.2 订阅市场算法 efficientDet

进到模型市场算法界面,找到 efficientDet 算法,点击“订阅”按钮

图 4 市场订阅 efficientDet 算法

而后到算法治理界面,找到曾经订阅的 efficientDet,点击同步,就能够进行算法训练

图 5 算法治理同步订阅算法

2.3 训练 student 网络模型

起一个 efficientDet 的训练作业,model_name=efficientdet-d0,数据集选用 2.1 公布的曾经切分好的数据集,抉择好输入门路,点击创立,具体创立参数如下:

图 6 创立 student 网络的训练作业

失去训练的模型精度信息在评估后果界面,如下:

图 7 student 模型训练后果

能够看到 student 的模型精度在 0.8473。

2.4 训练 teacher 网络模型

下一步就是训练一个 teacher 模型,依照 efficientDet 文档的形容,这里抉择 efficientdet-d3,同时须要增加一个参数,表明该训练作业生成的模型是用来作为常识蒸馏的 teacher 模型,新起一个训练作业,具体参数如下:

图 8 teacher 模型训练作业参数

失去的模型精度在评估后果一栏,具体如下:

图 9 teacher 模型训练后果

能够看到 teacher 的模型精度在 0.875。

2.5 应用常识蒸馏晋升 student 模型精度

有了 teacher 网络,下一步就是进行常识蒸馏了,依照官网文档,须要填写 teacher model url,具体填写的内容就是 2.4 训练输入门路上面的 model 目录,留神须要选到 model 目录的那一层级,同时须要增加参数 use_offline_kd=True,具体模型参数如下所示:

图 10 采纳常识蒸馏的 student 模型训练作业参数

失去模型精度在评估后果一栏,具体如下:

图 11 应用常识蒸馏之后的 student 模型训练后果

能够看到通过常识蒸馏之后的 student 的模型精度晋升到了 0.863,精度相比于之前的 student 网络晋升了 1.6% 百分点。

2.6 在线推理部署

训练之后的模型就能够进行模型部署了,具体点击“创立模型”

图 12 创立模型

界面会主动读取模型训练的保留门路,点击创立:

图 13 导入模型

模型部署胜利之后,点击创立在线服务:

图 14 部署在线服务

部署胜利就能够进行在线预测了:

图 15 模型推理后果展现

三、常识蒸馏目前的应用领域

目前常识蒸馏的算法曾经广泛应用到图像语义辨认,指标检测等场景中,并且针对不同的钻研场景,蒸馏办法都做了局部的定制化批改,同时,在行人检测,人脸识别,姿势检测,图像域迁徙,视频检测等方面,常识蒸馏也是作为一种晋升模型性能和精度的重要办法,随着深度学习的倒退,这种技术也会更加的成熟和稳固。

参考文献:

[1]Data Distillation: Towards Omni-Supervised Learning

[2]On the Efficacy of Knowledge Distillation

[3]Knowledge Distillation and Student-Teacher Learning for Visual Intelligence: A Review and New Outlooks

[4]Towards Understanding Knowledge Distillation

[5]Model Compression via Distillation and Quantization

**[点击关注,第一工夫理解华为云陈腐技术~](https://bbs.huaweicloud.com/b…
)**

退出移动版