Skip to content

Instantly share code, notes, and snippets.

@bobchennan
Forked from lucidrains/pytorch_reformer.py
Last active January 23, 2020 00:58
Show Gist options
  • Save bobchennan/228a3c363d193409b9d2a1cf3bc39967 to your computer and use it in GitHub Desktop.
Save bobchennan/228a3c363d193409b9d2a1cf3bc39967 to your computer and use it in GitHub Desktop.
reformer(pytorch)
import torch
import torch.nn as nn
import torch.nn.functional as F
# helpers
def make_unit_length(x, epsilon=1e-6):
norm = x.norm(p=2, dim=-1, keepdim=True)
return x.div(norm + epsilon)
def sort_key_val(t1, t2, dim=-1):
values, indices = t1.sort(dim=dim)
t2 = t2.expand_as(t1)
return values, t2.gather(dim, indices)
def batched_index_select(values, indices):
b = values.shape[0]
return values[torch.arange(0, b), indices.transpose(0, 1)].transpose(0, 1)
# reversible net helper classes
class ReversibleBlock(nn.Module):
def __init__(self, f_block, g_block, dim = 1):
super().__init__()
self.dim = dim
self.f_block = f_block
self.g_block = g_block
def forward(self, x):
x1, x2 = torch.chunk(x, 2, dim=self.dim)
y1, y2 = None, None
with torch.no_grad():
y1 = x1 + self.f_block(x2)
y2 = x2 + self.g_block(y1)
return torch.cat([y1, y2], dim=self.dim)
def backward_pass(self, y, dy):
y1, y2 = torch.chunk(y, 2, dim=self.dim)
del y
dy1, dy2 = torch.chunk(dy, 2, dim=self.dim)
del dy
y1.requires_grad = True
y2.requires_grad = True
with torch.enable_grad():
gy1 = self.g_block(y1)
gy1.backward(dy2)
with torch.no_grad():
x2 = y2 - gy1
del y2, gy1
dx1 = dy1 + y1.grad
del dy1
y1.grad = None
with torch.enable_grad():
x2.requires_grad = True
fx2 = self.f_block(x2)
fx2.backward(dx1)
with torch.no_grad():
x1 = y1 - fx2
del y1, fx2
dx2 = dy2 + x2.grad
del dy2
x2.grad = None
x = torch.cat([x1, x2.detach()], dim=self.dim)
dx = torch.cat([dx1, dx2], dim=self.dim)
return x, dx
class _ReversibleModuleFunction(torch.autograd.function.Function):
@staticmethod
def forward(ctx, x, reversible_blocks):
for block in reversible_blocks:
x = block(x)
ctx.y = x.detach()
ctx.reversible_blocks = reversible_blocks
return x
@staticmethod
def backward(ctx, dy):
y = ctx.y
del ctx.y
for i in range(len(ctx.reversible_blocks) - 1, -1, -1):
y, dy = ctx.reversible_blocks[i].backward_pass(y, dy)
del ctx.reversible_blocks
return dy, None
class ReversibleSequence(nn.Module):
def __init__(self, reversible_blocks):
super().__init__()
self.reversible_blocks = reversible_blocks
def forward(self, x):
x = _ReversibleModuleFunction.apply(x, self.reversible_blocks)
return x
# lsh attention
class LSHAttention(nn.Module):
def __init__( self,
dropout = 0.,
bucket_size = 64,
n_hashes = 8,
allow_duplicate_attention = False,
attend_across_buckets = False,
rehash_each_round = True,
drop_for_hash_rate = 0.0):
super().__init__()
if dropout >= 1.0:
raise ValueError('Dropout rates must be lower than 1.')
self.dropout = nn.Dropout(dropout)
self.dropout_for_hash = nn.Dropout(drop_for_hash_rate)
assert rehash_each_round or allow_duplicate_attention, (
'The setting {allow_duplicate_attention=False, rehash_each_round=False}'
' is not implemented.')
self.n_hashes = n_hashes
self.bucket_size = bucket_size
self._allow_duplicate_attention = allow_duplicate_attention
self._attend_across_buckets = attend_across_buckets
self._rehash_each_round = rehash_each_round
def _sample_rotation(self, shape, vecs):
device = vecs.device
return torch.randn(shape, device=device)
def hash_vectors(self, n_buckets, vecs):
batch_size = vecs.shape[0]
device = vecs.device
# See https://arxiv.org/pdf/1509.02897.pdf
# We sample a different random rotation for each round of hashing to
# decrease the probability of hash misses.
assert n_buckets % 2 == 0
rot_size = n_buckets
rotations_shape = (
vecs.shape[-1],
self.n_hashes if self._rehash_each_round else 1,
rot_size // 2)
random_rotations = self._sample_rotation(rotations_shape, vecs)
dropped_vecs = self.dropout_for_hash(vecs)
rotated_vecs = torch.einsum('btf,fhi->bhti', dropped_vecs, random_rotations)
if self._rehash_each_round:
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
buckets = torch.argmax(rotated_vecs, dim=-1)
# buckets is now (self.n_hashes, seqlen). Next we add offsets so that
# bucket numbers from different hashing rounds don't overlap.
offsets = torch.arange(self.n_hashes, device=device)
offsets = torch.reshape(offsets * n_buckets, (1, -1, 1))
buckets = torch.reshape(buckets + offsets, (batch_size, -1,))
else:
assert not self._factorize_hash
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1)
# In this configuration, we map each item to the top self.n_hashes buckets
rotated_vecs = torch.squeeze(rotated_vecs, 0)
bucket_range = torch.arange(0, rotated_vecs.shape[-1], device=device)
bucket_range = torch.reshape(bucket_range, (1, -1))
bucket_range = bucket_range.expand_as(rotated_vecs.shape)
_, buckets = sort_key_val(rotated_vecs, bucket_range, dim=-1)
buckets = buckets[:, -self.n_hashes:]
h, *_ = buckets.shape
buckets = torch.reshape(buckets.permute((*_, h)), (-1,))
return buckets
def forward(self, qk, v):
batch_size, seqlen, _ = qk.shape
device = qk.device
n_buckets = seqlen // self.bucket_size
n_bins = n_buckets
buckets = self.hash_vectors(n_buckets, qk)
# We use the same vector as both a query and a key.
assert int(buckets.shape[1]) == self.n_hashes * seqlen
ticker = torch.arange(0, self.n_hashes * seqlen, device=device).unsqueeze(0)
buckets_and_t = seqlen * buckets + (ticker % seqlen)
buckets_and_t = buckets_and_t.detach()
# Hash-based sort ("s" at the start of variable names means "sorted")
sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)
_, undo_sort = sort_key_val(sticker, ticker, dim=-1)
sbuckets_and_t = sbuckets_and_t.detach()
sticker = sticker.detach()
undo_sort = undo_sort.detach()
st = (sticker % seqlen)
sqk = batched_index_select(qk, st)
sv = batched_index_select(v, st)
# Split off a "bin" axis so that attention only occurs within chunks.
bq_t = bkv_t = torch.reshape(st, (batch_size, self.n_hashes * n_bins, -1))
bqk = torch.reshape(sqk, (batch_size, self.n_hashes * n_bins, -1, sqk.shape[-1]))
bv = torch.reshape(sv, (batch_size, self.n_hashes * n_bins, -1, sv.shape[-1]))
bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, self.n_hashes * n_bins, -1))
# Hashing operates on unit-length vectors. Unnormalized query vectors are
# fine because they effectively provide a learnable temperature for the
# attention softmax, but normalizing keys is needed so that similarity for
# the purposes of attention correctly corresponds to hash locality.
bq = bqk
bk = make_unit_length(bqk)
# Allow each chunk to attend within itself, and also one chunk back. Chunk
# boundaries might occur in the middle of a sequence of items from the
# same bucket, so this increases the chances of attending to relevant items.
def look_one_back(x):
x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
return torch.cat([x, x_extra], dim=2)
bk = look_one_back(bk)
bv = look_one_back(bv)
bkv_t = look_one_back(bkv_t)
bkv_buckets = look_one_back(bkv_buckets)
# Dot-product attention.
dots = torch.einsum('bhie,bhje->bhij', bq, bk) / (bq.shape[-1] ** -0.5)
# Causal masking
mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]
dots = dots - 1e9 * mask.float()
# Mask out attention to self except when no other targets are available.
self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
dots = dots - 1e5 * self_mask.float()
# Mask out attention to other hash buckets.
if not self._attend_across_buckets:
bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :]
dots = dots - 1e7 * bucket_mask.float()
# Don't double-count query-key pairs across multiple rounds of hashing.
# There are two possible strategies here. (1) The default is to count how
# many times a query-key pair is repeated, and to lower its log-prob
# correspondingly at each repetition. (2) When hard_k is set, the code
# instead masks all but the first occurence of each query-key pair.
if not self._allow_duplicate_attention:
locs1 = undo_sort // bq_t.shape[-1]
locs2 = (locs1 + 1) % (self.n_hashes * n_bins)
if not self._attend_across_buckets:
locs1 = buckets * (self.n_hashes * n_bins) + locs1
locs2 = buckets * (self.n_hashes * n_bins) + locs2
locs = torch.cat([
torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)),
torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)),
], 1).permute((0, 2, 1))
slocs = batched_index_select(locs, st)
b_locs = torch.reshape(slocs, (batch_size, self.n_hashes * n_bins, -1, 2 * self.n_hashes))
b_locs1 = b_locs[:, :, :, None, :self.n_hashes]
bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes))
bq_locs = torch.reshape(bq_locs, b_locs.shape)
bkv_locs = look_one_back(b_locs)
dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :]).float().sum(dim=-1)
dup_counts = dup_counts.detach()
assert dup_counts.shape == dots.shape
dots = dots - torch.log(dup_counts + 1e-9)
# Softmax.
dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
dots = torch.exp(dots - dots_logsumexp)
dots = self.dropout(dots)
bo = torch.einsum('buij,buje->buie', dots, bv)
so = torch.reshape(bo, (batch_size, -1, bo.shape[-1]))
slogits = torch.reshape(dots_logsumexp, (batch_size, -1,))
o = batched_index_select(so, undo_sort)
_, logits = sort_key_val(sticker, slogits, dim=-1)
if self.n_hashes == 1:
out = o
else:
o = torch.reshape(o, (batch_size, self.n_hashes, seqlen, o.shape[-1]))
logits = torch.reshape(logits, (batch_size, self.n_hashes, seqlen, 1))
probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True))
out = torch.sum(o * probs, dim=1)
assert out.shape == v.shape
return out
class LSHSelfAttention(nn.Module):
def __init__(self, emb, heads = 8, bucket_size = 64, n_hashes = 8, **kwargs):
super().__init__()
self.heads = heads
self.toqk = nn.Linear(emb, emb * heads)
self.tov = nn.Linear(emb, emb * heads)
self.unify_heads = nn.Linear(emb * heads, emb)
self.bucket_size = bucket_size
self.lsh_attn = LSHAttention(bucket_size=bucket_size, **kwargs)
def forward(self, x):
b, t, e, h = *x.shape, self.heads
assert t % self.bucket_size == 0, 'Sequence length needs to be divisible by target bucket size - {self.bucket_size}'
qk = self.toqk(x)
v = self.tov(x)
def merge_heads(v):
return v.view(b, t, h, e).transpose(1, 2).reshape(b * h, t, e)
def split_heads(v):
return v.view(b, h, t, e).transpose(1, 2).contiguous()
qk = merge_heads(qk)
v = merge_heads(v)
attn_out = self.lsh_attn(qk, v)
out = split_heads(attn_out).view(b, t, h * e)
return self.unify_heads(out)
# feedforward with chunking
class FeedForward(nn.Module):
def __init__(self, emb, mult = 4):
super().__init__()
self.emb = emb
self.proj_in = nn.Linear(emb, emb * mult)
self.proj_out = nn.Linear(emb * mult, emb)
def forward(self, x):
x = self.proj_in(x)
x = F.gelu(x) # works for new version of pytorch, otherwise using relu for example
x = self.proj_out(x)
return x
class WithLayerNorm(nn.Module):
def __init__(self, emb, fn):
super().__init__()
self.emb = emb
self.norm = nn.LayerNorm(emb)
self.fn = fn
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class Chunk(nn.Module):
def __init__(self, chunks, fn, dim = -1):
super().__init__()
self.dim = dim
self.chunks = chunks
self.fn = fn
def forward(self, x):
chunks = x.chunk(self.chunks, dim = self.dim)
return torch.cat([self.fn(c) for c in chunks], dim = self.dim)
# reformer auto-regressive lm
class Reformer(nn.Module):
def __init__(self, emb, depth, max_seq_len, num_tokens = 10000, heads = 8, bucket_size = 64, n_hashes = 8, ff_chunks = 100):
super().__init__()
self.emb = emb
self.depth = depth
self.token_emb = nn.Embedding(num_tokens, emb)
self.pos_emb = nn.Embedding(max_seq_len, emb)
blocks = []
for _ in range(depth):
f = WithLayerNorm(emb, LSHSelfAttention(emb, heads, bucket_size, n_hashes))
g = Chunk(ff_chunks, WithLayerNorm(emb, FeedForward(emb)), dim = -2)
blocks.append(ReversibleBlock(f, g, dim=-1))
self.layers = ReversibleSequence(nn.ModuleList(blocks))
self.to_logits = nn.Linear(emb, num_tokens)
def forward(self, x):
x = self.token_emb(x) + self.pos_emb(torch.arange(0, x.shape[1]).to(x.device))
x = torch.cat([x, x], dim = -1)
x = self.layers(x)
x = torch.stack(x.chunk(2, dim=-1)).sum(dim=0)
return self.to_logits(x)
# testing
num_tokens = 1000
seq_len = 3456
r = Reformer(
emb = 512,
depth = 12,
max_seq_len = seq_len,
num_tokens= num_tokens,
heads = 8,
bucket_size = 64,
n_hashes = 8,
ff_chunks = 200
)
x = torch.randint(0, num_tokens, (1, seq_len)).long()
y = r(x)
y.sum().backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment