关于程序员:使用onnx对pytorch模型进行部署

43次阅读

共计 1952 个字符,预计需要花费 5 分钟才能阅读完成。

1.onnx runtime 装置

# 激活虚拟环境
conda activate env_name # env_name 换成环境名称
# 装置 onnx
pip install onnx 
# 装置 onnx runtime
pip install onnxruntime # 应用 CPU 进行推理
# pip install onnxruntime-gpu # 应用 GPU 进行推理

2. 导出模型

import torch.onnx 
# 转换的 onnx 格局的名称,文件后缀需为.onnx
onnx_file_name = "xxxxxx.onnx"
# 咱们须要转换的模型,将 torch_model 设置为本人的模型
model = torch_model
# 加载权重,将 model.pth 转换为本人的模型权重
# 如果模型的权重是应用多卡训练进去,咱们须要去除权重中多的 module. 具体操作能够见 5.4 节
model = model.load_state_dict(torch.load("model.pth"))
# 导出模型前,必须调用 model.eval()或者 model.train(False)
model.eval()
# dummy_input 就是一个输出的实例,仅提供输出 shape、type 等信息 
batch_size = 1 # 随机的取值,当设置 dynamic_axes 后影响不大
dummy_input = torch.randn(batch_size, 1, 224, 224, requires_grad=True) 
# 这组输出对应的模型输入
output = model(dummy_input)
# 导出模型
torch.onnx.export(model,        # 模型的名称
                  dummy_input,   # 一组实例化输出
                  onnx_file_name,   # 文件保留门路 / 名称
                  export_params=True,        #  如果指定为 True 或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.
                  opset_version=10,          # ONNX 算子集的版本,以后已更新到 15
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names = ['input'],   # 输出模型的张量的名称
                  output_names = ['output'], # 输入模型的张量的名称
                  # dynamic_axes 将 batch_size 的维度指定为动静,# 后续进行推理的数据能够与导出的 dummy_input 的 batch_size 不同
                  dynamic_axes={'input' : {0 : 'batch_size'},    
                                'output' : {0 : 'batch_size'}})

3. 模型校验

import onnx
# 咱们能够应用异样解决的办法进行测验
try:
    # 当咱们的模型不可用时,将会报出异样
    onnx.checker.check_model(self.onnx_model)
except onnx.checker.ValidationError as e:
    print("The model is invalid: %s"%e)
else:
    # 模型可用时,将不会报出异样,并会输入“The model is valid!”print("The model is valid!")

4. 模型可视化
Netron 下载网址:github.com/lutzroeder/…
5. 应用 ONNX Runtime 进行推理

# 导入 onnxruntime
import onnxruntime
# 须要进行推理的 onnx 模型文件名称
onnx_file_name = "xxxxxx.onnx"

# onnxruntime.InferenceSession 用于获取一个 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession(onnx_file_name)  

# 构建字典的输出数据,字典的 key 须要与咱们构建 onnx 模型时的 input_names 雷同
# 输出的 input_img 也须要扭转为 ndarray 格局
ort_inputs = {'input': input_img} 
# 咱们更倡议应用上面这种办法, 因为防止了手动输出 key
# ort_inputs = {ort_session.get_inputs()[0].name:input_img}

# run 是进行模型的推理,第一个参数为输入张量名的列表,个别状况能够设置为 None
# 第二个参数为构建的输出值的字典
# 因为返回的后果被列表嵌套,因而咱们须要进行 [0] 的索引
ort_output = ort_session.run(None,ort_inputs)[0]
# output = {ort_session.get_outputs()[0].name}
# ort_output = ort_session.run([output], ort_inputs)[0]

正文完
 0