Last active
April 9, 2024 18:52
-
-
Save tencia/afb129122a64bde3bd0c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from PIL import Image | |
import sys | |
import os | |
import math | |
import numpy as np | |
########################################################################################### | |
# script to generate moving mnist video dataset (frame by frame) as described in | |
# [1] arXiv:1502.04681 - Unsupervised Learning of Video Representations Using LSTMs | |
# Srivastava et al | |
# by Tencia Lee | |
# saves in hdf5, npz, or jpg (individual frames) format | |
########################################################################################### | |
# helper functions | |
def arr_from_img(im,shift=0): | |
w,h=im.size | |
arr=im.getdata() | |
c = np.product(arr.size) / (w*h) | |
return np.asarray(arr, dtype=np.float32).reshape((h,w,c)).transpose(2,1,0) / 255. - shift | |
def get_picture_array(X, index, shift=0): | |
ch, w, h = X.shape[1], X.shape[2], X.shape[3] | |
ret = ((X[index]+shift)*255.).reshape(ch,w,h).transpose(2,1,0).clip(0,255).astype(np.uint8) | |
if ch == 1: | |
ret=ret.reshape(h,w) | |
return ret | |
# loads mnist from web on demand | |
def load_dataset(): | |
if sys.version_info[0] == 2: | |
from urllib import urlretrieve | |
else: | |
from urllib.request import urlretrieve | |
def download(filename, source='http://yann.lecun.com/exdb/mnist/'): | |
print("Downloading %s" % filename) | |
urlretrieve(source + filename, filename) | |
import gzip | |
def load_mnist_images(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=16) | |
data = data.reshape(-1, 1, 28, 28).transpose(0,1,3,2) | |
return data / np.float32(255) | |
return load_mnist_images('train-images-idx3-ubyte.gz') | |
# generates and returns video frames in uint8 array | |
def generate_moving_mnist(shape=(64,64), seq_len=30, seqs=10000, num_sz=28, nums_per_image=2): | |
mnist = load_dataset() | |
width, height = shape | |
lims = (x_lim, y_lim) = width-num_sz, height-num_sz | |
dataset = np.empty((seq_len*seqs, 1, width, height), dtype=np.uint8) | |
for seq_idx in xrange(seqs): | |
# randomly generate direc/speed/position, calculate velocity vector | |
direcs = np.pi * (np.random.rand(nums_per_image)*2 - 1) | |
speeds = np.random.randint(5, size=nums_per_image)+2 | |
veloc = [(v*math.cos(d), v*math.sin(d)) for d,v in zip(direcs, speeds)] | |
mnist_images = [Image.fromarray(get_picture_array(mnist,r,shift=0)).resize((num_sz,num_sz), Image.ANTIALIAS) \ | |
for r in np.random.randint(0, mnist.shape[0], nums_per_image)] | |
positions = [(np.random.rand()*x_lim, np.random.rand()*y_lim) for _ in xrange(nums_per_image)] | |
for frame_idx in xrange(seq_len): | |
canvases = [Image.new('L', (width,height)) for _ in xrange(nums_per_image)] | |
canvas = np.zeros((1,width,height), dtype=np.float32) | |
for i,canv in enumerate(canvases): | |
canv.paste(mnist_images[i], tuple(map(lambda p: int(round(p)), positions[i]))) | |
canvas += arr_from_img(canv, shift=0) | |
# update positions based on velocity | |
next_pos = [map(sum, zip(p,v)) for p,v in zip(positions, veloc)] | |
# bounce off wall if a we hit one | |
for i, pos in enumerate(next_pos): | |
for j, coord in enumerate(pos): | |
if coord < -2 or coord > lims[j]+2: | |
veloc[i] = tuple(list(veloc[i][:j]) + [-1 * veloc[i][j]] + list(veloc[i][j+1:])) | |
positions = [map(sum, zip(p,v)) for p,v in zip(positions, veloc)] | |
# copy additive canvas to data array | |
dataset[seq_idx*seq_len+frame_idx] = (canvas * 255).astype(np.uint8).clip(0,255) | |
return dataset | |
def main(dest, filetype='npz', frame_size=64, seq_len=30, seqs=100, num_sz=28, nums_per_image=2): | |
dat = generate_moving_mnist(shape=(frame_size,frame_size), seq_len=seq_len, seqs=seqs, \ | |
num_sz=num_sz, nums_per_image=nums_per_image) | |
n = seqs * seq_len | |
if filetype == 'hdf5': | |
import h5py | |
from fuel.datasets.hdf5 import H5PYDataset | |
def save_hd5py(dataset, destfile, indices_dict): | |
f = h5py.File(destfile, mode='w') | |
images = f.create_dataset('images', dataset.shape, dtype='uint8') | |
images[...] = dataset | |
split_dict = dict((k, {'images':v}) for k,v in indices_dict.iteritems()) | |
f.attrs['split'] = H5PYDataset.create_split_array(split_dict) | |
f.flush() | |
f.close() | |
indices_dict = {'train': (0, n*9/10), 'test': (n*9/10, n)} | |
save_hd5py(dat, dest, indices_dict) | |
elif filetype == 'npz': | |
np.savez(dest, dat) | |
elif filetype == 'jpg': | |
for i in xrange(dat.shape[0]): | |
Image.fromarray(get_picture_array(dat, i, shift=0)).save(os.path.join(dest, '{}.jpg'.format(i))) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser(description='Command line options') | |
parser.add_argument('--dest', type=str, dest='dest') | |
parser.add_argument('--filetype', type=str, dest='filetype') | |
parser.add_argument('--frame_size', type=int, dest='frame_size') | |
parser.add_argument('--seq_len', type=int, dest='seq_len') # length of each sequence | |
parser.add_argument('--seqs', type=int, dest='seqs') # number of sequences to generate | |
parser.add_argument('--num_sz', type=int, dest='num_sz') # size of mnist digit within frame | |
parser.add_argument('--nums_per_image', type=int, dest='nums_per_image') # number of digits in each frame | |
args = parser.parse_args(sys.argv[1:]) | |
main(**{k:v for (k,v) in vars(args).items() if v is not None}) |
Yes, thank you! Fixed.
This works great, thanks!
I'm using it with an implementation of video pixel networks, and will be happy to post code and results if everything works out.
Thanks for the gist, it helps a lot.
On line 77, clipping should precede type casting to avoid overflow of uint8
?
Thank you! Exactly what I needed.
Commented the code and made it Python 3 compatible. Also added training data set argument, here.
Thanks for the gist, was much needed!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
At line 61, it should be xrange(seq_len), not hard-coded xrange(30), right?