NLP十八利用ALBERT提升模型预测速度的一次尝试

9次阅读

共计 6389 个字符,预计需要花费 16 分钟才能阅读完成。

前沿

  在文章 NLP(十七)利用 tensorflow-serving 部署 kashgari 模型中,笔者介绍了如何利用 tensorflow-serving 部署来部署深度模型模型,在那篇文章中,笔者利用 kashgari 模块实现了经典的 BERT+Bi-LSTM+CRF 模型结构,在标注了时间的文本语料(大约 2000 多个训练句子)中也达到了很好的识别效果,但是也存在着不足之处,那就是模型的预测时间过长,平均预测一个句子中的时间耗时约 400 毫秒,这种预测速度在生产环境或实际应用中是不能忍受的。
  查看该模型的耗时原因,很大一部分原因在于 BERT 的调用。BERT 是当下最火,知名度最高的预训练模型,虽然会使得模型的训练、预测耗时增加,但也是小样本语料下的最佳模型工具之一,因此,BERT 在模型的架构上是不可缺少的。那么,该如何避免使用预训练模型带来的模型预测耗时过长的问题呢?
  本文决定尝试使用 ALBERT,来验证 ALBERT 在提升模型预测速度方面的应用,同时,也算是本人对于使用 ALBERT 的一次实战吧~

ALBERT 简介

  我们不妨花一些时间来简单地了解一下 ALBERT。ALBERT 是最近一周才开源的预训练模型,其 Github 的网址为:https://github.com/brightmart…,其论文可以参考网址:https://arxiv.org/pdf/1909.11…。
  根据 ALBERT 的 Github 介绍,ALBERT 在海量中文语料上进行了预训练,模型的参数更少,效果更好。以 albert_tiny_zh 为例,其文件大小 16M、参数为 1.8M,模型大小仅为 BERT 的 1 /25,效果仅比 BERT 略差或者在某些 NLP 任务上更好。在本文的预训练模型中,将采用 albert_tiny_zh。

利用 ALBERT 训练时间识别模型

  我们以 Github 中的 bertNER 为本次项目的代码模板,在该项目中,实现的模型为 BERT+Bi-LSTM+CRF,我们将 BERT 替换为 ALBERT,也就是说笔者的项目中模型为 ALBERT+Bi-LSTM+CRF,同时替换 bert 文件夹的代码为 alert_zh,替换预训练模型文件夹 chinese_L-12_H-768_A-12(BERT 中文预训练模型文件)为 albert_tiny。当然,也需要修改一部分的项目源代码,来适应 ALBERT 的模型训练。
  数据集采用笔者自己标注的时间语料,即标注了时间的句子,大概 2000+ 句子,其中 75% 作为训练集(time.train 文件),10% 作为验证集(time.dev 文件),15% 作为测试集(time.test 文件)。在这里笔者不打算给出具体的 Python 代码,因为工程比较复杂,有兴趣的额读者可以去查看该项目的 Github 地址:。
  一些模型的参数可以如下:

  • 预训练模型:ALBERT(tiny)
  • 训练样本的最大字符长度:128
  • batch_size: 8
  • epoch: 100
  • 双向 LSTM 的个数:100

  ALBERT 的模型训练时间也会显著提高,我们耐心地等待模型训练完毕。在 time.dev 和 time.test 数据集上的表现如下表:

数据集 precision recall f1
time.dev 81.41% 84.95% 83.14%
time.test 83.03% 86.38% 84.67%

  接着笔者利用训练好的模型,用 tornado 封装了一个模型预测的 HTTP 服务,具体的代码如下:

# -*- coding: utf-8 -*-

import os
import json
import time
import pickle
import traceback

import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options

import tensorflow as tf
from utils import create_model, get_logger
from model import Model
from loader import input_from_line
from train import FLAGS, load_config, train

# 定义端口为 12306
define("port", default=12306, help="run on the given port", type=int)
# 导入模型
config = load_config(FLAGS.config_file)
logger = get_logger(FLAGS.log_file)
# limit GPU memory
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = False
with open(FLAGS.map_file, "rb") as f:
    tag_to_id, id_to_tag = pickle.load(f)

sess = tf.Session(config=tf_config)
model = create_model(sess, Model, FLAGS.ckpt_path, config, logger)

# 模型预测的 HTTP 接口
class ResultHandler(tornado.web.RequestHandler):
    # post 函数
    def post(self):
        event = self.get_argument('event')
        result = model.evaluate_line(sess, input_from_line(event, FLAGS.max_seq_len, tag_to_id), id_to_tag)
        self.write(json.dumps(result, ensure_ascii=False))

# 主函数
def main():
    # 开启 tornado 服务
    tornado.options.parse_command_line()
    # 定义 app
    app = tornado.web.Application(
            handlers=[(r'/subj_extract', ResultHandler)
                     ], #网页路径控制
           )
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()

main()

模型预测提速了吗?

  将模型预测封装成 HTTP 服务后,我们利用 Postman 来测试模型预测的效果和时间,如下图所示:


可以看到,模型预测的结果正确,且耗时仅为 38ms。
  接着我们尝试多测试几个句子的测试,测试代码如下:

# Daxing, Beijing
import requests
import json
import time

url = 'http://localhost:12306/subj_extract'

texts = ['据《新闻联播》报道,9 月 9 日至 11 日,中央纪委书记赵乐际到河北调研。',
         '记者从国家发展改革委、商务部相关方面获悉,日前美方已决定对拟于 10 月 1 日实施的中国输美商品加征关税措施做出调整,中方支持相关企业从即日起按照市场化原则和 WTO 规则,自美采购一定数量大豆、猪肉等农产品,国务院关税税则委员会将对上述采购予以加征关税排除。',
         '据印度 Zee 新闻网站 12 日报道,亚洲新闻国际通讯社援引印度军方消息人士的话说,9 月 11 日的对峙事件发生在靠近班公错北岸的实际控制线一带。',
         '儋州市决定,从 9 月开始,对城市低保、农村低保、特困供养人员、优抚对象、领取失业保险金人员、建档立卡未脱贫人口等低收入群体共 3 万多人,发放猪肉价格补贴,每人每月发放不低于 100 元补贴,以后发放标准,将根据猪肉价波动情况进行动态调整。',
         '9 月 11 日,华为心声社区发布美国经济学家托马斯. 弗里德曼在《纽约时报》上的专栏内容,弗里德曼透露,在与华为创始人任正非最近一次采访中,任正非表示华为愿意与美国司法部展开话题不设限的讨论。',
         '造血干细胞移植治疗白血病技术已日益成熟,然而,通过该方法同时治愈艾滋病目前还是一道全球尚在攻克的难题。',
         '英国航空事故调查局(AAIB)近日披露,今年 2 月 6 日一趟由德国法兰克福飞往墨西哥坎昆的航班上,因飞行员打翻咖啡使操作面板冒烟,导致飞机折返迫降爱尔兰。',
         '当地时间周四(9 月 12 日),印度尼西亚财政部长英卓华(Sri Mulyani Indrawati)明确表示:特朗普的推特是风险之一。',
         '华中科技大学 9 月 12 日通过其官方网站发布通报称,9 月 2 日,我校一硕士研究生不幸坠楼身亡。',
         '微博用户 @ooooviki 9 月 12 日下午公布发生在自己身上的惊悚遭遇:一个自称网警、名叫郑洋的人利用职务之便,查到她的完备的个人信息,包括但不限于身份证号、家庭地址、电话号码、户籍变动情况等,要求她做他女朋友。',
         '今天,贵阳取消了汽车限购,成为目前全国实行限购政策的 9 个省市中,首个取消限购的城市。',
         '据悉,与全球同步,中国区此次将于 9 月 13 日于 iPhone 官方渠道和京东正式开启预售,京东成 Apple 中国区唯一官方授权预售渠道。',
         '根据央行公布的数据,截至 2019 年 6 月末,存款类金融机构住户部门短期消费贷款规模为 9.11 万亿元,2019 年上半年该项净增 3293.19 亿元,上半年增量看起来并不乐观。',
         '9 月 11 日,一段拍摄浙江万里学院学生食堂的视频走红网络,视频显示该学校食堂不仅在用餐区域设置了可以看电影、比赛的大屏幕,还推出了“一人食”餐位。',
         '当日,在北京举行的 2019 年国际篮联篮球世界杯半决赛中,西班牙队对阵澳大利亚队。',
         ]

t1 = time.time()
for text in texts:
    data = {'event': text.replace('','')}
    req = requests.post(url, data)
    if req.status_code == 200:
        print('原文:%s' % text)
        res = json.loads(req.content)['entities']
        print('抽取结果:%s' % str([_['word'] for _ in res]))


t2 = time.time()
print('一共耗时:%ss.' % str(round(t2-t1, 4)))

输出结果如下:

 原文:据《新闻联播》报道,9 月 9 日至 11 日,中央纪委书记赵乐际到河北调研。抽取结果:['9 月 9 日至 11 日']
原文:记者从国家发展改革委、商务部相关方面获悉,日前美方已决定对拟于 10 月 1 日实施的中国输美商品加征关税措施做出调整,中方支持相关企业从即日起按照市场化原则和 WTO 规则,自美采购一定数量大豆、猪肉等农产品,国务院关税税则委员会将对上述采购予以加征关税排除。抽取结果:['日前', '10 月 1 日']
原文:据印度 Zee 新闻网站 12 日报道,亚洲新闻国际通讯社援引印度军方消息人士的话说,9 月 11 日的对峙事件发生在靠近班公错北岸的实际控制线一带。抽取结果:['12 日', '9 月 11 日']
原文:儋州市决定,从 9 月开始,对城市低保、农村低保、特困供养人员、优抚对象、领取失业保险金人员、建档立卡未脱贫人口等低收入群体共 3 万多人,发放猪肉价格补贴,每人每月发放不低于 100 元补贴,以后发放标准,将根据猪肉价波动情况进行动态调整。抽取结果:['9 月']
原文:9 月 11 日,华为心声社区发布美国经济学家托马斯. 弗里德曼在《纽约时报》上的专栏内容,弗里德曼透露,在与华为创始人任正非最近一次采访中,任正非表示华为愿意与美国司法部展开话题不设限的讨论。抽取结果:['9 月 11 日']
原文:造血干细胞移植治疗白血病技术已日益成熟,然而,通过该方法同时治愈艾滋病目前还是一道全球尚在攻克的难题。抽取结果:[]
原文:英国航空事故调查局(AAIB)近日披露,今年 2 月 6 日一趟由德国法兰克福飞往墨西哥坎昆的航班上,因飞行员打翻咖啡使操作面板冒烟,导致飞机折返迫降爱尔兰。抽取结果:['近日', '今年 2 月 6 日']
原文:当地时间周四(9 月 12 日),印度尼西亚财政部长英卓华(Sri Mulyani Indrawati)明确表示:特朗普的推特是风险之一。抽取结果:['当地时间周四(9 月 12 日)']
原文:华中科技大学 9 月 12 日通过其官方网站发布通报称,9 月 2 日,我校一硕士研究生不幸坠楼身亡。抽取结果:['9 月 12 日', '9 月 2 日']
原文:微博用户 @ooooviki 9 月 12 日下午公布发生在自己身上的惊悚遭遇:一个自称网警、名叫郑洋的人利用职务之便,查到她的完备的个人信息,包括但不限于身份证号、家庭地址、电话号码、户籍变动情况等,要求她做他女朋友。抽取结果:['9 月 12 日下午']
原文:今天,贵阳取消了汽车限购,成为目前全国实行限购政策的 9 个省市中,首个取消限购的城市。抽取结果:['今天', '目前']
原文:据悉,与全球同步,中国区此次将于 9 月 13 日于 iPhone 官方渠道和京东正式开启预售,京东成 Apple 中国区唯一官方授权预售渠道。抽取结果:['9 月 13 日']
原文:根据央行公布的数据,截至 2019 年 6 月末,存款类金融机构住户部门短期消费贷款规模为 9.11 万亿元,2019 年上半年该项净增 3293.19 亿元,上半年增量看起来并不乐观。抽取结果:['2019 年 6 月末', '2019 年上半年', '上半年']
原文:9 月 11 日,一段拍摄浙江万里学院学生食堂的视频走红网络,视频显示该学校食堂不仅在用餐区域设置了可以看电影、比赛的大屏幕,还推出了“一人食”餐位。抽取结果:['9 月 11 日']
原文:当日,在北京举行的 2019 年国际篮联篮球世界杯半决赛中,西班牙队对阵澳大利亚队。抽取结果:['当日', '2019 年']
一共耗时:0.5314s.

可以看到,对于测试的 15 个句子,识别的准确率很高,且预测耗时为 531ms,平均每个话的预测时间不超过 40ms。相比较而言,文章 NLP(十七)利用 tensorflow-serving 部署 kashgari 模型中的模型,该模型的预测时间为每句话 1 秒多,模型预测的速度为带 ALBERT 模型的 25 倍多。
  因此,ALBERT 模型确实提升了模型预测的时间,而且效 & 果非常显著。

总结

  由于 ALBERT 开源不到一周,而且笔者的学识、才能有限,因此,在代码方面可能会存在不足。但是,作为一次使用 ALBERT 的历经,希望能够与大家分享。
  本文绝不是上述项目代码的抄袭和堆砌,该项目融入了笔者自己的思考,希望不要被误解为是抄袭。笔者使用上述的 bertNER 和 ALBERT,只是为了验证 ALBERT 在模型预测耗时方面的提速效果,而事实是,ALBERT 确实给我带来了很大惊喜,感受源代码作者们~
  最后,附上本文中笔者项目的 Github 地址:https://github.com/percent4/A…。
  众里寻他千百度。蓦然回首,那人却在,灯火阑珊处。

参考文献

  1. 超小型 BERT 中文版横空出世!模型只有 16M,训练速度提升 10 倍:https://mp.weixin.qq.com/s/eV…
  2. ALBERT 的 Github 地址:https://github.com/brightmart…
  3. bertNER 项目的 Github 地址:https://github.com/yumath/ber…
  4. NLP(十七)利用 tensorflow-serving 部署 kashgari 模型:https://www.cnblogs.com/jclia…
正文完
 0