Created
March 23, 2016 15:19
-
-
Save EderSantana/fb66f36ab8577672ba3c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Sequential | |
from keras.layers.recurrent import Recurrent, GRU, LSTM | |
from keras import backend as K | |
# from seya.utils import rnn_states | |
tol = 1e-4 | |
def _wta(X): | |
M = K.max(X, axis=-1, keepdims=True) | |
R = K.switch(K.equal(X, M), X, 0.) | |
return R | |
def _update_controller(self, inp, h_tm1, M): | |
"""We have to update the inner RNN inside the NTM, this | |
is the function to do it. Pretty much copy+pasta from Keras | |
""" | |
x = K.concatenate([inp, M], axis=-1) | |
if isinstance(self.controller, Sequential): | |
h = self.controller(x) | |
h = [h, ] | |
else: | |
# update state | |
_, h = self.controller.step(x, h_tm1) | |
return h | |
def _circulant(leng, n_shifts): | |
""" | |
I confess, I'm actually proud of this hack. I hope you enjoy! | |
This will generate a tensor with `n_shifts` of rotated versions the | |
identity matrix. When this tensor is multiplied by a vector | |
the result are `n_shifts` shifted versions of that vector. Since | |
everything is done with inner products, everything is differentiable. | |
Paramters: | |
---------- | |
leng: int > 0, number of memory locations | |
n_shifts: int > 0, number of allowed shifts (if 1, no shift) | |
Returns: | |
-------- | |
shift operation, a tensor with dimensions (n_shifts, leng, leng) | |
""" | |
eye = np.eye(leng) | |
shifts = range(n_shifts//2, -n_shifts//2, -1) | |
C = np.asarray([np.roll(eye, s, axis=1) for s in shifts]) | |
return theano.shared(C.astype(theano.config.floatX)) | |
def _renorm(x): | |
return x / (x.sum(axis=1, keepdims=True)) | |
def _softmax(x): | |
wt = x.flatten(ndim=2) | |
w = T.nnet.softmax(wt) | |
return w.reshape(x.shape) # T.clip(s, 0, 1) | |
def _cosine_distance(M, k): | |
dot = (M * k[:, None, :]).sum(axis=-1) | |
nM = T.sqrt((M**2).sum(axis=-1)) | |
nk = T.sqrt((k**2).sum(axis=-1, keepdims=True)) | |
return dot / (nM * nk) | |
class NeuralTuringMachine(Recurrent): | |
""" Neural Turing Machines | |
Non obvious parameter: | |
---------------------- | |
shift_range: int, number of available shifts, ex. if 3, avilable shifts are | |
(-1, 0, 1) | |
n_slots: number of memory locations | |
m_length: memory length at each location | |
Known issues: | |
------------- | |
Theano may complain when n_slots == 1. | |
""" | |
def __init__(self, output_dim, n_slots, m_length, shift_range=3, | |
controller_type='gru', | |
init='glorot_uniform', inner_init='orthogonal', | |
input_dim=None, input_length=None, **kwargs): | |
self.output_dim = output_dim | |
self.n_slots = n_slots | |
self.m_length = m_length | |
self.shift_range = shift_range | |
self.init = init | |
self.inner_init = inner_init | |
self.controller_type = controller_type | |
self.input_dim = input_dim | |
self.input_length = input_length | |
if self.input_dim: | |
kwargs['input_shape'] = (self.input_length, self.input_dim) | |
super(NeuralTuringMachine, self).__init__(**kwargs) | |
def build(self): | |
input_leng, input_dim = self.input_shape[1:] | |
self.input = T.tensor3() | |
if self.controller_type == 'gru': | |
self.controller = GRU( | |
activation='relu', | |
input_dim=input_dim+self.m_length, | |
input_length=input_leng, | |
output_dim=self.output_dim, init=self.init, | |
inner_init=self.inner_init) | |
elif self.controller_type == 'lstm': | |
self.controller = LSTM( | |
input_dim=input_dim+self.m_length, | |
input_length=input_leng, | |
output_dim=self.output_dim, init=self.init, | |
forget_bias_init='zero', | |
inner_init=self.inner_init) | |
elif isinstance(self.controller_type, Sequential): | |
self.controller = self.controller_type | |
self.controller.init = self.controller.layers[0].init | |
else: | |
raise ValueError('this controller_type is not implemented.') | |
self.controller.build() | |
# initial memory, state, read and write vecotrs | |
self.M = theano.shared((.001*np.ones((1,)).astype(floatX))) | |
self.init_h = K.zeros((self.output_dim)) | |
self.init_wr = self.controller.init((self.n_slots,)) | |
self.init_ww = self.controller.init((self.n_slots,)) | |
# write | |
self.W_e = self.controller.init((self.output_dim, self.m_length)) # erase | |
self.b_e = K.zeros((self.m_length)) | |
self.W_a = self.controller.init((self.output_dim, self.m_length)) # add | |
self.b_a = K.zeros((self.m_length)) | |
# get_w parameters for reading operation | |
self.W_k_read = self.controller.init((self.output_dim, self.m_length)) | |
self.b_k_read = self.controller.init((self.m_length, )) | |
self.W_c_read = self.controller.init((self.output_dim, 3)) # 3 = beta, g, gamma see eq. 5, 7, 9 | |
self.b_c_read = K.zeros((3)) | |
self.W_s_read = self.controller.init((self.output_dim, self.shift_range)) | |
self.b_s_read = K.zeros((self.shift_range)) # b_s lol! not intentional | |
# get_w parameters for writing operation | |
self.W_k_write = self.controller.init((self.output_dim, self.m_length)) | |
self.b_k_write = self.controller.init((self.m_length, )) | |
self.W_c_write = self.controller.init((self.output_dim, 3)) # 3 = beta, g, gamma see eq. 5, 7, 9 | |
self.b_c_write = K.zeros((3)) | |
self.W_s_write = self.controller.init((self.output_dim, self.shift_range)) | |
self.b_s_write = K.zeros((self.shift_range)) | |
self.C = _circulant(self.n_slots, self.shift_range) | |
self.trainable_weights = self.controller.trainable_weights + [ | |
self.W_e, self.b_e, | |
self.W_a, self.b_a, | |
self.W_k_read, self.b_k_read, | |
self.W_c_read, self.b_c_read, | |
self.W_s_read, self.b_s_read, | |
self.W_k_write, self.b_k_write, | |
self.W_s_write, self.b_s_write, | |
self.W_c_write, self.b_c_write, | |
self.M, | |
self.init_h, self.init_wr, self.init_ww] | |
if self.controller_type == 'lstm': | |
self.init_c = K.zeros((self.output_dim)) | |
self.trainable_weights = self.trainable_weights + [self.init_c, ] | |
def _read(self, w, M): | |
return (w[:, :, None]*M).sum(axis=1) | |
def _write(self, w, e, a, M): | |
Mtilda = M * (1 - w[:, :, None]*e[:, None, :]) | |
Mout = Mtilda + w[:, :, None]*a[:, None, :] | |
return Mout | |
def _get_content_w(self, beta, k, M): | |
num = beta[:, None] * _cosine_distance(M, k) | |
return _softmax(num) | |
def _get_location_w(self, g, s, C, gamma, wc, w_tm1): | |
wg = g[:, None] * wc + (1-g[:, None])*w_tm1 | |
Cs = (C[None, :, :, :] * wg[:, None, None, :]).sum(axis=3) | |
wtilda = (Cs * s[:, :, None]).sum(axis=1) | |
wout = _renorm(wtilda ** gamma[:, None]) | |
return wout | |
def _get_controller_output(self, h, W_k, b_k, W_c, b_c, W_s, b_s): | |
k = T.tanh(T.dot(h, W_k) + b_k) # + 1e-6 | |
c = T.dot(h, W_c) + b_c | |
beta = T.nnet.relu(c[:, 0]) + 1e-4 | |
g = T.nnet.sigmoid(c[:, 1]) | |
gamma = T.nnet.relu(c[:, 2]) + 1.0001 | |
s = T.nnet.softmax(T.dot(h, W_s) + b_s) | |
return k, beta, g, gamma, s | |
def get_initial_states(self, X): | |
batch_size = X.shape[0] | |
init_M = self.M.dimshuffle(0, 'x', 'x').repeat( | |
batch_size, axis=0).repeat(self.n_slots, axis=1).repeat( | |
self.m_length, axis=2) | |
init_M = init_M.flatten(ndim=2) | |
init_h = self.init_h.dimshuffle(('x', 0)).repeat(batch_size, axis=0) | |
init_wr = self.init_wr.dimshuffle(('x', 0)).repeat(batch_size, axis=0) | |
init_ww = self.init_ww.dimshuffle(('x', 0)).repeat(batch_size, axis=0) | |
if self.controller_type == 'lstm': | |
init_c = self.init_c.dimshuffle(('x', 0)).repeat(batch_size, axis=0) | |
return [init_M, T.nnet.softmax(init_wr), T.nnet.softmax(init_ww), | |
init_h, init_c] | |
else: | |
return [init_M, T.nnet.softmax(init_wr), T.nnet.softmax(init_ww), | |
init_h] | |
@property | |
def output_shape(self): | |
input_shape = self.input_shape | |
if self.return_sequences: | |
return input_shape[0], input_shape[1], self.output_dim | |
else: | |
return input_shape[0], self.output_dim | |
def get_full_output(self, train=False): | |
""" | |
This method is for research and visualization purposes. Use it as | |
X = model.get_input() # full model | |
Y = ntm.get_output() # this layer | |
F = theano.function([X], Y, allow_input_downcast=True) | |
[memory, read_address, write_address, rnn_state] = F(x) | |
if controller_type == "lstm" use it as | |
[memory, read_address, write_address, rnn_cell, rnn_state] = F(x) | |
""" | |
# input shape: (nb_samples, time (padded with zeros), input_dim) | |
X = self.get_input(train) | |
assert K.ndim(X) == 3 | |
if K._BACKEND == 'tensorflow': | |
if not self.input_shape[1]: | |
raise Exception('When using TensorFlow, you should define ' + | |
'explicitely the number of timesteps of ' + | |
'your sequences. Make sure the first layer ' + | |
'has a "batch_input_shape" argument ' + | |
'including the samples axis.') | |
# mask = self.get_output_mask(train) | |
# if mask: | |
# # apply mask | |
# X *= K.cast(K.expand_dims(mask), X.dtype) | |
# masking = True | |
# else: | |
# masking = False | |
if self.stateful: | |
initial_states = self.states | |
else: | |
initial_states = self.get_initial_states(X) | |
states = K.rnn(self.step, X, initial_states, | |
go_backwards=self.go_backwards) | |
return states | |
def step(self, x, states): | |
M_tm1, wr_tm1, ww_tm1 = states[:3] | |
# reshape | |
M_tm1 = M_tm1.reshape((x.shape[0], self.n_slots, self.m_length)) | |
# read | |
h_tm1 = states[3:] | |
k_read, beta_read, g_read, gamma_read, s_read = self._get_controller_output( | |
h_tm1[0], self.W_k_read, self.b_k_read, self.W_c_read, self.b_c_read, | |
self.W_s_read, self.b_s_read) | |
wc_read = self._get_content_w(beta_read, k_read, M_tm1) | |
wr_t = self._get_location_w(g_read, s_read, self.C, gamma_read, | |
wc_read, wr_tm1) | |
M_read = self._read(wr_t, M_tm1) | |
# update controller | |
h_t = _update_controller(self, x, h_tm1, M_read) | |
# write | |
k_write, beta_write, g_write, gamma_write, s_write = self._get_controller_output( | |
h_t[0], self.W_k_write, self.b_k_write, self.W_c_write, | |
self.b_c_write, self.W_s_write, self.b_s_write) | |
wc_write = self._get_content_w(beta_write, k_write, M_tm1) | |
ww_t = self._get_location_w(g_write, s_write, self.C, gamma_write, | |
wc_write, ww_tm1) | |
e = T.nnet.sigmoid(T.dot(h_t[0], self.W_e) + self.b_e) | |
a = T.tanh(T.dot(h_t[0], self.W_a) + self.b_a) | |
M_t = self._write(ww_t, e, a, M_tm1) | |
M_t = M_t.flatten(ndim=2) | |
return h_t[0], [M_t, wr_t, ww_t] + h_t |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment