Skip to content

Instantly share code, notes, and snippets.

@akesling
Last active August 15, 2024 03:08
Show Gist options
  • Save akesling/5358964 to your computer and use it in GitHub Desktop.
Save akesling/5358964 to your computer and use it in GitHub Desktop.
MNist loading helper for Python 2.7. For Python 3.x, see https://gist.github.com/akesling/42393ccb868125071fdea77d98a0d2f0
import os
import struct
import numpy as np
"""
MNist loading helper for Python 2.7.
For Python 3.x, see https://gist.github.com/akesling/42393ccb868125071fdea77d98a0d2f0
Loosely inspired by http://abel.ee.ucla.edu/cvxopt/_downloads/mnist.py
which is GPL licensed.
"""
def read(dataset = "training", path = "."):
"""
Python function for importing the MNIST data set. It returns an iterator
of 2-tuples with the first element being the label and the second element
being a numpy.uint8 2D array of pixel data for the given image.
"""
if dataset is "training":
fname_img = os.path.join(path, 'train-images-idx3-ubyte')
fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')
elif dataset is "testing":
fname_img = os.path.join(path, 't10k-images-idx3-ubyte')
fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')
else:
raise ValueError, "dataset must be 'testing' or 'training'"
# Load everything in some numpy arrays
with open(fname_lbl, 'rb') as flbl:
magic, num = struct.unpack(">II", flbl.read(8))
lbl = np.fromfile(flbl, dtype=np.int8)
with open(fname_img, 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
img = np.fromfile(fimg, dtype=np.uint8).reshape(len(lbl), rows, cols)
get_img = lambda idx: (lbl[idx], img[idx])
# Create an iterator which returns each image in turn
for i in xrange(len(lbl)):
yield get_img(i)
def show(image):
"""
Render a given numpy.uint8 2D array of pixel data.
"""
from matplotlib import pyplot
import matplotlib as mpl
fig = pyplot.figure()
ax = fig.add_subplot(1,1,1)
imgplot = ax.imshow(image, cmap=mpl.cm.Greys)
imgplot.set_interpolation('nearest')
ax.xaxis.set_ticks_position('top')
ax.yaxis.set_ticks_position('left')
pyplot.show()
@Shunxintime
Copy link

@Jae1015, the dataset in the origin website are named as 'train-labels.idx1-ubyte' . Please pay attention to the dot and the slash.

@yarcowang
Copy link

Just a notice for running under Python 3, you should change those lines:

raise ValueError("dataset must be 'testing' or 'training'") # lineno:24

for i in range(len(lbl)): # lineno:38

@shm007g
Copy link

shm007g commented Nov 22, 2018

this script work well i think, for plotting right digit.

however, this is wired.
seems different loading tools(https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/src/mnist_loader.py) make result different for me when testing network here(https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/src/network.py).

@YidongEric
Copy link

I got this error, who can help me? Very Thanks.

raise ValueError, "dataset must be 'testing' or 'training'"

@akesling
Copy link
Author

For all those who find this and want something working on Python 3.x, I've created an updated gist: https://gist.github.com/akesling/42393ccb868125071fdea77d98a0d2f0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment