Saver方法
tf.train.Saver().save()导出的checkpoint文件中是模型graph结构和权重。实际部署中都是使用pb格式文件,而这两者是可以相互转化的,重点是冻结需要保存的节点。
模型保存
saver = tf.train.Saver() saver.save(sess,"model.ckpt")
checkpoint目录下有四个文件:
- model_ckpt.data 保存权重
- .meta 保存计算图的结构信息
- .index 保存表
- checkpoint 文件的model_checkpoit_path值觉得restore文件路径名
加载模型
saver.restore(sess,"model.ckpt")
Checkpoint文件转pb格式文件
checkpoint保存的graph和权重是分离的文件。可以使用TensoFlowconvert_variables_to_constants()方法固化模型结构,这样可以在不同语言和不同平台间移植训练的模型。
读取checkpoint文件获得graph节点名称
from tensorflow.python import pywrap_tensorflow import os checkpoint_path=os.path.join('model.ckpt-path') reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map=reader.get_variable_to_shape_map() for key in var_to_shape_map: print 'tensor_name: ',key
固化模型得到pb格式文件
import tensorflow as tf from tensorflow.python.framework import graph_util def freeze_graph(model_folder, output_graph): """ -通过传入 CKPT 模型的路径得到模型的图和变量数据 -通过 import_meta_graph 导入模型中的图 -通过 saver.restore 从模型中恢复图中各个变量的数据 -通过 graph_util.convert_variables_to_constants 将模型持久化 :param input_checkpoint: :param output_graph: PB模型保存路径 :return: """ checkpoint = tf.train.get_checkpoint_state(model_folder) # 检查目录下ckpt文件状态是否可用 input_checkpoint = checkpoint.model_checkpoint_path # 得ckpt文件路径 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点 output_node_names = ["input_x", "keep_prob", "score/ArgMax", "score/Softmax"] saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True) with tf.Session() as sess: saver.restore(sess, input_checkpoint) # 恢复图并得到数据 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定 sess=sess, input_graph_def=sess.graph_def, # 等于:sess.graph_def output_node_names=output_node_names # 如果有多个输出节点,以逗号隔开 ) with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型 f.write(output_graph_def.SerializeToString()) # 序列化输出 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点 def check_freeze_graph(pb_path): with tf.Graph().as_default(): output_graph_def = tf.GraphDef() with open(pb_path, "rb") as f: output_graph_def.ParseFromString(f.read()) tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: sess.run(tf.global_variables_initializer()) if __name__ == '__main__': # 输入ckpt模型路径 input_checkpoint = 'checkpoints/textcnn' # 输出pb模型的路径 out_pb_path = "checkpoints/frozen_model.pb" # 调用freeze_graph将ckpt转为pb freeze_graph(input_checkpoint, out_pb_path)