四时宝库

程序员的知识宝库

tensorflow模型权重保存和加载(tensorflow模型保存与加载)

Saver方法

tf.train.Saver().save()导出的checkpoint文件中是模型graph结构和权重。实际部署中都是使用pb格式文件,而这两者是可以相互转化的,重点是冻结需要保存的节点。

模型保存

saver = tf.train.Saver()
saver.save(sess,"model.ckpt")

checkpoint目录下有四个文件:

  • model_ckpt.data 保存权重
  • .meta 保存计算图的结构信息
  • .index 保存表
  • checkpoint 文件的model_checkpoit_path值觉得restore文件路径名

加载模型

saver.restore(sess,"model.ckpt")

Checkpoint文件转pb格式文件

checkpoint保存的graph和权重是分离的文件。可以使用TensoFlowconvert_variables_to_constants()方法固化模型结构,这样可以在不同语言和不同平台间移植训练的模型。

读取checkpoint文件获得graph节点名称

from tensorflow.python import pywrap_tensorflow
import os
checkpoint_path=os.path.join('model.ckpt-path')
reader=pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map=reader.get_variable_to_shape_map()
for key in var_to_shape_map:
 print 'tensor_name: ',key

固化模型得到pb格式文件

import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(model_folder, output_graph):
 """
 -通过传入 CKPT 模型的路径得到模型的图和变量数据
 -通过 import_meta_graph 导入模型中的图
 -通过 saver.restore 从模型中恢复图中各个变量的数据
 -通过 graph_util.convert_variables_to_constants 将模型持久化
 :param input_checkpoint:
 :param output_graph: PB模型保存路径
 :return:
 """
 checkpoint = tf.train.get_checkpoint_state(model_folder) # 检查目录下ckpt文件状态是否可用
 input_checkpoint = checkpoint.model_checkpoint_path # 得ckpt文件路径
 # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
 output_node_names = ["input_x", "keep_prob", "score/ArgMax", "score/Softmax"]
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
 with tf.Session() as sess:
 saver.restore(sess, input_checkpoint) # 恢复图并得到数据
 output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
 sess=sess,
 input_graph_def=sess.graph_def, # 等于:sess.graph_def
 output_node_names=output_node_names # 如果有多个输出节点,以逗号隔开
 )
 with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
 f.write(output_graph_def.SerializeToString()) # 序列化输出
 print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点

def check_freeze_graph(pb_path):
 with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()
 with open(pb_path, "rb") as f:
 output_graph_def.ParseFromString(f.read())
 tf.import_graph_def(output_graph_def, name="")
 with tf.Session() as sess:
 sess.run(tf.global_variables_initializer())

if __name__ == '__main__':
 # 输入ckpt模型路径
 input_checkpoint = 'checkpoints/textcnn'
 # 输出pb模型的路径
 out_pb_path = "checkpoints/frozen_model.pb"
 # 调用freeze_graph将ckpt转为pb
 freeze_graph(input_checkpoint, out_pb_path)

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接