四时宝库

程序员的知识宝库

机器学习100天-Day2304 深层神经网络(复用模型)

说明:本文依据《Sklearn 与 TensorFlow 机器学习实用指南》完成,所有版权和解释权均归作者和翻译成员所有,我只是搬运和做注解。

进入第二部分深度学习

第十一章训练深层神经网络

在第十章以及之前tf练习中,训练的深度神经网络都只是简单的demo,如果增大数据量或是面对更多的特征,遇到的问题就会棘手起来。

  • 梯度消失(梯度爆炸),这会影响深度神经网络,并使较低层难以训练
  • 训练效率
  • 包含数百万参数的模型将会有严重的过拟合训练集的风险
  • 本章中,教程从解释梯度消失问题开始,并探讨解决这个问题的一些最流行的解决方案。
  • 接下来讨论各种优化器,与普通梯度下降相比,它们可以加速大型模型的训练。
  • 介绍大型神经网络正则化技术。

7.复用预训练层

教材里不提倡从零开始训练DNN,可以寻找现有的神经网络来解决类似的任务,然后复用该网络较低层数据,这就是迁移学习

这样做的好处是可以加快训练速率。

如果新任务的输入图像与原始任务中使用的输入图像的大小不一致,则必须添加预处理步骤以将其大小调整为原始模型的预期大小。一般来说如果输入具有类似的低级层次的特征,则迁移学习将很好地工作。

8.复用Tensorflow模型

如果原始模型使用了Tensorflow进行训练,则可以较好地进行迁移。

  • 载入图结构(Graph's structure)。使用import_meta_graph()函数完成,该函数能够将图操作载入默认图中,返回一个saver,用户可以使用并重载模型。需要注意的是要载入.meta文件。
reset_graph()
saver=tf.train.import_meta_graph("./tf_logs/run-2019013101001/tensorflowmodel01clip.ckpt.meta")
  • 获取训练所需的所有操作。如果不知道图结构,可以列出所有操作
for op in tf.get_default_graph().get_operations():
 print(op.name)

结果会呈现一大串操作名称,也可以调用tensorboard来可视化。

可以看到整体结构

  • 如果知道需要使用哪些操作,可以使用get_default_graph()下的get_tensor_by_name()和get_operation_by_name()获取
X = tf.get_default_graph().get_tensor_by_name("X:0")
y = tf.get_default_graph().get_tensor_by_name("y:0")
accuracy=tf.get_default_graph().get_tensor_by_name("eval/accuracy:0")
training_op=tf.get_default_graph().get_operation_by_name("GradientDescent")
  • 现在可以开始一个会话,重载模型并继续训练
reset_graph()
import tensorflow as tf
mnist=tf.keras.datasets.mnist.load_data()
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train.astype(np.float32).reshape(-1, 28 * 28) / 255.0
X_test = X_test.astype(np.float32).reshape(-1, 28 * 28) / 255.0
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]
def shuffle_batch(X, y, batch_size):
 rnd_idx = np.random.permutation(len(X))
 n_batches = len(X) // batch_size
 for batch_idx in np.array_split(rnd_idx, n_batches):
 X_batch, y_batch = X[batch_idx], y[batch_idx]
 yield X_batch, y_batch
n_inputs = 28 * 28 # MNIST
n_hidden1 = 300
n_hidden2 = 50
n_hidden3 = 50
n_hidden4 = 50
n_hidden5 = 50
n_outputs = 10
X = tf.placeholder(tf.float32, shape=(None, n_inputs), name="X")
y = tf.placeholder(tf.int64, shape=(None), name="y")
with tf.name_scope("dnn"):
 hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu, name="hidden1")
 hidden2 = tf.layers.dense(hidden1, n_hidden2, activation=tf.nn.relu, name="hidden2")
 hidden3 = tf.layers.dense(hidden2, n_hidden3, activation=tf.nn.relu, name="hidden3")
 hidden4 = tf.layers.dense(hidden3, n_hidden4, activation=tf.nn.relu, name="hidden4")
 hidden5 = tf.layers.dense(hidden4, n_hidden5, activation=tf.nn.relu, name="hidden5")
 logits = tf.layers.dense(hidden5, n_outputs, name="outputs")
with tf.name_scope("loss"):
 xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
 loss = tf.reduce_mean(xentropy, name="loss")
with tf.name_scope("eval"):
 correct = tf.nn.in_top_k(logits, y, 1)
 accuracy = tf.reduce_mean(tf.cast(correct, tf.float32), name="accuracy")
learning_rate = 0.01
threshold = 1.0
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss)
capped_gvs = [(tf.clip_by_value(grad, -threshold, threshold), var)
 for grad, var in grads_and_vars]
training_op = optimizer.apply_gradients(capped_gvs)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_epochs = 20
batch_size = 200
with tf.Session() as sess:
 saver.restore(sess, "./tf_logs/run-2019013101001/tensorflowmodel01clip.ckpt")
 for epoch in range(n_epochs):
 for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):
 sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
 accuracy_val = accuracy.eval(feed_dict={X: X_valid, y: y_valid})
 print(epoch, "Validation accuracy:", accuracy_val)
 save_path = saver.save(sess, "./tf_logs/run-2019013101002/tensorflowmodel01reuse.ckpt")
0 Validation accuracy: 0.9614
1 Validation accuracy: 0.9622
2 Validation accuracy: 0.963
3 Validation accuracy: 0.9628
4 Validation accuracy: 0.9648
5 Validation accuracy: 0.9638
6 Validation accuracy: 0.9664
7 Validation accuracy: 0.967
8 Validation accuracy: 0.967
9 Validation accuracy: 0.9672
10 Validation accuracy: 0.969
11 Validation accuracy: 0.9694
12 Validation accuracy: 0.9642
13 Validation accuracy: 0.9676
14 Validation accuracy: 0.9708
15 Validation accuracy: 0.9694
16 Validation accuracy: 0.972
17 Validation accuracy: 0.9702
18 Validation accuracy: 0.9722
19 Validation accuracy: 0.9708

会发现正确率比之前的要高。

发表评论:

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言
    友情链接