关于机器学习:使用-C-API-部署-TVM-模块

5次阅读

共计 1755 个字符,预计需要花费 5 分钟才能阅读完成。

更多 TVM 中文文档可拜访 →Apache TVM 是一个端到端的深度学习编译框架,实用于 CPU、GPU 和各种机器学习减速芯片。| Apache TVM 中文站

应用 C++ API 部署 TVM 模块

apps/howto_deploy 中给出了部署 TVM 模块的示例,执行上面的命令运行该示例:

cd apps/howto_deploy
./run_example.sh

获取 TVM Runtime 库​

惟一要做的是链接到 target 平台中的 TVM runtime。TVM 给出了一个最小 runtime,它的开销大概在 300K 到 600K 之间,具体值取决于应用模块的数量。大多数状况下,可用 libtvm_runtime.so 文件去构建。

若构建 libtvm_runtime 有艰难,可查看 tvm_runtime_pack.cc(集成了 TVM runtime 的所有示例)。用构建零碎来编译这个文件,而后将它蕴含到我的项目中。

查看 apps 获取在 iOS、Android 和其余平台上,用 TVM 构建的利用示例。

动静库 vs. 零碎模块​

TVM 有两种应用编译库的办法,查看 prepare_test_libs.py 理解如何生成库,查看 cpp_deploy.cc 理解如何应用它们。

  • 把库存储为共享库,并动静加载到我的项目中。
  • 将编译好的库以零碎模块模式绑定到我的项目中。

动静加载更加灵便,能疾速加载新模块。零碎模块是一种更 static 的办法,可用在动静库加载不可用的中央。


部署到 Android

为 Android Target 构建模型​

针对 Android target 的 Relay 模型编译遵循和 android_rpc 雷同的办法,以下代码会保留 Android target 所需的编译输入:

lib.export_library("deploy_lib.so", ndk.create_shared)
with open("deploy_graph.json", "w") as fo:
    fo.write(graph.json())
with open("deploy_param.params", "wb") as fo:
    fo.write(runtime.save_param_dict(params))

deploy_lib.so、deploy_graph.json、deploy_param.params 将转到 Android target。

实用于 Android Target 的 TVM Runtime​

参考 此处 为 Android target 构建 CPU/OpenCL 版本的 TVM runtime。参考这个 Java 示例来理解 Android Java TVM API,以及如何加载和执行模型。


将 TVM 集成到我的项目中

TVM runtime 具备轻量级和可移植性的特点,有几种办法可将 TVM 集成到我的项目中。

下文介绍如何将 TVM 作为 JIT 编译器集成到我的项目中,从而用它在零碎上生成函数的办法

DLPack 反对​

TVM 的生成函数遵循 PackedFunc 约定,它是一个能够承受地位参数(包含规范类型,如浮点、整数、字符串)的函数。PackedFunc 采纳 DLPack 约定中的 DLTensor 指针。惟一要做的是创立一个对应的 DLTensor 对象。

集成用户自定义的 C++ 数组​

在 C++ 中惟一要做的就是将你的数组转换为 DLTensor,并将其地址作为 DLTensor* 传递给生成的函数。

集成用户自定义的 Python 数组​

针对 Python 对象 MyArray,须要做:

  • 将 _tvm_tcode 字段增加到返回 tvm.TypeCode.ARRAY_HANDLE 的数组中
  • 在对象中反对
    _tvm_handle 属性(以 Python 整数模式返回 DLTensor 的地址)
  • 用 tvm.register_extension 注册这个类
# 示例代码
import tvm

class MyArray(object):
    _tvm_tcode = tvm.TypeCode.ARRAY_HANDLE

 @property
 def _tvm_handle(self):
        dltensor_addr = self.get_dltensor_addr()
 return dltensor_addr

# 将注册的步骤放在独自的文件 mypkg.tvm.py 中
# 依据须要选择性地导入依赖
tvm.register_extension(MyArray)
正文完
 0