Skip to content

Instantly share code, notes, and snippets.

Last active September 11, 2024 08:28
Show Gist options
  • Save swyoon/8185b3dcf08ec728fb22b99016dd533f to your computer and use it in GitHub Desktop.
Save swyoon/8185b3dcf08ec728fb22b99016dd533f to your computer and use it in GitHub Desktop.
From numpy ndarray to tfrecords
import numpy as np
import tensorflow as tf
__author__ = "Sangwoong Yoon"
def np_to_tfrecords(X, Y, file_path_prefix, verbose=True):
Converts a Numpy array (or two Numpy arrays) into a tfrecord file.
For supervised learning, feed training inputs to X and training labels to Y.
For unsupervised learning, only feed training inputs to X, and feed None to Y.
The length of the first dimensions of X and Y should be the number of samples.
X : numpy.ndarray of rank 2
Numpy array for training inputs. Its dtype should be float32, float64, or int64.
If X has a higher rank, it should be rshape before fed to this function.
Y : numpy.ndarray of rank 2 or None
Numpy array for training labels. Its dtype should be float32, float64, or int64.
None if there is no label array.
file_path_prefix : str
The path and name of the resulting tfrecord file to be generated, without '.tfrecords'
verbose : bool
If true, progress is reported.
If input type is not float (64 or 32) or int.
def _dtype_feature(ndarray):
"""match appropriate tf.train.Feature class with dtype of ndarray. """
assert isinstance(ndarray, np.ndarray)
dtype_ = ndarray.dtype
if dtype_ == np.float64 or dtype_ == np.float32:
return lambda array: tf.train.Feature(float_list=tf.train.FloatList(value=array))
elif dtype_ == np.int64:
return lambda array: tf.train.Feature(int64_list=tf.train.Int64List(value=array))
raise ValueError("The input should be numpy ndarray. \
Instaed got {}".format(ndarray.dtype))
assert isinstance(X, np.ndarray)
assert len(X.shape) == 2 # If X has a higher rank,
# it should be rshape before fed to this function.
assert isinstance(Y, np.ndarray) or Y is None
# load appropriate tf.train.Feature class depending on dtype
dtype_feature_x = _dtype_feature(X)
if Y is not None:
assert X.shape[0] == Y.shape[0]
assert len(Y.shape) == 2
dtype_feature_y = _dtype_feature(Y)
# Generate tfrecord writer
result_tf_file = file_path_prefix + '.tfrecords'
writer = tf.python_io.TFRecordWriter(result_tf_file)
if verbose:
print "Serializing {:d} examples into {}".format(X.shape[0], result_tf_file)
# iterate over each sample,
# and serialize it as ProtoBuf.
for idx in range(X.shape[0]):
x = X[idx]
if Y is not None:
y = Y[idx]
d_feature = {}
d_feature['X'] = dtype_feature_x(x)
if Y is not None:
d_feature['Y'] = dtype_feature_y(y)
features = tf.train.Features(feature=d_feature)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
if verbose:
print "Writing {} done!".format(result_tf_file)
## Test and Use Cases ##
# 1-1. Saving a dataset with input and label (supervised learning)
xx = np.random.randn(10,5)
yy = np.random.randn(10,1)
np_to_tfrecords(xx, yy, 'test1', verbose=True)
# 1-2. Check if the data is stored correctly
# open the saved file and check the first entries
for serialized_example in tf.python_io.tf_record_iterator('test1.tfrecords'):
example = tf.train.Example()
x_1 = np.array(example.features.feature['X'].float_list.value)
y_1 = np.array(example.features.feature['Y'].float_list.value)
# the numbers may be slightly different because of the floating point error.
print xx[0]
print x_1
print yy[0]
print y_1
# 2. Saving a dataset with only inputs (unsupervised learning)
xx = np.random.randn(100,100)
np_to_tfrecords(xx, None, 'test2', verbose=True)
Copy link

nairouz commented May 23, 2018

Thank you.

Copy link

filmo commented Jun 13, 2018

Might want to extend _dtype_feature to recognize unit8 which is a common image datatype. Then pass this as a bytes_list.


storing images in as uint8 will be 4x smaller than float32.

Copy link

chisnova commented Jul 5, 2018

Very good code for tfrecord tutorial thank you

Copy link

TyJK commented Aug 26, 2018

Thank you. I've been tearing my hair out trying to confirm that my data/labels were written to the record properly, and this made it so much easier.

Copy link

Very useful code. Thank you!

Copy link

ypxie commented Dec 23, 2019

Very useful code. not sure why returning a anonymous function from _dtype_feature works?

Copy link

tchaye59 commented May 4, 2020

Thank you very much. But your code is very slow for huge datasets I think this is because you write the records one by one. I don't know if it is possible to write the whole data using tf.train.SequenceExample?

Copy link

If someone is looking for a faster solution check my Kaggle kernel

Copy link

If someone is looking for a faster solution check my Kaggle kernel

Well that notebook(version 17) has an error.
But I'm glad that someone is working on a faster version :-)

Copy link

Wait for 10 min I just sent a new commit

Copy link

Thanks for this.

Please, after saving test1, how can I recover the whole array, i.e xx and not just x_1? Thanks in advance.

Copy link

What about just using
tensor = tf.convert_to_tensor(array) result =

Copy link

patmorli commented Feb 4, 2021

How would 1-2 look like with the tf updates? tf_record_iterator does not work anymore, and I can't find out how to recover the file.

Copy link

why is it important for len(X.shape) == 2?

Copy link

thanks for your sharing! may i know what if i have X with higher dimenions? is that still possible to covert them to tfrecords? thanks!

Copy link

swyoon commented Jan 22, 2022

It's been so long since I wrote this code and personally, I have moved to pytorch.
I have no idea if this would work for tensors with a higher rank.
My personal guess is it should probably work, if you comment out the assertions (L45, L53).
Could you try it? @songssssss

Copy link

thanks for your reply! i tried to use instead and it solved my memory issue

Copy link

MohammedHAlali commented Sep 11, 2024

Thanks for the code. I'm using TF version 2.15.0
I received this error:
module 'tensorflow' has no attribute 'python_io'
I fixed it by changing tf.python_io to However, I couldn't find a replacement of "tf_record_itorator", but I guess it is not needed.

I also received another error:

    return lambda array: tf.train.Feature(int64_list=tf.train.Int64List(value=array))
TypeError: Value must be iterable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment