Skip to content

Instantly share code, notes, and snippets.

@d0znpp
Created December 12, 2017 00:46
Show Gist options
  • Save d0znpp/49b165bf8c3e04c4fb8f50f5c70faaa0 to your computer and use it in GitHub Desktop.
Save d0znpp/49b165bf8c3e04c4fb8f50f5c70faaa0 to your computer and use it in GitHub Desktop.
with tf.Session() as train_sess:
init = tf.global_variables_initializer()
train_sess.run(init)
for step in range(self.max_step_per_action):
batch_x, batch_y = self.mnist.train.next_batch(self.bathc_size)
feed = {model.X: batch_x,
model.Y: batch_y,
model.dropout_keep_prob: self.dropout_rate,
model.cnn_dropout_rates: cnn_drop_rate}
_ = train_sess.run(train_op, feed_dict=feed)
batch_x, batch_y = self.mnist.test.next_batch(10000)
loss, acc = train_sess.run(
[loss_op, model.accuracy],
feed_dict={model.X: batch_x,
model.Y: batch_y,
model.dropout_keep_prob: 1.0,
model.cnn_dropout_rates: [1.0]*len(cnn_drop_rate)})
if acc - pre_acc <= 0.01:
return acc, acc
else:
return 0.1, acc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment