-
-
Save karpathy/587454dc0146a6ae21fc to your computer and use it in GitHub Desktop.
""" | |
This is a batched LSTM forward and backward pass | |
""" | |
import numpy as np | |
import code | |
class LSTM: | |
@staticmethod | |
def init(input_size, hidden_size, fancy_forget_bias_init = 3): | |
""" | |
Initialize parameters of the LSTM (both weights and biases in one matrix) | |
One might way to have a positive fancy_forget_bias_init number (e.g. maybe even up to 5, in some papers) | |
""" | |
# +1 for the biases, which will be the first row of WLSTM | |
WLSTM = np.random.randn(input_size + hidden_size + 1, 4 * hidden_size) / np.sqrt(input_size + hidden_size) | |
WLSTM[0,:] = 0 # initialize biases to zero | |
if fancy_forget_bias_init != 0: | |
# forget gates get little bit negative bias initially to encourage them to be turned off | |
# remember that due to Xavier initialization above, the raw output activations from gates before | |
# nonlinearity are zero mean and on order of standard deviation ~1 | |
WLSTM[0,hidden_size:2*hidden_size] = fancy_forget_bias_init | |
return WLSTM | |
@staticmethod | |
def forward(X, WLSTM, c0 = None, h0 = None): | |
""" | |
X should be of shape (n,b,input_size), where n = length of sequence, b = batch size | |
""" | |
n,b,input_size = X.shape | |
d = WLSTM.shape[1]/4 # hidden size | |
if c0 is None: c0 = np.zeros((b,d)) | |
if h0 is None: h0 = np.zeros((b,d)) | |
# Perform the LSTM forward pass with X as the input | |
xphpb = WLSTM.shape[0] # x plus h plus bias, lol | |
Hin = np.zeros((n, b, xphpb)) # input [1, xt, ht-1] to each tick of the LSTM | |
Hout = np.zeros((n, b, d)) # hidden representation of the LSTM (gated cell content) | |
IFOG = np.zeros((n, b, d * 4)) # input, forget, output, gate (IFOG) | |
IFOGf = np.zeros((n, b, d * 4)) # after nonlinearity | |
C = np.zeros((n, b, d)) # cell content | |
Ct = np.zeros((n, b, d)) # tanh of cell content | |
for t in xrange(n): | |
# concat [x,h] as input to the LSTM | |
prevh = Hout[t-1] if t > 0 else h0 | |
Hin[t,:,0] = 1 # bias | |
Hin[t,:,1:input_size+1] = X[t] | |
Hin[t,:,input_size+1:] = prevh | |
# compute all gate activations. dots: (most work is this line) | |
IFOG[t] = Hin[t].dot(WLSTM) | |
# non-linearities | |
IFOGf[t,:,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:,:3*d])) # sigmoids; these are the gates | |
IFOGf[t,:,3*d:] = np.tanh(IFOG[t,:,3*d:]) # tanh | |
# compute the cell activation | |
prevc = C[t-1] if t > 0 else c0 | |
C[t] = IFOGf[t,:,:d] * IFOGf[t,:,3*d:] + IFOGf[t,:,d:2*d] * prevc | |
Ct[t] = np.tanh(C[t]) | |
Hout[t] = IFOGf[t,:,2*d:3*d] * Ct[t] | |
cache = {} | |
cache['WLSTM'] = WLSTM | |
cache['Hout'] = Hout | |
cache['IFOGf'] = IFOGf | |
cache['IFOG'] = IFOG | |
cache['C'] = C | |
cache['Ct'] = Ct | |
cache['Hin'] = Hin | |
cache['c0'] = c0 | |
cache['h0'] = h0 | |
# return C[t], as well so we can continue LSTM with prev state init if needed | |
return Hout, C[t], Hout[t], cache | |
@staticmethod | |
def backward(dHout_in, cache, dcn = None, dhn = None): | |
WLSTM = cache['WLSTM'] | |
Hout = cache['Hout'] | |
IFOGf = cache['IFOGf'] | |
IFOG = cache['IFOG'] | |
C = cache['C'] | |
Ct = cache['Ct'] | |
Hin = cache['Hin'] | |
c0 = cache['c0'] | |
h0 = cache['h0'] | |
n,b,d = Hout.shape | |
input_size = WLSTM.shape[0] - d - 1 # -1 due to bias | |
# backprop the LSTM | |
dIFOG = np.zeros(IFOG.shape) | |
dIFOGf = np.zeros(IFOGf.shape) | |
dWLSTM = np.zeros(WLSTM.shape) | |
dHin = np.zeros(Hin.shape) | |
dC = np.zeros(C.shape) | |
dX = np.zeros((n,b,input_size)) | |
dh0 = np.zeros((b, d)) | |
dc0 = np.zeros((b, d)) | |
dHout = dHout_in.copy() # make a copy so we don't have any funny side effects | |
if dcn is not None: dC[n-1] += dcn.copy() # carry over gradients from later | |
if dhn is not None: dHout[n-1] += dhn.copy() | |
for t in reversed(xrange(n)): | |
tanhCt = Ct[t] | |
dIFOGf[t,:,2*d:3*d] = tanhCt * dHout[t] | |
# backprop tanh non-linearity first then continue backprop | |
dC[t] += (1-tanhCt**2) * (IFOGf[t,:,2*d:3*d] * dHout[t]) | |
if t > 0: | |
dIFOGf[t,:,d:2*d] = C[t-1] * dC[t] | |
dC[t-1] += IFOGf[t,:,d:2*d] * dC[t] | |
else: | |
dIFOGf[t,:,d:2*d] = c0 * dC[t] | |
dc0 = IFOGf[t,:,d:2*d] * dC[t] | |
dIFOGf[t,:,:d] = IFOGf[t,:,3*d:] * dC[t] | |
dIFOGf[t,:,3*d:] = IFOGf[t,:,:d] * dC[t] | |
# backprop activation functions | |
dIFOG[t,:,3*d:] = (1 - IFOGf[t,:,3*d:] ** 2) * dIFOGf[t,:,3*d:] | |
y = IFOGf[t,:,:3*d] | |
dIFOG[t,:,:3*d] = (y*(1.0-y)) * dIFOGf[t,:,:3*d] | |
# backprop matrix multiply | |
dWLSTM += np.dot(Hin[t].transpose(), dIFOG[t]) | |
dHin[t] = dIFOG[t].dot(WLSTM.transpose()) | |
# backprop the identity transforms into Hin | |
dX[t] = dHin[t,:,1:input_size+1] | |
if t > 0: | |
dHout[t-1,:] += dHin[t,:,input_size+1:] | |
else: | |
dh0 += dHin[t,:,input_size+1:] | |
return dX, dWLSTM, dc0, dh0 | |
# ------------------- | |
# TEST CASES | |
# ------------------- | |
def checkSequentialMatchesBatch(): | |
""" check LSTM I/O forward/backward interactions """ | |
n,b,d = (5, 3, 4) # sequence length, batch size, hidden size | |
input_size = 10 | |
WLSTM = LSTM.init(input_size, d) # input size, hidden size | |
X = np.random.randn(n,b,input_size) | |
h0 = np.random.randn(b,d) | |
c0 = np.random.randn(b,d) | |
# sequential forward | |
cprev = c0 | |
hprev = h0 | |
caches = [{} for t in xrange(n)] | |
Hcat = np.zeros((n,b,d)) | |
for t in xrange(n): | |
xt = X[t:t+1] | |
_, cprev, hprev, cache = LSTM.forward(xt, WLSTM, cprev, hprev) | |
caches[t] = cache | |
Hcat[t] = hprev | |
# sanity check: perform batch forward to check that we get the same thing | |
H, _, _, batch_cache = LSTM.forward(X, WLSTM, c0, h0) | |
assert np.allclose(H, Hcat), 'Sequential and Batch forward don''t match!' | |
# eval loss | |
wrand = np.random.randn(*Hcat.shape) | |
loss = np.sum(Hcat * wrand) | |
dH = wrand | |
# get the batched version gradients | |
BdX, BdWLSTM, Bdc0, Bdh0 = LSTM.backward(dH, batch_cache) | |
# now perform sequential backward | |
dX = np.zeros_like(X) | |
dWLSTM = np.zeros_like(WLSTM) | |
dc0 = np.zeros_like(c0) | |
dh0 = np.zeros_like(h0) | |
dcnext = None | |
dhnext = None | |
for t in reversed(xrange(n)): | |
dht = dH[t].reshape(1, b, d) | |
dx, dWLSTMt, dcprev, dhprev = LSTM.backward(dht, caches[t], dcnext, dhnext) | |
dhnext = dhprev | |
dcnext = dcprev | |
dWLSTM += dWLSTMt # accumulate LSTM gradient | |
dX[t] = dx[0] | |
if t == 0: | |
dc0 = dcprev | |
dh0 = dhprev | |
# and make sure the gradients match | |
print 'Making sure batched version agrees with sequential version: (should all be True)' | |
print np.allclose(BdX, dX) | |
print np.allclose(BdWLSTM, dWLSTM) | |
print np.allclose(Bdc0, dc0) | |
print np.allclose(Bdh0, dh0) | |
def checkBatchGradient(): | |
""" check that the batch gradient is correct """ | |
# lets gradient check this beast | |
n,b,d = (5, 3, 4) # sequence length, batch size, hidden size | |
input_size = 10 | |
WLSTM = LSTM.init(input_size, d) # input size, hidden size | |
X = np.random.randn(n,b,input_size) | |
h0 = np.random.randn(b,d) | |
c0 = np.random.randn(b,d) | |
# batch forward backward | |
H, Ct, Ht, cache = LSTM.forward(X, WLSTM, c0, h0) | |
wrand = np.random.randn(*H.shape) | |
loss = np.sum(H * wrand) # weighted sum is a nice hash to use I think | |
dH = wrand | |
dX, dWLSTM, dc0, dh0 = LSTM.backward(dH, cache) | |
def fwd(): | |
h,_,_,_ = LSTM.forward(X, WLSTM, c0, h0) | |
return np.sum(h * wrand) | |
# now gradient check all | |
delta = 1e-5 | |
rel_error_thr_warning = 1e-2 | |
rel_error_thr_error = 1 | |
tocheck = [X, WLSTM, c0, h0] | |
grads_analytic = [dX, dWLSTM, dc0, dh0] | |
names = ['X', 'WLSTM', 'c0', 'h0'] | |
for j in xrange(len(tocheck)): | |
mat = tocheck[j] | |
dmat = grads_analytic[j] | |
name = names[j] | |
# gradcheck | |
for i in xrange(mat.size): | |
old_val = mat.flat[i] | |
mat.flat[i] = old_val + delta | |
loss0 = fwd() | |
mat.flat[i] = old_val - delta | |
loss1 = fwd() | |
mat.flat[i] = old_val | |
grad_analytic = dmat.flat[i] | |
grad_numerical = (loss0 - loss1) / (2 * delta) | |
if grad_numerical == 0 and grad_analytic == 0: | |
rel_error = 0 # both are zero, OK. | |
status = 'OK' | |
elif abs(grad_numerical) < 1e-7 and abs(grad_analytic) < 1e-7: | |
rel_error = 0 # not enough precision to check this | |
status = 'VAL SMALL WARNING' | |
else: | |
rel_error = abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic) | |
status = 'OK' | |
if rel_error > rel_error_thr_warning: status = 'WARNING' | |
if rel_error > rel_error_thr_error: status = '!!!!! NOTOK' | |
# print stats | |
print '%s checking param %s index %s (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \ | |
% (status, name, `np.unravel_index(i, mat.shape)`, old_val, grad_analytic, grad_numerical, rel_error) | |
if __name__ == "__main__": | |
checkSequentialMatchesBatch() | |
raw_input('check OK, press key to continue to gradient check') | |
checkBatchGradient() | |
print 'every line should start with OK. Have a nice day!' |
Thanks a lot, karpathy!!!!
Could you please suggest me a good reference to get familiar with SLTM equations? Presently, I'm using
http://arxiv.org/abs/1503.04069 and https://apaszke.github.io/lstm-explained.html
Thanks
This is fantastic and clearly written. Thanks for this
Hello,
Any try to convert it slightly in cython ? (it should give a boost of x3)
Thanks for this!
I rewrote the code in R, if anyone is interested.
What should be the shape of 'dHout_in' in backward pass if I want to consider only nth time's state in my classification/softmax layer??
Can someone post some example of using this batched LSTM for training over a dataset? I'm new to the domain and I have learned a lot by reading this well written code, but when it is time to use it in a real learning situation I find problems.
This is indeed very efficient. I sat down hoping to rewrite this faster. Early on I changed a lot of things. But later, I reverted most of it.
Have you used numba. It gave me quite a bit of speedup.