当我们拿到别人训练后保存的模型文件后,如果需要通过C++接口部署模型的话,一般情况下都需要将模型固化并保存为pb格式。Tensorflow提供了相关的固化命令脚本,下面以 .meta 格式的固化为例说明使用的方式。
1、.index、.meta、.data文件固化
(1)文件内容
用tf.train.Saver.save()方式保存下来的checkpoint会产生四个文件:
- checkpoint:记录了已存储(部分)和最近存储的模型:
- model.ckpt.data-00000-of-00001:保存了模型的所有变量的值,TensorBundle集合。
- model.ckpt.index:为一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每个BundleEntryProto表述了tensor的metadata,比如那个data文件包含tensor、文件中的偏移量、一些辅助数据等。
- model.ckpt.meta:保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值。
(2)固化命令
当需要将模型文件固化为pb文件时,可以使用以下的方式,其中的freeze_graph.py是tensorflow源码提供的脚本工具:
python tensorflow/python/tools/freeze_graph.py \
--input_meta_graph=model.ckpt.meta \
--input_checkpoint=model.ckpt \
--output_graph=frozen_graph_meta.pb \
--output_node_name=embeddings \
--input_binary=True
可能遇到的错误:
UnicodeDecoderError: 'utf-8' codec can't decode byte 0xd8 in position 1: invalid continuation byte
解决方法是传入参数 --input_binary=True
2、output_node_name的确定
(1)输出所有op的名字
import tensorflow as tf
import os
def dump_op_name(ckpt_path):
saver = tf.train.import_meta_graph(ckpt_path + '.meta', clear_devices=True) # 从.meta加载图结构
graph = tf.get_default_graph() # 设置为session默认的图结构
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer()) #全局变量初始化
saver.restore(sess, ckpt_path) # 导入ckpt数据
ops = [op for op in sess.graph.get_operations()] #获取所有的op
for op in ops:
print(op.name)
if __name__ == '__main__':
dump_op_name('./best-m-334000')
只输出了op的名字,对全局的图结构没有直观的体现,如果需要分析整体的图结构,可以使用tensorboard工具。
(2)使用TensorBoard查看
使用TensorBoard之前,需要先从模型文件meta或者pb中提取出graph的信息。
- 从meta提取
import tensorflow as tf
import os
def write_graph_log(meta_file, log_dir):
if not os.path.exists(log_dir):
os.mkdir(log_dir)
g = tf.Graph()
with g.as_default() as g:
tf.train.import_meta_graph(meta_file)
with tf.Session(graph=g) as sess:
tf.summary.FileWriter(logdir=log_dir, graph=g)
if __name__ == '__main__':
write_graph_log('best-m-334000.meta', './log/')
- 从pb提取
tensorflow源码中提供了转换的脚本工具:tensorflow/tensorflow/python/tools/import_pb_to_tensorboard.py,只需要执行
python import_pb_to_tensorboard.py \
--model_dir=model.pb \
--log_dir=log
运行脚本后在log目录下生成events.out.tfevents.xxx文件,通过cmd启动Tensorboard:
cd model_dir
tensorboard --logdir=log
浏览器打开显示的网址,通过可视化图结构可以清楚地看到输入和输出节点的名字。