int8进行网络传输
目的:
- 我们项目需要用到
instance segmentation
, 所以rcnn_mask
输出层数据量特别大,同时因为图片尺寸有1024*1024*3
这么大. - 如果不压缩一下直接进行网络传输,数据量会很大,接口延迟很高.
- 为了部署后,请求推理服务时,数据输入层和输出层的传输量减少.
- 从
float32
转变成int8
,传输量减少了4倍.
处理:
- 输入层:
先转换成 int8
,然后再转成float32
def build(self, mode, config): input_image = KL.Input( shape=[None, None, config.IMAGE_SHAPE[2]], name="input_image",dtype=tf.int8) input_image = KL.Input(tensor=K.cast(input_image,dtype= tf.float32))
输出层:
mrcnn_mask = build_fpn_mask_graph(detection_boxes, mrcnn_feature_maps, input_image_meta, config.MASK_POOL_SIZE, config.NUM_CLASSES,
模型导出时加上去
mrcnn_mask = KL.Lambda(lambda x: x*255)(mrcnn_mask) mrcnn_mask = KL.Lambda(lambda x: tf.cast(x,dtype=tf.int8))(mrcnn_mask)
云函数服务 拿到`mrcnn_mask`输出层数据后,再转换成 `float32`- 模型导出:导出时注意 `output` 名称要对应.
def save_model():
config = QuesConfig()PRETRAINED_MODEL_PATH = "/xxxxx/weights"MODEL_NAME = 'xxxxx_0601.h5'export_dir = "./saved_model/1/"h5_to_saved_model(config, PRETRAINED_MODEL_PATH, MODEL_NAME, export_dir)
def h5_to_saved_model(config, model_dir, model_name, export_dir):
if tf.gfile.Exists(export_dir): tf.gfile.DeleteRecursively(export_dir)config.display()model = modellib.MaskRCNN(mode="inference", config=config, model_dir=model_dir)model_path = os.path.join(model_dir, model_name)model.load_weights(model_path, by_name=True)with K.get_session() as session: save_m(export_dir, session)
def save_m(export_dir, session):
if tf.gfile.Exists(export_dir): tf.gfile.DeleteRecursively(export_dir)# legacy_init_op = tf.group(tf.tables_initializer())builder = tf.saved_model.builder.SavedModelBuilder(export_dir)signature_inputs = { "input_image": session.graph.get_tensor_by_name("input_image:0"), "input_image_meta": session.graph.get_tensor_by_name("input_image_meta:0"), "input_anchors": session.graph.get_tensor_by_name("input_anchors:0"),}signature_outputs = { 'mrcnn_detection': session.graph.get_tensor_by_name('mrcnn_detection/Reshape_1:0'), 'mrcnn_mask': session.graph.get_tensor_by_name('lambda_5/Cast:0')}sigs = {}sigs['serving_default'] = tf.saved_model.signature_def_utils.predict_signature_def( inputs=signature_inputs, outputs=signature_outputs)builder.add_meta_graph_and_variables( session, [tf.saved_model.tag_constants.SERVING], signature_def_map=sigs)builder.save()