说明:本文依据《Sklearn 与 TensorFlow 机器学习实用指南》完成,所有版权和解释权均归作者和翻译成员所有,我只是搬运和做注解。
进入第二部分深度学习
第十一章训练深层神经网络
在第十章以及之前tf练习中,训练的深度神经网络都只是简单的demo,如果增大数据量或是面对更多的特征,遇到的问题就会棘手起来。
- 梯度消失(梯度爆炸),这会影响深度神经网络,并使较低层难以训练
- 训练效率
- 包含数百万参数的模型将会有严重的过拟合训练集的风险
- 本章中,教程从解释梯度消失问题开始,并探讨解决这个问题的一些最流行的解决方案。
- 接下来讨论各种优化器,与普通梯度下降相比,它们可以加速大型模型的训练。
- 介绍大型神经网络正则化技术。
7.复用预训练层
教材里不提倡从零开始训练DNN,可以寻找现有的神经网络来解决类似的任务,然后复用该网络较低层数据,这就是迁移学习。
这样做的好处是可以加快训练速率。
如果新任务的输入图像与原始任务中使用的输入图像的大小不一致,则必须添加预处理步骤以将其大小调整为原始模型的预期大小。一般来说如果输入具有类似的低级层次的特征,则迁移学习将很好地工作。
8.复用Tensorflow模型
如果原始模型使用了Tensorflow进行训练,则可以较好地进行迁移。
- 载入图结构(Graph's structure)。使用import_meta_graph()函数完成,该函数能够将图操作载入默认图中,返回一个saver,用户可以使用并重载模型。需要注意的是要载入.meta文件。
reset_graph() saver=tf.train.import_meta_graph("./tf_logs/run-2019013101001/tensorflowmodel01clip.ckpt.meta")
- 获取训练所需的所有操作。如果不知道图结构,可以列出所有操作
for op in tf.get_default_graph().get_operations(): print(op.name)
结果会呈现一大串操作名称,也可以调用tensorboard来可视化。
可以看到整体结构
- 如果知道需要使用哪些操作,可以使用get_default_graph()下的get_tensor_by_name()和get_operation_by_name()获取
X = tf.get_default_graph().get_tensor_by_name("X:0") y = tf.get_default_graph().get_tensor_by_name("y:0") accuracy=tf.get_default_graph().get_tensor_by_name("eval/accuracy:0") training_op=tf.get_default_graph().get_operation_by_name("GradientDescent")
- 现在可以开始一个会话,重载模型并继续训练
reset_graph() import tensorflow as tf mnist=tf.keras.datasets.mnist.load_data() (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data() X_train = X_train.astype(np.float32).reshape(-1, 28 * 28) / 255.0 X_test = X_test.astype(np.float32).reshape(-1, 28 * 28) / 255.0 y_train = y_train.astype(np.int32) y_test = y_test.astype(np.int32) X_valid, X_train = X_train[:5000], X_train[5000:] y_valid, y_train = y_train[:5000], y_train[5000:] def shuffle_batch(X, y, batch_size): rnd_idx = np.random.permutation(len(X)) n_batches = len(X) // batch_size for batch_idx in np.array_split(rnd_idx, n_batches): X_batch, y_batch = X[batch_idx], y[batch_idx] yield X_batch, y_batch n_inputs = 28 * 28 # MNIST n_hidden1 = 300 n_hidden2 = 50 n_hidden3 = 50 n_hidden4 = 50 n_hidden5 = 50 n_outputs = 10 X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X") y = tf.placeholder(tf.int64, shape=(None), name="y") with tf.name_scope("dnn"): hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1") hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2") hidden3 = tf.layers.dense(hidden2, n_hidden3, activation=tf.nn.relu, name="hidden3") hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4") hidden5 = tf.layers.dense(hidden4, n_hidden5, activation=tf.nn.relu, name="hidden5") logits = tf.layers.dense(hidden5, n_outputs, name="outputs") with tf.name_scope("loss"): xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) loss = tf.reduce_mean(xentropy, name="loss") with tf.name_scope("eval"): correct = tf.nn.in_top_k(logits, y, 1) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy") learning_rate = 0.01 threshold = 1.0 optimizer = tf.train.GradientDescentOptimizer(learning_rate) grads_and_vars = optimizer.compute_gradients(loss) capped_gvs = [(tf.clip_by_value(grad, -threshold, threshold), var) for grad, var in grads_and_vars] training_op = optimizer.apply_gradients(capped_gvs) init = tf.global_variables_initializer() saver = tf.train.Saver() n_epochs = 20 batch_size = 200 with tf.Session() as sess: saver.restore(sess, "./tf_logs/run-2019013101001/tensorflowmodel01clip.ckpt") for epoch in range(n_epochs): for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size): sess.run(training_op, feed_dict={X: X_batch, y: y_batch}) accuracy_val = accuracy.eval(feed_dict={X: X_valid, y: y_valid}) print(epoch, "Validation accuracy:", accuracy_val) save_path = saver.save(sess, "./tf_logs/run-2019013101002/tensorflowmodel01reuse.ckpt") 0 Validation accuracy: 0.9614 1 Validation accuracy: 0.9622 2 Validation accuracy: 0.963 3 Validation accuracy: 0.9628 4 Validation accuracy: 0.9648 5 Validation accuracy: 0.9638 6 Validation accuracy: 0.9664 7 Validation accuracy: 0.967 8 Validation accuracy: 0.967 9 Validation accuracy: 0.9672 10 Validation accuracy: 0.969 11 Validation accuracy: 0.9694 12 Validation accuracy: 0.9642 13 Validation accuracy: 0.9676 14 Validation accuracy: 0.9708 15 Validation accuracy: 0.9694 16 Validation accuracy: 0.972 17 Validation accuracy: 0.9702 18 Validation accuracy: 0.9722 19 Validation accuracy: 0.9708
会发现正确率比之前的要高。