mask-rcnn-部署小技巧

32次阅读

共计 2117 个字符,预计需要花费 6 分钟才能阅读完成。

int8 进行网络传输

目的:

  1. 我们项目需要用到 instance segmentation, 所以 rcnn_mask 输出层数据量特别大, 同时因为图片尺寸有 1024*1024*3 这么大.
  2. 如果不压缩一下直接进行网络传输, 数据量会很大, 接口延迟很高.
  3. 为了部署后, 请求推理服务时, 数据输入层和输出层的传输量减少.
  4. float32 转变成 int8, 传输量减少了 4 倍.

处理:

  • 输入层:

先转换成 int8, 然后再转成float32

 def build(self, mode, config):
         input_image = KL.Input(shape=[None, None, config.IMAGE_SHAPE[2]], name="input_image",dtype=tf.int8)
        input_image = KL.Input(tensor=K.cast(input_image,dtype= tf.float32))
        
  • 输出层:

       mrcnn_mask = build_fpn_mask_graph(detection_boxes, mrcnn_feature_maps,
                                                 input_image_meta,
                                                 config.MASK_POOL_SIZE,
                                                 config.NUM_CLASSES,

模型导出时加上去

        mrcnn_mask = KL.Lambda(lambda x: x*255)(mrcnn_mask)
        mrcnn_mask = KL.Lambda(lambda x: tf.cast(x,dtype=tf.int8))(mrcnn_mask)

云函数服务 拿到 `mrcnn_mask` 输出层数据后, 再转换成 `float32`


- 模型导出:
导出时注意 `output` 名称要对应.

def save_model():

config = QuesConfig()
PRETRAINED_MODEL_PATH = "/xxxxx/weights"
MODEL_NAME = 'xxxxx_0601.h5'
export_dir = "./saved_model/1/"
h5_to_saved_model(config, PRETRAINED_MODEL_PATH, MODEL_NAME, export_dir)

def h5_to_saved_model(config, model_dir, model_name, export_dir):

if tf.gfile.Exists(export_dir):
    tf.gfile.DeleteRecursively(export_dir)
config.display()

model = modellib.MaskRCNN(mode="inference", config=config, model_dir=model_dir)
model_path = os.path.join(model_dir, model_name)
model.load_weights(model_path, by_name=True)
with K.get_session() as session:
    save_m(export_dir, session)

def save_m(export_dir, session):

if tf.gfile.Exists(export_dir):
    tf.gfile.DeleteRecursively(export_dir)
# legacy_init_op = tf.group(tf.tables_initializer())
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
signature_inputs = {
    "input_image":
        session.graph.get_tensor_by_name("input_image:0"),
    "input_image_meta":
        session.graph.get_tensor_by_name("input_image_meta:0"),
    "input_anchors":
        session.graph.get_tensor_by_name("input_anchors:0"),
}
signature_outputs = {
    'mrcnn_detection':
        session.graph.get_tensor_by_name('mrcnn_detection/Reshape_1:0'),
    'mrcnn_mask':
        session.graph.get_tensor_by_name('lambda_5/Cast:0')
}

sigs = {}
sigs['serving_default'] = tf.saved_model.signature_def_utils.predict_signature_def(
    inputs=signature_inputs,
    outputs=signature_outputs)
builder.add_meta_graph_and_variables(
    session,
    [tf.saved_model.tag_constants.SERVING],
    signature_def_map=sigs

)
builder.save()

正文完
 0