Created
April 10, 2021 22:22
-
-
Save weilueluo/39e13939270e546944c98e872120c219 to your computer and use it in GitHub Desktop.
minimum multi-head attention implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, in_size, n_heads=1, scaled=True): | |
super().__init__() | |
# in_size = d_k in the paper | |
self.scale = in_size ** 0.5 if scaled else 1 | |
self.n_heads = n_heads | |
self.q_linear = nn.Linear(in_size, in_size * n_heads) | |
self.k_linear = nn.Linear(in_size, in_size * n_heads) | |
self.v_linear = nn.Linear(in_size, in_size * n_heads) | |
self.o_linear = nn.Linear(in_size * n_heads, in_size) | |
def forward(self, x): | |
batch_size, seq_len, in_size = x.shape | |
# projection | |
q = self.q_linear(x).reshape(batch_size, self.n_heads, seq_len, in_size) | |
k = self.k_linear(x).reshape(batch_size, self.n_heads, seq_len, in_size) | |
v = self.v_linear(x).reshape(batch_size, self.n_heads, seq_len, in_size) | |
# attention | |
attn = torch.matmul(q, k.transpose(2, 3)) / self.scale | |
score = torch.softmax(attn, dim=-1) | |
attn_out = torch.matmul(score, v) | |
# concat | |
concatenated = attn_out.transpose(1, 2).reshape(batch_size, seq_len, -1) | |
# projection | |
projected = self.o_linear(concatenated) | |
# residual | |
out = x + projected | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment