Skip to content

Instantly share code, notes, and snippets.

@mrajchl
Created June 8, 2018 15:56
Show Gist options
  • Save mrajchl/02018f165232e8e5b0882009e818e504 to your computer and use it in GitHub Desktop.
Save mrajchl/02018f165232e8e5b0882009e818e504 to your computer and use it in GitHub Desktop.
Feed .nii imaging data via a native python generator
# Generator function
def f():
fn = read_fn(file_references=all_filenames,
mode=tf.estimator.ModeKeys.TRAIN,
params=reader_params)
ex = next(fn)
# Yield the next image
yield ex
# Timed example with generator io
dataset = tf.data.Dataset.from_generator(
f, reader_example_dtypes, reader_example_shapes)
dataset = dataset.repeat(None)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
iterator = dataset.make_initializable_iterator()
next_dict = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess_gen:
# Initialize generator
sess_gen.run(iterator.initializer)
with Timer('Generator'):
for i in range(iterations):
# Fetch the next batch of images
gen_batch_feat, gen_batch_lbl = sess_gen.run([next_dict['features'], next_dict['labels']])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment