欲乎其上 得乎其中 欲乎其中 得乎其下
在实际生产环境中,一般将模型训练和模型使用分离开,先从一套环境中训练好模型,再导出到另一套环境中使用。模型训练是参数不断更新收敛过程,是数据密集型和计算密集型过程,数据量大、计算能力要求高、训练时间长。模型应用不涉及参数不断更新收敛过程,对计算能力要求减弱、输入实际数据、可以快速出结果,可以跑在DC服务器集群中,也可以跑在手机上。
一、模型保存
TF通过convert_variables_to_constants函数将数据流图中的变量值以常量形式保存,其中output_node_names参数指定要保存的node列表。通过write_graph函数将数据流图保存到pb文件中。
test-mnist-saver.py |
1 from tensorflow.examples.tutorials.mnist import input_data 2 from tensorflow.python.framework.graph_util import convert_variables_to_constants 3 4 import tensorflow as tf 5 6 mnist = input_data.read_data_sets("MNIST/", one_hot=True) 7 8 x = tf.placeholder(tf.float32, [None, 784], name='x') 9 W = tf.Variable(tf.zeros([784, 10]), name='w') 10 b = tf.Variable(tf.zeros([10]), name='b') 11 y = tf.add(tf.matmul(x, W), b, name='y') 12 y_ = tf.placeholder(tf.float32, [None, 10], name='y_') 13 14 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 15 labels=y_, logits=y)) 16 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 17 18 session = tf.Session() 19 session.run(tf.global_variables_initializer()) 20 21 for idx in range(12000): 22 batch_x, batch_y = mnist.train.next_batch(100) 23 session.run(train_step, feed_dict={x: batch_x, y_: batch_y}) 24 25 graph = convert_variables_to_constants(session, session.graph_def, 26 output_node_names=['y']) 27 tf.train.write_graph(graph, '.', 'GRAPH/graph.pb', as_text=False) 28 29 session.close() |
保存模型文件路径:GRAPH/graph.pb |
二、模型恢复:
TF通过import_graph_def函数从pb模型文件中恢复模型,其中input_map参数指定placeholder(node:0),return_elements参数指定恢复的tensor(node:0)。
test-mnist-restorer.py |
1 from tensorflow.examples.tutorials.mnist import input_data 2 import tensorflow as tf 3 4 mnist = input_data.read_data_sets("MNIST/", one_hot=True) 5 6 session = tf.Session() 7 session.run(tf.global_variables_initializer()) 8 9 f = open('GRAPH/graph.pb', 'rb') 10 graph = tf.GraphDef() 11 graph.ParseFromString(f.read()) 12 x = tf.placeholder(tf.float32, [None, 784]) 13 y = tf.import_graph_def(graph, input_map={'x:0': x}, return_elements=['y:0']) 14 y = y[0] 15 y_ = tf.placeholder(tf.float32, [None, 10]) 16 17 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 18 19 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 20 print session.run(accuracy, 21 feed_dict={x: mnist.test.images, y_: mnist.test.labels}) 22 23 session.close() |
测试准确率: 0.9125 |
未名小宇宙,一直在您身边,欢迎关注!