Created
May 5, 2015 07:31
-
-
Save karpathy/7bae8033dcf5ca2630ba to your computer and use it in GitHub Desktop.
Efficient LSTM cell in Torch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--[[ | |
Efficient LSTM in Torch using nngraph library. This code was optimized | |
by Justin Johnson (@jcjohnson) based on the trick of batching up the | |
LSTM GEMMs, as also seen in my efficient Python LSTM gist. | |
--]] | |
function LSTM.fast_lstm(input_size, rnn_size) | |
local x = nn.Identity()() | |
local prev_c = nn.Identity()() | |
local prev_h = nn.Identity()() | |
local i2h = nn.Linear(input_size, 4 * rnn_size)(x) | |
local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h) | |
local all_input_sums = nn.CAddTable()({i2h, h2h}) | |
local sigmoid_chunk = nn.Narrow(2, 1, 3 * rnn_size)(all_input_sums) | |
sigmoid_chunk = nn.Sigmoid()(sigmoid_chunk) | |
local in_gate = nn.Narrow(2, 1, rnn_size)(sigmoid_chunk) | |
local forget_gate = nn.Narrow(2, rnn_size + 1, rnn_size)(sigmoid_chunk) | |
local out_gate = nn.Narrow(2, 2 * rnn_size + 1, rnn_size)(sigmoid_chunk) | |
local in_transform = nn.Narrow(2, 3 * rnn_size + 1, rnn_size)(all_input_sums) | |
in_transform = nn.Tanh()(in_transform) | |
local next_c = nn.CAddTable()({ | |
nn.CMulTable()({forget_gate, prev_c}), | |
nn.CMulTable()({in_gate, in_transform}) | |
}) | |
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) | |
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h}) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment