本篇文章译自英文文档 Compile TFLite Models
作者是 FrozenGene (Zhao Wu) · GitHub
更多 TVM 中文文档可拜访 →Apache TVM 是一个端到端的深度学习编译框架,实用于 CPU、GPU 和各种机器学习减速芯片。 | Apache TVM 中文站

本文介绍如何用 Relay 部署 TFLite 模型。

首先装置 TFLite 包。

# 装置 tflitepip install tflite==2.1.0 --user

或者自行生成 TFLite 包,步骤如下:

# 获取 flatc 编译器。# 具体可参考 https://github.com/google/flatbuffers,确保正确装置flatc --version# 获取 TFLite 架构wget https://raw.githubusercontent.com/tensorflow/tensorflow/r1.13/tensorflow/lite/schema/schema.fbs# 生成 TFLite 包flatc --python schema.fbs# 将以后文件夹门路(蕴含生成的 TFLite 模块)增加到 PYTHONPATH。export PYTHONPATH=${PYTHONPATH:+$PYTHONPATH:}$(pwd)

用 python -c "import tflite" 命令,查看 TFLite 包是否装置胜利。

无关如何用 TVM 编译 TFLite 模型的示例如下:

用于下载和提取 zip 文件的程序

import osdef extract(path):    import tarfile    if path.endswith("tgz") or path.endswith("gz"):        dir_path = os.path.dirname(path)        tar = tarfile.open(path)        tar.extractall(path=dir_path)        tar.close()    else:        raise RuntimeError("Could not decompress the file: " + path)

加载预训练的 TFLite 模型

加载 Google 提供的 mobilenet V1 TFLite 模型:

from tvm.contrib.download import download_testdatamodel_url = "http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224.tgz"# 下载模型 tar 文件,解压失去 mobilenet_v1_1.0_224.tflitemodel_path = download_testdata(model_url, "mobilenet_v1_1.0_224.tgz", module=["tf", "official"])model_dir = os.path.dirname(model_path)extract(model_path)# 关上 mobilenet_v1_1.0_224.tflitetflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite")tflite_model_buf = open(tflite_model_file, "rb").read()# 从缓冲区获取 TFLite 模型try:    import tflite    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)except AttributeError:    import tflite.Model    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

加载测试图像

还是用猫的图像:

from PIL import Imagefrom matplotlib import pyplot as pltimport numpy as npimage_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"image_path = download_testdata(image_url, "cat.png", module="data")resized_image = Image.open(image_path).resize((224, 224))plt.imshow(resized_image)plt.show()image_data = np.asarray(resized_image).astype("float32")# 给图像增加一个维度,造成 NHWC 格局布局image_data = np.expand_dims(image_data, axis=0)# 预处理图像:# https://github.com/tensorflow/models/blob/edb6ed22a801665946c63d650ab9a0b23d98e1b1/research/slim/preprocessing/inception_preprocessing.py#L243image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1print("input", image_data.shape)


输入后果:

input (1, 224, 224, 3)

应用 Relay 编译模型

# TFLite 输出张量名称、shape 和类型input_tensor = "input"input_shape = (1, 224, 224, 3)input_dtype = "float32"# 解析 TFLite 模型,并将其转换为 Relay 模块from tvm import relay, transformmod, params = relay.frontend.from_tflite(    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype})# 针对 x86 CPU 构建模块target = "llvm"with transform.PassContext(opt_level=3):    lib = relay.build(mod, target, params=params)

输入后果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.  "target_host parameter is going to be deprecated. "

在 TVM 上执行

import tvmfrom tvm import tefrom tvm.contrib import graph_executor as runtime# 创立 runtime 执行器模块module = runtime.GraphModule(lib["default"](tvm.cpu()))# 输出数据module.set_input(input_tensor, tvm.nd.array(image_data))# 运行module.run()# 失去输入tvm_output = module.get_output(0).numpy()

显示后果

# 加载标签文件label_file_url = "".join(    [        "https://raw.githubusercontent.com/",        "tensorflow/tensorflow/master/tensorflow/lite/java/demo/",        "app/src/main/assets/",        "labels_mobilenet_quant_v1_224.txt",    ])label_file = "labels_mobilenet_quant_v1_224.txt"label_path = download_testdata(label_file_url, label_file, module="data")# 1001 个类的列表with open(label_path) as f:    labels = f.readlines()# 将后果转换为一维数据predictions = np.squeeze(tvm_output)# 取得分数最高的第一个预测值prediction = np.argmax(predictions)# 将 id 转换为类名,并显示后果print("The image prediction result is: id " + str(prediction) + " name: " + labels[prediction])

输入后果:

The image prediction result is: id 283 name: tiger cat

下载 Python 源代码:from_tflite.py

下载 Jupyter Notebook:from_tflite.ipynb