Skip to content

Instantly share code, notes, and snippets.

@kachayev
Last active May 3, 2021 21:06
Show Gist options
  • Save kachayev/652fc45fbac85a380e4210eced87c83f to your computer and use it in GitHub Desktop.
Save kachayev/652fc45fbac85a380e4210eced87c83f to your computer and use it in GitHub Desktop.
Simplified Encoder implementation for "NLP From Scratch: Translation with a Sequence to Sequence Network and Attention" tutorial
#
# The following is the simplification of EncoderRNN code from
# "NLP From Scratch: Translation with a Sequence to Sequence Network and Attention"
# PyTorch tutorial (link: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html)
#
# In fact, `nn.GRU` module can execute loop mechanics over the given input sequence when provided
# with such. In the tutorial, the sentence is presented as a tensor of 1-word sequences hence the
# loop is handled in `train` and `evaluation` "manually". Given the fact the `nn.Embedding` handles
# set of indecies and `nn.GRU` executes loops when given (seq_len, batch_size, elem_dim) shaped
# input, encoder could be constructed as the following:
#
class EncoderRNNSimplified(nn.Module):
def __init__(self, input_size, hidden_size):
super(EncoderRNNSimplified, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = nn.Embedding(input_size, hidden_size)
# `nn.Embedding` module outputs [batch_size, seq_len, hidden_dim]
# when `nn.GRU` expects [seq_len, batch_size, hidden_dim] by default
# so, we either need to use `permute(1,0,2)` to get proper view from the
# embedding tensor or set `batch_first` to `True` (in this case
# `nn.GRU` accepts [batch_size, seq_len, hidden_dim])
self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
# `hidden` is automatically evaluated to properly shaped tensor of zeros if not given
def forward(self, x, hidden=None):
return self.rnn(self.embedding(x), hidden)
#
# Now, here's how the encoder should be used in `evaluate` (using simple decoder
# as an example, decoder with attention will work just the same). `training` is similar
#
def sentence_to_tensor_simplified(lang, sentence):
# array of indices, len = (num words + 1 for EOS)
idx = sentence_to_index(lang, sentence)
idx.append(EOS_TOKEN)
# note, that we don't transpose tensor here
return torch.tensor(idx, device=device)
def evaluate_simplified(encoder, decoder, sentence, max_length=MAX_LENGTH):
with torch.no_grad():
input_tensor = sentence_to_tensor_simplified(input_lang, sentence)
# `unsqueeze` here is used to create batch of size 1
# in most practical cases, working with `DataLoader`s, that won't be
# necessary as loaders typically return data in batches
encoder_outputs, encoder_hidden = encoder(input_tensor.unsqueeze(0))
# `squeeze` to "ignore" batching, the result would be exactly the same
# as it was previously done by gathering output tensors in for loop and
# injecting them into outer `encoder_outputs` tensor
encoder_outputs = encoder_outputs.squeeze(0)
# `encoder_hidden` has a shape of (1, batch_size, hidden_dim) and could be used
# as an argument into a decoder "as is"
decoder_hidden = encoder_hidden
decoder_input = torch.tensor([[SOS_TOKEN]], device=device)
decoded_words = []
for di in range(max_length):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
topi = torch.argmax(decoder_output)
if topi.item() == EOS_TOKEN: break
else:
decoded_words.append(output_lang.index2word[topi.item()])
decoder_input = topi.detach()
return " ".join(decoded_words)
#
# How does it work step by step with already trained `encoder`:
#
# sentence = "i am just going for a walk"
# #> 7 words
# input_tensor = sentence_to_tensor_simplified(input_lang, sentence)
# #> torch.Size([8])
# input_batch = input_tensor.unsqueeze(0)
# #> torch.Size([1, 8])
# embedding_batch = encoder.embedding(input_batch)
# #> torch.Size([1, 8, 256])
# outputs, hidden = encoder.rnn(embedding_batch)
# #> (torch.Size([1, 8, 256]), torch.Size([1, 1, 256]))
# outputs = outputs.squeeze(0)
# #> torch.Size([8, 256])
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment