-
-
Save rookiepig/1435cd881ceba782831d61f7d2e3147b to your computer and use it in GitHub Desktop.
Example of benchmarking session.run call
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Example of profiling session.run overhead | |
# for python profiling | |
# python -m cProfile -o session-run-benchmark-feed.prof session-run-benchmark.py feed_dict | |
# python -m cProfile -o session-run-benchmark-variable.prof session-run-benchmark.py variable | |
# pip install snakeviz | |
# snakeviz session-run-benchmark-feed.prof | |
# snakeviz session-run-benchmark.prof | |
# | |
# | |
# Feed_dict: 147 usec, no feed dict, 71 usec | |
import tensorflow as tf | |
import numpy as np | |
import time, sys, os | |
# make sure our ops aren't getting optimized away | |
config = tf.ConfigProto(graph_options=tf.GraphOptions(optimizer_options=tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0))) | |
sess = tf.Session(config=config) | |
n = 1024 | |
x0 = np.random.random([1, n]) | |
x = tf.placeholder(tf.float32, shape=x0.shape) | |
x_cached = tf.Variable(x0) | |
simple_op_feed_dict = tf.square(x) | |
simple_op = tf.square(x_cached) | |
sess.run(tf.global_variables_initializer()) | |
num_iters = 100000 | |
use_feed_dict = True | |
timelines = False | |
if sys.argv[1] == 'feed_dict': | |
use_feed_dict = True | |
elif sys.argv[1] == 'variable': | |
use_feed_dict = False | |
else: | |
print("Error") | |
if len(sys.argv)>2: | |
assert sys.argv[2] == 'timelines' | |
timelines = True | |
if timelines: | |
run_metadata = tf.RunMetadata() | |
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) | |
ss = tf.contrib.stat_summarizer.NewStatSummarizer(tf.get_default_graph().as_graph_def().SerializeToString()) | |
for i in range(num_iters//10): | |
if use_feed_dict: | |
sess.run(simple_op_feed_dict.op, feed_dict={x: x0}, | |
options=run_options, | |
run_metadata = run_metadata) | |
else: | |
sess.run(simple_op.op, | |
options=run_options, | |
run_metadata = run_metadata) | |
ss.ProcessStepStatsStr(run_metadata.step_stats.SerializeToString()) | |
print(ss.GetOutputString()) | |
sys.exit() | |
if use_feed_dict: | |
for i in range(num_iters): | |
sess.run(simple_op_feed_dict.op, feed_dict={x: x0}) | |
else: | |
for i in range(num_iters): | |
sess.run(simple_op.op) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment