关于前端:如何将训练好的Python模型给JavaScript使用

34次阅读

共计 6963 个字符,预计需要花费 18 分钟才能阅读完成。

前言

&nbsp&nbsp&nbsp&nbsp&nbsp&nbsp 从后面的 Tensorflow 环境搭建到指标检测模型迁徙学习,曾经实现了一个简答的扑克牌检测器,不论是从图片还是视频都能从画面中辨认出有扑克的指标,并标识出扑克点数。然而,我想在想让他放在浏览器上可能理论应用,那么要如何让 Tensorflow 模型转换成 web 格局的呢?接下来将从实际的角度具体介绍一下部署办法!

环境

Windows10

Anaconda3

TensorFlow.js converter 

converter 介绍

&nbsp&nbsp&nbsp&nbsp&nbsp&nbspconverter 全名是 TensorFlow.js Converter,他能够将 TensorFlow GraphDef 模型(通过 Python API 创立的,能够先了解为 Python 模型) 转换成 Tensorflow.js 可读取的模型格局(json 格局), 用于在浏览器上对指定数据进行推算。
 

converter 装置

&nbsp&nbsp&nbsp&nbsp&nbsp&nbsp 为了不影响后面指标检测训练环境,这里我用 conda 创立了一个新的 Python 虚拟环境,Python 版本 3.6.8。在装置转换器的时候,如果以后环境没有 Tensorflow,默认会装置与 TF 相干的依赖,只须要进入指定虚拟环境,输出以下命令。1pip install tensorflowjs

converter 用法

tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model

1.&nbsp 产生的文件 (生成的 web 格局模型) 转换器命令执行后生产两种文件,别离是 model.json(数据流图和权重清单)和 group1-shard*of*(二进制权重文件)

2.&nbsp 输出的必要条件(命令参数和选项[带 – 为选项])converter 转换指令前面次要携带四个参数,别离是输出模型的格局,输入模型的格局,输出模型的门路,输入模型的门路,更多帮忙信息能够通过以下命令查看,另附命令合成图。1tensorflowjs_converter –help

2.1. –input_format 要转换的模型的格局,SavedModel 为 tf_saved_model, frozen model 为 tf_frozen_model, session bundle 为 tf_session_bundle, TensorFlow Hub module 为 tf_hub,Keras HDF5 为 keras。

2.2. –output_format 输入模型的格局, 别离有 tfjs_graph_model (tensorflow.js 图模型,保留后的 web 模型没有了再训练能力,适宜 SavedModel 输出格局转换),tfjs_layers_model(tensorflow.js 层模型,具备无限的 Keras 性能,不适宜 TensorFlow SavedModels 转换)。

2.3. input_pathsaved model, session bundle 或 frozen model 的残缺的门路,或 TensorFlow Hub 模块的门路。

2.4. output_path 输入文件的保留门路。

2.5. –saved_model_tags 只对 SavedModel 转换用的选项:输出须要加载的 MetaGraphDef 绝对应的 tag,多个 tag 请用逗号分隔。默认为 serve。

2.6. –signature_name 对 TensorFlow Hub module 和 SavedModel 转换用的选项:对应要加载的签名,默认为 default。

2.7. –output_node_names 输入节点的名字,每个名字用逗号拆散。

罕用的两组命令行 1.covert from saved_model tensorflowjs_converter --input_format=tf_saved_model --output_format=tfjs_graph_model --signature_name=serving_default --saved_model_tags=serve ./saved_model ./web_model 
2. convert from frozen_modeltensorflowjs_converter --input_format=tf_frozen_model --output_node_names='num_detections,detection_boxes,detection_scores,detection_classes' ./frozen_inference_graph.pb  ./web_modelk

开始实际

1.&nbsp 找到通过 export_inference_graph.py 导出的模型导出的模型在我的项目的 inference_graph 文件夹 (models\research\object_detection) 里,frozen_inference_graph.pb 是 tf_frozen_model 输出格局须要的,而 saved_model 文件夹就是 tf_saved_model 格局。在当前目录下新建 web_model 目录,用于存储转换后的 web 格局的模型。

2.&nbsp 开始转换在以后虚拟环境下,进入到 inference_graph 目录下,输出以下命令,之后就会在 web_model 生成一个 json 文件和多个权重文件。1tensorflowjs_converter –input_format=tf_saved_model –output_format=tfjs_graph_model –signature_name=serving_default –saved_model_tags=serve ./saved_model ./web_model

3. 浏览器端部署

3.1. 创立一个前端我的项目,将 web_model 放入其中。

3.2. 编写代码

<!doctype html><head>  <link rel="stylesheet" href="tfjs-examples.css" />  <style>  canvas {outline: orange 2px solid; margin: 10px 0;}  </style></head> <body>  <div class="tfjs-example-container centered-container">    <section class='title-area'>      <h1> 赌圣 2023</h1>    </section>    <p class='section-head'> 模型形容 </p>    <p> 我看你怎么出老千!</p>    <p class='section-head'> 模型状态 </p>    <div id="status"> 加载模型中...</div>    <div>      <p class='section-head'> 成果展现 </p>      <p></button><input type="file" accept="image/*" id="test"/></p>      <canvas id="data-canvas" width="300" height="1100"></canvas>    </div>  </div> </body> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script> <script>  const canvas = document.getElementById('data-canvas');  const status = document.getElementById('status');  const testModel = document.getElementById('test');   const BOUNDING_BOX_LINE_WIDTH = 3;  const BOUNDING_BOX_STYLE1 = 'rgb(0,0,255)';  const BOUNDING_BOX_STYLE2 = 'rgb(0,255,0)';   async function init() {     const LOCAL_MODEL_PATH = './web_model/model.json';     // 将本地模型保留到浏览器    // tf.sequential().save     // 加载本地模型    let model;    try {model = await tf.loadGraphModel(LOCAL_MODEL_PATH);      testModel.disabled = false;      status.textContent = '胜利加载本地模型!请亮出你的牌吧';             // 默认扑克牌      runAndVisualizeInference('./cam_image39.jpg', model)           } catch (err) {console.log('加载本地模型谬误:', err);      status.textContent = '加载本地模型失败';    }     testModel.addEventListener('change', (e) => {runAndVisualizeInference(e, model)    });} async function runAndVisualizeInference(e, model) {if (typeof e === 'string') {await new Promise((resolve, reject) => {// 图片显示在 canvas 中      var img = new Image;      img.src = e;      img.onload = function () {// 必须 onload 之后再画        let w = 500;        let h = img.height/img.width*500;        canvas.width = w;        canvas.height = h;        var ctx = canvas.getContext('2d');        ctx.drawImage(img,0,0,w,h);        resolve();}    })  } else {// 上传图片并显示在 canvas 中    var file = e.target.files[0];     if (!/image\/\w+/.test(file.type)) {alert("请确保文件为图像类型");      return false;    }    var reader = new FileReader();    reader.readAsDataURL(file); // 转化成 base64 数据类型    await new Promise((resolve, reject) => {reader.onload = function (e) {// 图片显示在 canvas 中        var img = new Image;        img.src = this.result;        img.onload = function () {// 必须 onload 之后再画          let w = 500;          let h = img.height/img.width*500;          canvas.width = w;          canvas.height = h;          var ctx = canvas.getContext('2d');          ctx.drawImage(img,0,0,w,h);          resolve();}      }    })  }   // 模型输出解决  let image = tf.browser.fromPixels(canvas);  const t4d = image.expandDims(0);   const outputDim = ['num_detections', 'detection_boxes', 'detection_scores',    'detection_classes'];     const labelMap = {1: '九点',    2: '十点',    3: 'Jack',    4: 'Queen',    5: 'King',    6: 'Ace'}     let modelOut = {}, boxes = [], w = canvas.width, h = canvas.height;  console.log(model)     for (const dim of outputDim) {let tensor = await model.executeAsync({      'image_tensor': t4d}, `${dim}:0`);    modelOut[dim] = await tensor.data();}  console.log(modelOut)     for (let i=0; i<modelOut['detection_scores'].length; i++) {const score = modelOut['detection_scores'][i];       if (score < 0.5) break; // 置信度过滤       boxes.push({ymin: modelOut['detection_boxes'][i*4]*h,      xmin: modelOut['detection_boxes'][i*4+1]*w,      ymax: modelOut['detection_boxes'][i*4+2]*h,      xmax: modelOut['detection_boxes'][i*4+3]*w,      label: labelMap[modelOut['detection_classes'][i]],    })  }     console.log(boxes)   // 可视化检测框  drawBoundingBoxes(canvas, boxes);   // 张量运行内存革除  tf.dispose([image, modelOut]);} function drawBoundingBoxes(canvas, predictBoundingBoxArr) {for (const box of predictBoundingBoxArr) {let left = box.xmin;    let right = box.xmax;    let top = box.ymin;    let bottom = box.ymax;     const ctx = canvas.getContext('2d');    ctx.beginPath();    ctx.strokeStyle = box.label==='ZERO_DEV'?BOUNDING_BOX_STYLE1:BOUNDING_BOX_STYLE2;    ctx.lineWidth = BOUNDING_BOX_LINE_WIDTH;    ctx.moveTo(left, top);    ctx.lineTo(right, top);    ctx.lineTo(right, bottom);    ctx.lineTo(left, bottom);    ctx.lineTo(left, top);    ctx.stroke();     ctx.font = '24px Arial bold';    ctx.fillStyle = box.label==='zfc'?BOUNDING_BOX_STYLE2:BOUNDING_BOX_STYLE1;    ctx.fillText(box.label, left+8, top+8);  }} init(); </script>

3.3. 运行后果

  

正文完
 0