四时宝库

程序员的知识宝库

TensorFlow模型保存与恢复(tensorflow模型保存与加载)

欲乎其上 得乎其中 欲乎其中 得乎其下


在实际生产环境中,一般将模型训练和模型使用分离开,先从一套环境中训练好模型,再导出到另一套环境中使用。模型训练是参数不断更新收敛过程,是数据密集型和计算密集型过程,数据量大、计算能力要求高、训练时间长。模型应用不涉及参数不断更新收敛过程,对计算能力要求减弱、输入实际数据、可以快速出结果,可以跑在DC服务器集群中,也可以跑在手机上。

一、模型保存

TF通过convert_variables_to_constants函数将数据流图中的变量值以常量形式保存,其中output_node_names参数指定要保存的node列表。通过write_graph函数将数据流图保存到pb文件中。

test-mnist-saver.py

1 from tensorflow.examples.tutorials.mnist import input_data

2 from tensorflow.python.framework.graph_util import convert_variables_to_constants

3

4 import tensorflow as tf

5

6 mnist = input_data.read_data_sets("MNIST/", one_hot=True)

7

8 x = tf.placeholder(tf.float32, [None, 784], name='x')

9 W = tf.Variable(tf.zeros([784, 10]), name='w')

10 b = tf.Variable(tf.zeros([10]), name='b')

11 y = tf.add(tf.matmul(x, W), b, name='y')

12 y_ = tf.placeholder(tf.float32, [None, 10], name='y_')

13

14 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(

15 labels=y_, logits=y))

16 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

17

18 session = tf.Session()

19 session.run(tf.global_variables_initializer())

20

21 for idx in range(12000):

22 batch_x, batch_y = mnist.train.next_batch(100)

23 session.run(train_step, feed_dict={x: batch_x, y_: batch_y})

24

25 graph = convert_variables_to_constants(session, session.graph_def,

26 output_node_names=['y'])

27 tf.train.write_graph(graph, '.', 'GRAPH/graph.pb', as_text=False)

28

29 session.close()

保存模型文件路径:GRAPH/graph.pb

二、模型恢复

TF通过import_graph_def函数从pb模型文件中恢复模型,其中input_map参数指定placeholder(node:0),return_elements参数指定恢复的tensor(node:0)。

test-mnist-restorer.py

1 from tensorflow.examples.tutorials.mnist import input_data

2 import tensorflow as tf

3

4 mnist = input_data.read_data_sets("MNIST/", one_hot=True)

5

6 session = tf.Session()

7 session.run(tf.global_variables_initializer())

8

9 f = open('GRAPH/graph.pb', 'rb')

10 graph = tf.GraphDef()

11 graph.ParseFromString(f.read())

12 x = tf.placeholder(tf.float32, [None, 784])

13 y = tf.import_graph_def(graph, input_map={'x:0': x}, return_elements=['y:0'])

14 y = y[0]

15 y_ = tf.placeholder(tf.float32, [None, 10])

16

17 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

18

19 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

20 print session.run(accuracy,

21 feed_dict={x: mnist.test.images, y_: mnist.test.labels})

22

23 session.close()

测试准确率: 0.9125

未名小宇宙,一直在您身边,欢迎关注!

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接