关于算法:在数据增强蒸馏剪枝下ERNIE30分类模型性能提升

7次阅读

共计 9439 个字符,预计需要花费 24 分钟才能阅读完成。

在数据加强、蒸馏剪枝下 ERNIE3.0 模型性能晋升

我的项目链接:
https://aistudio.baidu.com/aistudio/projectdetail/4436131?contributionType=1

以 CBLUE 数据集中医疗搜寻检索词用意分类为例:

本我的项目首先解说了数据加强和数据蒸馏的计划,并在前面章节进行成果展现,后果预览:

模型 ACC Precision Recall F1 average_of_acc_and_f1
ERNIE 3.0 Base 0.80255 0.9317147 0.908284 0.919850 0.86120
ERNIE 3.0 Base+ 数据加强 0.7979539 0.901004 0.92899 0.91478 0.8563
ERNIE 3.0 Base+ 剪裁保留比 0.5 0.79846 0.951257 0.89497 0.92225 0.8603
ERNIE 3.0 Base + 剪裁保留比 2 /3 0.8092071 0.9415384 0.905325 0.923076 0.86614

gensim 装置最新版本:pip install gensim

tqdm 装置:pip install tqdm

LAC 装置最新版本:pip install lac


Gensim 库介绍

Gensim 是在做自然语言解决时较为常常用到的一个工具库,次要用来以无监督的形式从原始的非结构化文本当中来学习到文本暗藏层的主题向量表白。

次要包含 TF-IDF,LSA,LDA,word2vec,doc2vec 等多种模型。

Tqdm

是一个疾速,可扩大的 Python 进度条,能够在 Python 长循环中增加一个进度提示信息,用户只须要封装任意的迭代器 tqdm(iterator)。目标为了程序显示的好看

中文词法剖析 -LAC

LAC 是一个联结的词法分析模型,整体性地实现中文分词、词性标注、专名辨认工作。LAC 既能够认为是 Lexical Analysis of Chinese 的首字母缩写,也能够认为是 LAC Analyzes Chinese 的递归缩写。

LAC 基于一个重叠的双向 GRU 构造,在长文本上精确复刻了百度 AI 开放平台上的词法剖析算法。成果方面,分词、词性、专名辨认的整体准确率 95.5%;独自评估专名辨认工作,F 值 87.1%(精确 90.3,召回 85.4%),总体略优于开放平台版本。在成果优化的根底上,LAC 的模型简洁高效,内存开销不到 100M,而速度则比百度 AI 开放平台进步了 57%

LAC 链接:https://www.paddlepaddle.org….

!pip install –upgrade paddlenlp
!pip install gensim
!pip install tqdm
!pip install lac

2. 数据加强计划介绍

数据加强工具提供 4 种加强策略:遮蔽、删除、同词性词替换、词向量近义词替换

!unzip ERNIE-.zip -d ./ERNIE

增加 ERNIE 工具包

如果程序报错:能够发现提醒有一个.ipynb_checkpoints 的文件。但当我去对应的文件夹找时基本看不到这个文件,所以猜想是一个暗藏文件。所以通过终端进入对应的目录:输出 cd coco 进入对应目录,输出 ls - a 显示所有文件。而后输出 rm -rf .ipynb_checkpoints 删除该文件。再次输出 ls - a 查看文件是否被删除。

下载词表, 词表有 1.7G 会花点工夫。上面以情感剖析数据样例展现 demo,看看数据加强的成果。

!wget -q --no-check-certificate http://bj.bcebos.com/wenxin-models/vec2.txt

python data_aug.py “ 输出文件夹的目录 ” “ 输入文件夹的目录 ”

  • data_aug.py 脚本传参阐明
shell 输出:python data_aug.py -h

shell 输入:usage: data_aug.py [-h] [-n AUG_TIMES] [-c COLUMN_NUMBER] [-u UNK]
                       [-t TRUNCATE] [-r POS_REPLACE] [-w W2V_REPLACE]
                       [-e ERNIE_REPLACE] [--unk_token UNK_TOKEN]
                       input output
    
    main
    
    positional arguments:
      input                                                #原始待加强数据文件所在文件夹,带 label 的,一个或多个文本列
      output                                               #输入文件门路
    
    optional arguments:
      -h, --help            show this help message and exit
      -n AUG_TIMES, --aug_times AUG_TIMES                  #数据集数目放大 n 倍,output 行数为 input 的 n + 1 倍      
      -c COLUMN_NUMBER, --column_number COLUMN_NUMBER      #明文文件中所要加强列的列序号,多列用逗号宰割,如:1,2
      -u UNK, --unk UNK                                    #unk 加强策略的概率
      -t TRUNCATE, --truncate TRUNCATE                     #truncate 加强策略的概率
      -r POS_REPLACE, --pos_replace POS_REPLACE            #pos_replace 加强策略的概率
      -w W2V_REPLACE, --w2v_replace W2V_REPLACE            #w2v_replace 加强策略的概率
      --unk_token UNK_TOKEN                    

分类问题中:举荐应用前三种即可,w2v 词向量近义词替换能够不必,破费工夫太长。

!python data_aug.py --unk 0.25 --truncate 0.25 --pos 0.5 --w2v 0 ./train ./output
demo 后果展现:机器 反面 仿佛 被 撕 了 张 什么 标签,残 胶 还在。然而 又 看 不 出 是 什么 标签 不见 了,该 有 的 都 在,怪    0
机器 反面 仿佛 被 撕 了 张 什么 标签,胶 还在。然而 又 看 不 出 是 什么 标签 不见 了,该 有 的 都 在,怪    0
机器 反面 了 张 什么 标签,残 胶 还在。然而 又 看 不 出 是 什么 标签  了,该在,怪    0
呵呵,尽管 表皮 看上去 不错 很 粗劣,然而 我 还是 能 看得出来 是 盗 的。然而 外面 的 内容 真 的 不错,我 妈 爱 看,我本人 也 学 着 找 一些 穴位。0
呵呵,尽管 表皮 看上去 不错 很 粗劣,然而 我 还是 能 看得出来 是 盗 的。然而 外面 的 内容 真 的 不错,我😄妈 爱 看,我本人 也 学 着 找 一些 穴位 😄    0
呵呵,尽管 表皮 看上去 不错 很 粗劣,然而 我 还😄 能 看得出来 是 盗😄😄😄。然而 外面 的 内容 真 的 不错,我 妈 爱 看,😄😄😄😄😄😄😄学 着 找 😄😄😄😄😄😄😄    0
😄😄😄😄😄尽管 表皮 看上去 不错 很 粗劣,然而 我 还是 能 看得出来 是 盗 的。然而 外面 的 内容 真 的 不错,我 妈 爱 看,我本人 也 学 着 找 一些 穴位。0
😄😄😄😄😄😄😄 表皮 看上去 不错 很 粗劣,然而 我 还是 能 看得出来 是 盗 的。然而 外面 的 内容 真 的 不错,我 妈 爱 看,我本人 也 学 着 找 一些 穴位。0
天文 地位 佳,在 市中心。酒店 服务 好、早餐 种类 丰盛。我 住 的 商务 数码 房 电脑 宽带 速度 称心 , 房间 还算 洁净,离 湖南路小吃街 近。1
天文 地位 佳,在 市中心。酒店 服务 好、早餐 种类 丰盛。我 住 的 商务 数码 房 电脑 宽带 速度 称心 , 房间 还算 洁净,离 湖南路小吃街 近。。1
天文 地位 佳,在 市中心。酒店 服务 好、早餐 种类 丰盛。我 住 的 商务 数码 房 电脑 宽带 速度 称心 , 机器 还算 洁净,离 湖南路小吃街 近。1
天文 地位 佳,在 市中心。酒店 服务 好、早餐 种类 丰盛。我 住 的 商务 数码 房 电脑 宽带 速度 称心 , 房间 还算 洁净,离 湖南路小吃街 近。1
天文 地位 佳,在 市中心。酒店 服务 好、早餐 种类 丰盛。我 住 的 商务 数码 房 电脑 宽
我 看 是 书 的 还 能够,然而 我 订 的 书 迟迟 还 到 能 半个月,都 没有 收到 打电话 也 没

2.0 补充 nlpcda 一键中文数据加强工具(NLP Chinese Data Augmentation)

一键中文数据加强工具,反对:

1. 随机实体替换
2. 近义词
3. 近义近音字替换
4. 随机字删除(外部细节:数字工夫日期片段,内容不会删)
5.NER 类 BIO 数据加强
6. 随机置换邻近的字:研表究明,汉字序顺并不定一影响文字的浏览了解 << 是乱序的
7. 中文等价字替换(1 一 壹 ①,2 二 贰 ②)
8. 翻译互转实现的加强
9. 应用 simbert 做生成式类似句生成

参考链接:
一键中文数据加强包;NLP 数据加强、bert 数据加强、EDA:pip install nlpcda
nlpcda 一键中文数据加强工具

3. 数据蒸馏技术

ERNIE 数据蒸馏三步

Step 1. 应用 ERNIE 模型对输出标注数据对进行 fine-tune,失去 Teacher Model

Step 2. 应用 ERNIE Service 对以下无监督数据进行预测:

  • 用户提供的大规模无标注数据,需与标注数据同源
  • 对标注数据进行数据加强,具体加强策略
  • 对无标注数据和数据加强数据进行肯定比例混合

Step 3. 应用步骤 2 的数据训练出 Student Model

数据加强

目前采纳三种数据加强策略策略,对于不必的工作能够特定的比例混合。三种数据加强策略包含:

增加噪声:对原始样本中的词,以肯定的概率(如 0.1)替换为”UNK”标签

同词性词替换:对原始样本中的所有词,以肯定的概率(如 0.1)替换为本数据集钟随机一个同词性的词

N-sampling:从原始样本中,随机选取地位截取长度为 m 的片段作为新的样本,其中片段的长度 m 为 0 到原始样本长度之间的随机值
模型剪裁,基于 PaddleNLP 的 Trainer API 公布提供了模型裁剪 API。裁剪 API 反对用户对 ERNIE 等 Transformers 类上游工作微调模型进行裁剪。

具体成果在下一节展示,先装置好 paddleslim 库

4. 基于 ERNIR3.0 文本模型微调

加载已有数据集:CBLUE 数据集中医疗搜寻检索词用意分类(训练)

数据集定义:
以公开数据集 CBLUE 数据集中医疗搜寻检索词用意分类 (KUAKE-QIC) 工作为示例,在训练集上进行模型微调,并在开发集上应用准确率 Accuracy 评估模型体现。

数据集默认为:默认为 ”cblue”。

save_dir:保留训练模型的目录;默认保留在当前目录 checkpoint 文件夹下。

dataset:训练数据集; 默认为 ”cblue”。

<font color=”red”>dataset_dir:本地数据集门路,数据集门路中应蕴含 train.txt,dev.txt 和 label.txt 文件; 默认为 None。</font>

task_name:训练数据集; 默认为 ”KUAKE-QIC”。

max_seq_length:ERNIE 模型应用的最大序列长度,最大不能超过 512, 若呈现显存有余,请适当调低这一参数;默认为 128。

<font color=”red”>model_name:抉择预训练模型;默认为 ”ernie-3.0-base-zh”。</font>

<font color=”red”>device: 选用什么设施进行训练,可选 cpu、gpu、xpu、npu。如应用 gpu 训练,可应用参数 gpus 指定 GPU 卡号。</font>

batch_size:批处理大小,请联合显存状况进行调整,若呈现显存有余,请适当调低这一参数;默认为 32。

learning_rate:Fine-tune 的最大学习率;默认为 6e-5。

weight_decay:管制正则项力度的参数,用于避免过拟合,默认为 0.01。

early_stop:抉择是否应用早停法(EarlyStopping);默认为 False。

<font color=”red”>early_stop_nums:在设定的早停训练轮次内,模型在开发集上体现不再回升,训练终止;默认为 4。
epochs: 训练轮次,默认为 100。</font>

warmup:是否应用学习率 warmup 策略;默认为 False。

warmup_proportion:学习率 warmup 策略的比例数,如果设为 0.1,则学习率会在前 10%steps 数从 0 缓缓增长到 learning_rate, 而后再迟缓衰减;默认为 0.1。

logging_steps: 日志打印的距离 steps 数,默认 5。

init_from_ckpt: 模型初始 checkpoint 参数地址,默认 None。

seed:随机种子,默认为 3。

# 批改后的训练文件 train_new2.py,次要应用了 paddlenlp.metrics.glue 的 AccuracyAndF1:准确率及 F1-score,可用于 GLUE 中的 MRPC 和 QQP 工作
#不过吐槽一下:return (acc,precision,recall,f1,(acc + f1) / 2,) 最初一个指标居然是加权均匀.....
!python train_new2.py --warmup --early_stop --epochs 10 --save_dir "./checkpoint2" --batch_size 16 --model_name ernie-3.0-base-zh

训练后果局部展现:

[2022-08-16 19:58:36,834] [INFO] - global step 1280, epoch: 3, batch: 412, loss: 0.23292, acc: 0.87106, speed: 16.54 step/s
[2022-08-16 19:58:37,392] [INFO] - global step 1290, epoch: 3, batch: 422, loss: 0.22339, acc: 0.87130, speed: 17.94 step/s
[2022-08-16 19:58:37,960] [INFO] - global step 1300, epoch: 3, batch: 432, loss: 0.22791, acc: 0.87182, speed: 17.68 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8025575447570332, 0.9317147192716236, 0.908284023668639, 0.9198501872659175, 0.8612038660114754)

[2022-08-16 20:01:36,060] [INFO] – Early stop!
[2022-08-16 20:01:36,060] [INFO] – Save best accuracy text classification model in ./checkpoint2

4.1 加载自定义数据集(并通过数据加强训练)

从本地文件创建数据集

应用本地数据集来训练咱们的文本分类模型,本我的项目反对应用固定格局本地数据集文件进行训练
如果须要对本地数据集进行数据标注,能够参考文本分类工作 doccano 数据标注使用指南进行文本分类数据标注。[这个放到下个我的项目解说]

本我的项目将以 CBLUE 数据集中医疗搜寻检索词用意分类 (KUAKE-QIC) 工作为例进行介绍如何加载本地固定格局数据集进行训练:

本地数据集目录构造如下:

data/
├── train.txt # 训练数据集文件
├── dev.txt # 开发数据集文件
├── label.txt # 分类标签文件
└── data.txt # 可选,待预测数据文件

局部后果展现

[2022-08-16 23:43:18,093] [INFO] - global step 2400, epoch: 2, batch: 234, loss: 0.60859, acc: 0.84437, speed: 19.27 step/s
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7979539641943734, 0.9010043041606887, 0.9289940828402367, 0.9147851420247632, 0.8563695531095683)
[2022-08-16 23:43:24,522] [INFO] - Save best F1 text classification model in ./checkpoint3
[2022-08-16 23:43:24,523] [INFO] - best F1 performence has been updated: 0.91450 --> 0.91479

4.2 数据蒸馏

!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \
    --device "gpu" \
    --output_dir "./prune" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --learning_rate 3e-5 \
    --num_train_epochs 5 \
    --logging_steps 10 \
    --save_steps 50 \
    --seed 3 \
    --dataset_dir "KUAKE_QIC" \
    --max_seq_length 128 \
    --params_dir "./checkpoint3" \
    --width_mult '0.5'

局部后果展现:

[2022-08-17 14:22:30,954] [INFO] - width_mult: 0.5, eval loss: 0.63535, acc: 0.79847
(acc, precision, recall, f1, average_of_acc_and_f1):(0.7984654731457801, 0.9512578616352201, 0.8949704142011834, 0.9222560975609755, 0.8603607853533778)
[2022-08-17 14:22:35,870] [INFO] - Save best F1 text classification model in ./prune/0.5
[2022-08-17 14:22:35,870] [INFO] - best F1 performence has been updated: 0.92226 --> 0.92226
!unset CUDA_VISIBLE_DEVICES
!python -m paddle.distributed.launch --gpus "0" prune.py \
    --device "gpu" \
    --output_dir "./prune" \
    --per_device_train_batch_size 32 \
    --per_device_eval_batch_size 32 \
    --learning_rate 3e-5 \
    --num_train_epochs 5 \
    --logging_steps 10 \
    --save_steps 50 \
    --seed 3 \
    --dataset_dir "KUAKE_QIC" \
    --max_seq_length 128 \
    --params_dir "./checkpoint3" \
    --width_mult '2/3'
2022-08-17 14:53:45,544] [INFO] - global step 3070, epoch: 2, batch: 904, loss: 0.709566, speed: 9.93 step/s
[2022-08-17 14:53:46,550] [INFO] - global step 3080, epoch: 2, batch: 914, loss: 0.607238, speed: 9.94 step/s
[2022-08-17 14:53:47,558] [INFO] - global step 3090, epoch: 2, batch: 924, loss: 0.718484, speed: 9.93 step/s
[2022-08-17 14:53:48,563] [INFO] - global step 3100, epoch: 2, batch: 934, loss: 0.546288, speed: 9.95 step/s
[2022-08-17 14:53:50,206] [INFO] - teacher model, eval loss: 0.66438, acc: 0.80358
[2022-08-17 14:53:50,207] [INFO] - eval done total : 1.6434180736541748 s
[2022-08-17 14:53:53,568] [INFO] - width_mult: 0.6666666666666666, eval loss: 0.60219, acc: 0.80921
(acc, precision, recall, f1, average_of_acc_and_f1):(0.8092071611253197, 0.9415384615384615, 0.9053254437869822, 0.923076923076923, 0.8661420421011213)
[2022-08-17 14:53:58,489] [INFO] - Save best F1 text classification model in ./prune/0.6666666666666666
[2022-08-17 14:53:58,489] [INFO] - best F1 performence has been updated: 0.92308 --> 0.92308

4.3 模型预测

输出待预测数据和数据标签对照列表,模型预测数据对应的标签

应用默认数据进行预测:

# 也能够抉择应用本地数据文件 data/data.txt 进行预测:!python predict.py --params_path ./checkpoint3/ --dataset_dir ./KUAKE_QIC --device "cpu"
黑苦荞茶的效用与作用及食用方法 效用作用
接壤痣会凸起吗 疾病表述
查看是否能怀孕挂什么科 就医倡议
鱼油怎么吃咬破吃还是间接咽下去 其余
幼儿挑食的生理起因是 病因剖析
!python predict.py \
    --device "cpu" \
    --dataset_dir ./KUAKE_QIC \
    --params_path "./prune/0.5" \

5. 总结

本我的项目首先解说了数据加强和数据蒸馏的计划,并在前面章节进行成果展现,当初进行汇总

模型 ACC Precision Recall F1 average_of_acc_and_f1
ERNIE 3.0 Base 0.80255 0.9317147 0.908284 0.919850 0.86120
ERNIE 3.0 Base+ 数据加强 0.7979539 0.901004 0.92899 0.91478 0.8563
ERNIE 3.0 Base+ 剪裁保留比 0.5 0.79846 0.951257 0.89497 0.92225 0.8603
ERNIE 3.0 Base + 剪裁保留比 2 /3 0.8092071 0.9415384 0.905325 0.923076 0.86614

剖析可得,

  • 首先数据加强后导致性能局部降落局部和预期的起因:
    随机 mask、删除会产生过多噪声样本影响后果,举荐只应用同义词替换,本次样本数据量足够,且 ERNIE 性能本就优越,数据加强对后果晋升在较大样本集能够疏忽。
  • 其次,能够看到通过数据蒸馏后,模型性能变动不大,甚至在剪裁 1 / 3 之后,性能有小幅度晋升

本次次要对分类模型退出数据加强、数据蒸馏,曾经对性能指标进行细化,不只是 ACC,集体比拟关注 F1 状况,并作为保留模型根据。

瞻望: 后续将欠缺动态图和动态图转化局部,让蒸馏下来模型能够持续线上加载应用;其次将会思考小样本学习在分类模型利用状况;最初将实现模型交融环节晋升性能,并做可解释性剖析。

自己博客:https://blog.csdn.net/sinat_39620217?type=blog

正文完
 0