int8 进行网络传输
目的:
- 我们项目需要用到
instance segmentation
, 所以rcnn_mask
输出层数据量特别大, 同时因为图片尺寸有1024*1024*3
这么大. - 如果不压缩一下直接进行网络传输, 数据量会很大, 接口延迟很高.
- 为了部署后, 请求推理服务时, 数据输入层和输出层的传输量减少.
- 从
float32
转变成int8
, 传输量减少了 4 倍.
处理:
- 输入层:
先转换成 int8
, 然后再转成float32
def build(self, mode, config):
input_image = KL.Input(shape=[None, None, config.IMAGE_SHAPE[2]], name="input_image",dtype=tf.int8)
input_image = KL.Input(tensor=K.cast(input_image,dtype= tf.float32))
-
输出层:
mrcnn_mask = build_fpn_mask_graph(detection_boxes, mrcnn_feature_maps, input_image_meta, config.MASK_POOL_SIZE, config.NUM_CLASSES,
模型导出时加上去
mrcnn_mask = KL.Lambda(lambda x: x*255)(mrcnn_mask)
mrcnn_mask = KL.Lambda(lambda x: tf.cast(x,dtype=tf.int8))(mrcnn_mask)
云函数服务 拿到 `mrcnn_mask` 输出层数据后, 再转换成 `float32`
- 模型导出:
导出时注意 `output` 名称要对应.
def save_model():
config = QuesConfig()
PRETRAINED_MODEL_PATH = "/xxxxx/weights"
MODEL_NAME = 'xxxxx_0601.h5'
export_dir = "./saved_model/1/"
h5_to_saved_model(config, PRETRAINED_MODEL_PATH, MODEL_NAME, export_dir)
def h5_to_saved_model(config, model_dir, model_name, export_dir):
if tf.gfile.Exists(export_dir):
tf.gfile.DeleteRecursively(export_dir)
config.display()
model = modellib.MaskRCNN(mode="inference", config=config, model_dir=model_dir)
model_path = os.path.join(model_dir, model_name)
model.load_weights(model_path, by_name=True)
with K.get_session() as session:
save_m(export_dir, session)
def save_m(export_dir, session):
if tf.gfile.Exists(export_dir):
tf.gfile.DeleteRecursively(export_dir)
# legacy_init_op = tf.group(tf.tables_initializer())
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
signature_inputs = {
"input_image":
session.graph.get_tensor_by_name("input_image:0"),
"input_image_meta":
session.graph.get_tensor_by_name("input_image_meta:0"),
"input_anchors":
session.graph.get_tensor_by_name("input_anchors:0"),
}
signature_outputs = {
'mrcnn_detection':
session.graph.get_tensor_by_name('mrcnn_detection/Reshape_1:0'),
'mrcnn_mask':
session.graph.get_tensor_by_name('lambda_5/Cast:0')
}
sigs = {}
sigs['serving_default'] = tf.saved_model.signature_def_utils.predict_signature_def(
inputs=signature_inputs,
outputs=signature_outputs)
builder.add_meta_graph_and_variables(
session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map=sigs
)
builder.save()