如果后续要在移动端部署,将模型转换成tflite,那么训练完的模型导出时需要调用文件object_detection下的export_tflite_ssd_graph.py:
python export_tflite_ssd_graph.py --input_type image_tensor --pipeline_config_path training/ssdlite_mobilenet_v3_small_320x320_coco.config --trained_checkpoint_prefix training/model.ckpt-80000 --output_directory zuanjing_tflite_inference_graph首先下载的是3.1.0版本的bazel(3.5.0版本的试了出错不知道为什么?),打开https://github.com/bazelbuild/bazel/releases/download/3.1.0/bazel-3.1.0-installer-linux-x86_64.sh,下载好后cd到bazel-3.1.0-installer-linux-x86_64.sh所在的目录下,依次运行
chmod +x bazel-3.1.0-installer-linux-x86_64.sh ./bazel-3.1.0-installer-linux-x86_64.sh --user sudo gedit ~/.bashrc # 在文件最后面添加路径 export PATH="$PATH:$HOME/bin" source ~/.bashrcbazel安装好后,用来编译 tensorflow 转 tflite 时用到几个工具,freeze、toco、summarize_graph,这些工具都在 tensorflow(从github上clone) 中,按下面命令进行编译(在下载的tensorflow目录下运行):
bazel build tensorflow/tools/graph_transforms:summarize_graph配置 pb_to_tflite.py 脚本,如下
# -*- coding:utf-8 -*- import tensorflow as tf in_path = "/home/jiaoda/PycharmProjects/tensorflow/models/research/object_detection/tflite_inference_graph/tflite_graph.pb" # 模型输入节点 input_tensor_name = ["normalized_input_image_tensor"] input_tensor_shape = {"normalized_input_image_tensor":[1,320,320,3]} # 模型输出节点 classes_tensor_name = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'] converter = tf.lite.TFLiteConverter.from_frozen_graph(in_path, input_tensor_name, classes_tensor_name, input_tensor_shape) converter.allow_custom_ops = True #converter.post_training_quantize = True # 该行目的是决定是否输出量化的tflite模型 tflite_model = converter.convert() open("ssdlite_mobilenet_v3_small_320x320.tflite", "wb").write(tflite_model) print("done")参考博文:
Tensorflow 1.13训练模型.pb文件转换成Tensorflowlite可以使用的.tflite文件过程记录.linux源码安装bazel.Tensorflow 模型转 tflite ,在安卓端使用.