Skip to content

Instantly share code, notes, and snippets.

@proger
Last active October 8, 2024 04:13
Show Gist options
  • Save proger/0a04b2168f1110636c720ba204b5ac2d to your computer and use it in GitHub Desktop.
Save proger/0a04b2168f1110636c720ba204b5ac2d to your computer and use it in GitHub Desktop.
"""
DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory.
`forward` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching.
`forward_loop` is the reference implementation of the original recurrence.
References:
[1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985)
Method 1, section 3 guides `decay_values`.
https://ecommons.cornell.edu/items/92a11030-dca1-45d4-a0ba-732cf962b2b2
[2] Parallelizing Linear Transformers with the Delta Rule over Sequence Length (Yang et al 2024)
- equation 5 is a specialization of method 1 of [1] is in `decay_values`
- equation 6 is application of decayed keys to values is also in `decay_values`
- `forward_chunkwise` uses the distributed form of equation 7 and 8
(actually look the two equations before it instead, they are easier to read)
https://arxiv.org/abs/2406.06484
[3] Linear Transformers Are Secretly Fast Weight Programmers (Schlag et al 2021)
Introduction to Transformers as RNNs. Ignore all of the kernel stuff.
https://arxiv.org/abs/2102.11174
"""
#%%
import os
os.environ['TORCH_LOGS'] = 'output_code'
import torch
from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange
#set_float32_matmul_precision('high')
def tileprint(K, name='K'):
"format matches tileprint in tk code so you can diff it"
assert K.shape == (16, 16)
for laneid in range(32):
row_top = laneid // 4
row_bottom = row_top + 8
col_left = laneid % 4 * 2
col_right = col_left + 8
def fmt(r,c,tag):
odd = "y" in tag
if odd: # do not print r for odd rows because cuda printf silently runs out of function arguments
return f"{name}[,{c:02}] {tag}={K[r,c]: .3f}"
else:
return f"{name}[{r:02},{c:02}] {tag}={K[r,c]: .3f}"
print(f"lane={laneid:02}", " ".join([
" ".join([fmt(row_top, col_left, "0x"), fmt(row_top, col_left+1, "0y")]),
" ".join([fmt(row_bottom, col_left, "1x"), fmt(row_bottom, col_left+1, "1y")]),
" ".join([fmt(row_top, col_right, "2x"), fmt(row_top, col_right+1, "2y")]),
" ".join([fmt(row_bottom, col_right, "3x"), fmt(row_bottom, col_right+1, "3y")])
]))
def decay_values(q, k, v, beta, chunk_size=2):
NH, T, D = shape(q, k, v, beta)
C = T // chunk_size
q_, k_, v_, beta_ = (
q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
)
# evaluate all chunks in parallel
beta__ = beta_.unsqueeze(-1)
w = beta__ * k_.clone()
u = beta__ * v_.clone()
K = einsum('nsd,ntd->nst', k_, k_) # (chunk_size,chunk_size) matrix
for t in range(1,chunk_size):
w[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone())
u[:, t] -= beta__[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())
# attend to decayed values
qk = einsum("nsk,ntk->nst", q_, k_)
qk.tril_()
y = einsum("nst,ntj->nsj", qk, u)
return w, u, y
def forward(q, k, v, beta, chunk_size=2):
"decay values applying deltanet forgetting rules, then stitch chunks"
NH, T, D = shape(q, k, v, beta)
C = T // chunk_size
w, u, y = decay_values(q, k, v, beta, chunk_size=chunk_size)
# stitch chunks sequentially
q_ = q.view(NH, C, chunk_size, D)
k_ = k.view(NH, C, chunk_size, D)
u = u.view(NH, C, chunk_size, D)
w = w.view(NH, C, chunk_size, D)
y = y.view(NH, C, chunk_size, D)
# materialize the state for the leading chunk
kc = k_[:, 0]
uc = u[:, 0]
state = u.new_zeros(NH, D, D)
for c in range(1, C):
state = state + einsum('ntv,ntk->nvk', uc, kc)
wc = w[:, c] # load w
uc = einsum('ntk,nvk->ntv', wc, state) # DDT
qc = q_[:, c] # load q
kc = k_[:, c] # load k
# attend to old values
qk = einsum("nsi,nti->nst", qc, kc) # TDT
qk = qk.tril()
yc = y[:, c].clone() # load y
y_prev = einsum("nst,ntv->nsv", qk, uc) # TTD
yc = yc - y_prev
y_cur = einsum('nsk,nvk->nsv', qc, state) # DDT
yc = yc + y_cur
y[:, c] = yc # store
u1 = u[:, c] # load u
uc = u1 - uc
w = w.view(NH, T, D)
u = u.view(NH, T, D)
y = y.view(NH, T, D)
return w, u, y
def forward_loop(q, k, v, beta):
"reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t"
NH, T, D = shape(q, k, v, beta)
w = k.new_zeros(NH, D, D)
y = []
for t in range(T):
q_ = q[:, t]
k_ = k[:, t]
v_ = v[:, t]
beta_ = beta[:, t].unsqueeze(-1)
v_old = einsum("nij,nj->ni", w, k_)
delta = beta_ * (v_ - v_old)
w = w + einsum("ni,nj->nij", delta, k_)
y.append(einsum("nij,nj->ni", w, q_))
return stack(y, dim=1)
def shape(q, k, v, beta=None):
NH, T, D = (q if q is not None else k).shape
if q is not None:
assert q.shape == (NH, T, D)
if v is not None:
assert k.shape == v.shape
if beta is not None:
assert beta.shape == (NH, T)
return NH, T, D
def make_example(NH, T, D, device='cpu', dtype=torch.float32):
manual_seed(0)
q = randn(NH, T, D, device=device, dtype=dtype) / D**0.5
q.requires_grad_()
k = randn(NH, T, D, device=device, dtype=dtype) / D**0.5
k.requires_grad_()
v = randn(NH, T, D, device=device, dtype=dtype) / D**0.5
v.requires_grad_()
beta = randn(NH, T, device=device, dtype=dtype).sigmoid()
beta.requires_grad_()
return q, k, v, beta
@no_grad()
def backward(d_out_w_long, d_out_u_long, d_out_y_long, q_long, k_long, v_long, beta_long, chunk_size=2):
NH, T, D = shape(q_long, k_long, v_long, beta_long)
C = T // chunk_size
q, k, v, beta, d_out_y = (
q_long.view(NH*C, chunk_size, D), k_long.view(NH*C, chunk_size, D),
v_long.view(NH*C, chunk_size, D), beta_long.view(NH*C, chunk_size),
d_out_y_long.view(NH*C, chunk_size, D)
)
#
# allocations
#
# this group is loaded from global memory
q = q.clone() # load q
k = k.clone() # load k
v = v.clone() # load v
beta = beta.clone() # load beta
#d_out_w = d_out_w.clone() # ntk # placeholders
#d_out_y = d_out_y.clone() # ntv # placeholders
w = k.new_zeros(NH*C, chunk_size, D) # ntk
u = v.new_zeros(NH*C, chunk_size, D) # ntw
w_bases = w.clone() # ntk
u_bases = u.clone() # ntw
bk = einsum('nt,ntk->ntk', beta, k)
bKl = k.new_zeros(NH*C, chunk_size, chunk_size)
tt = k.new_zeros(NH*C, chunk_size, chunk_size)
d_k = k.new_zeros(NH*C, chunk_size, D) # nsk
tk = k.new_zeros(NH*C, chunk_size, D) # ntk
#
# forward
#
tt = einsum('ntk,nsk->nts', k, k)
tt = tt.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
bKl = einsum('nt,nts->nts', beta, tt) # multiply each row of K by beta
u_bases = v
v = einsum('nt,ntw->ntw', beta, v)
for t in range(chunk_size):
tk = einsum('nts,nsk->ntk', bKl, w) # matmul for the sake of one row
w[:, t] = bk[:, t, :] - tk[:, t, :]
tk = einsum('nts,nsw->ntw', bKl, u) # matmul for the sake of one row
u[:, t] = v[:, t, :] - tk[:, t, :]
w.clone() # store w
u.clone() # store u
#
# stitch_backward
#
w_long = w.view(NH, T, D)
u_long = u.view(NH, T, D)
d_q_1_long, d_k_1_long, d_out_w_long, d_out_u_long = stitch_backward(d_out_y_long, q_long, k_long, w_long, u_long, C, chunk_size)
d_out_w, d_out_u = (
d_out_w_long.view(NH*C, chunk_size, D), d_out_u_long.view(NH*C, chunk_size, D)
)
w_bases = einsum('nts,nsk->ntk', tt, w)
w_bases = k - w_bases
v = einsum('nts,nsw->ntw', tt, u)
u_bases = u_bases - v
#
# causal_attend_backward for d_q, d_k_2, d_out_u
#
tt = einsum('nsv,ntv->nst', d_out_y, u)
tt = tt.tril()
d_q = einsum('nst,ntk->nsk', tt, k)
d_q.clone() # store
d_k_2 = einsum('nst,nsk->ntk', tt, q)
d_k_2.clone() # store to shared memory?
tt = einsum('nsk,ntk->nst', q, k)
tt = tt.tril()
v.zero_() # reuse register space of v for d_out_u
d_out_u = d_out_u.clone() # load ntw
d_out_u += einsum('nst,nsv->ntv', tt, d_out_y)
#
# backward for d_k, d_v, d_beta
#
d_k.zero_()
for t in range(chunk_size-1,-1,-1):
# d_k
tt = einsum('njw,ntw->njt', w, d_out_w) # matmul for the sake of one column t
tt[:, t:, :] = 0
tk = einsum('njt,njk->ntk', tt, k)
tt = einsum('njv,ntv->njt', u, d_out_u) # matmul for the sake of one column t
tt[:, t:, :] = 0
tk += einsum('njt,njk->ntk', tt, k)
d_k[:, t] += tk[:, t]
# backpropagate through time, updating only remaining timestamps
tt.zero_()
tt[:, t] += bKl[:, t]
tk = einsum('ntj,ntk->njk', tt, d_out_w)
d_out_w = d_out_w - tk
tk = einsum('ntj,ntk->njk', tt, d_out_u)
d_out_u = d_out_u - tk
d_k = d_out_w - d_k
d_k = einsum('ntk,nt->ntk', d_k, beta)
# decay w and u
tt = einsum('ntw,njw->ntj', d_out_w, w)
tt += einsum('ntw,njw->ntj', d_out_u, u)
tt.tril_(diagonal=-1)
tk = einsum('ntj,ntk->njk', tt, bk)
d_k = d_k - tk
d_k_2 = d_k_2.clone() # load from shared memory
d_k = d_k_2 + d_k
d_k = d_k.clone() # store
# d_beta
w_bases = einsum('ntk,ntk->ntk', w_bases, d_out_w)
u_bases = einsum('ntw,ntw->ntw', u_bases, d_out_u)
# d_v using d_out_u register
d_out_u = einsum('nt,ntv->ntv', beta, d_out_u)
d_v = d_out_u.clone() # store
# continue d_beta reusing the beta register
beta = einsum('ntk->nt', w_bases)
beta += einsum('ntv->nt', u_bases)
d_beta = beta.clone() # store
d_q_long = d_q.view(NH, T, D) + d_q_1_long
d_k_long = d_k.view(NH, T, D) + d_k_1_long
d_v_long = d_v.view(NH, T, D)
d_beta_long = d_beta.view(NH, T)
return d_q_long, d_k_long, d_v_long, d_beta_long
def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
NH, T, D = shape(q, k, None, None)
# outputs
d_q_ = q.new_zeros(NH, C, chunk_size, D)
d_k_ = k.new_zeros(NH, C, chunk_size, D)
d_w = w.new_zeros(NH, C, chunk_size, D)
d_u = u.new_zeros(NH, C, chunk_size, D)
# chunked inputs
d_y_delta = d_y_delta.view(NH, C, chunk_size, D)
q_ = q.view(NH, C, chunk_size, D)
k_ = k.view(NH, C, chunk_size, D)
w = w.view(NH, C, chunk_size, D)
# shared memory copy
u = u.view(NH, C, chunk_size, D).clone()
state = w.new_zeros(NH, D, D)
d_state = w.new_zeros(NH, D, D) # NHVK
state_delta = w.new_zeros(NH, D, D) # NHVK # can this be float32?
qk = k.new_zeros(NH, chunk_size, C)
tk = k.new_zeros(NH, chunk_size, D)
# materialize the state for the leading chunk
state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])
# stitch forward
for c in range(1, C):
tk = einsum('nvk,ntk->ntv', state, w[:, c])
u[:, c] = u[:, c] - tk
state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
if c < C-1:
state = state + state_delta # walk the state forwards
# from now on, u's are decayed
# stitch backward
for c in range(C-1, 0, -1):
if c < C-1:
state_delta = einsum('ntv,ntk->nvk', u[:, c], k_[:, c])
state = state - state_delta # uncompute the state backwards
tk = einsum('nvk,ntk->ntv', state, w[:, c]) # state_decay
d_y_delta_c = d_y_delta[:, c]
d_y_delta_c = -d_y_delta_c # neg
# d_q, d_k
qk = einsum('nsv,ntv->nst', d_y_delta_c, tk)
qk.tril_()
# d_q
tk = einsum('nst,ntk->nsk', qk, k_[:, c]) # causal_attend_backward for delta
tk.sub_(einsum('nsv,nvk->nsk', d_y_delta_c, state)) # prev_output
d_q_[:, c] = tk
# d_k
tk = einsum('nst,nsk->ntk', qk, q_[:, c])
if c < C-1:
tk.add_(einsum('nvk,ntv->ntk', d_state, u[:, c])) # state_add
else:
# d_state is zero
pass
d_k_[:, c] = tk
# d_u
if c < C-1:
d_u[:, c] = einsum('nvk,ntk->ntv', d_state, k_[:, c]) # state_add
else:
# d_state is zero
pass
# d_state_decays
qk = einsum('nsk,ntk->nst', q_[:, c], k_[:, c])
qk.tril_()
d_state_decays = einsum('nsv,nst->ntv', d_y_delta_c, qk)
if c < C-1:
d_state_decays.sub_(einsum('nvk,ntk->ntv', d_state, k_[:, c])) # state_add
# d_w
tk = einsum('ntv,nvk->ntk', d_state_decays, state)
d_w[:, c] = tk # state_decays
# backpropagate through time
d_state.sub_(einsum('nsv,nsk->nvk', d_y_delta_c, q_[:, c])) # prev_output
d_state.add_(einsum('ntv,ntk->nvk', d_state_decays, w[:, c])) # state_decays
tk = einsum('nvk,ntk->ntv', d_state, k_[:, 0])
d_u[:, 0] = tk # state_add
tk = einsum('nvk,ntv->ntk', d_state, u[:, 0])
d_k_[:, 0] = tk # state_add
return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)
class Delta(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, beta, chunk_size):
w, u, y = forward(q, k, v, beta, chunk_size)
ctx.save_for_backward(q, k, v, beta)
ctx.chunk_size = chunk_size
return y
@staticmethod
def backward(ctx, d_y):
q, k, v, beta = ctx.saved_tensors
NH, T, D = shape(q, k, v, beta)
d_w = k.new_zeros(NH, T, D)
d_u = v.new_zeros(NH, T, D)
d_q, d_k, d_v, d_beta = backward(d_w, d_u, d_y, q, k, v, beta, chunk_size=ctx.chunk_size)
return d_q, d_k, d_v, d_beta, None
def test_delta():
NH, T, D = 1, 64, 16
q1, k1, v1, beta1 = make_example(NH, T, D)
y0 = forward_loop(q1, k1, v1, beta1)
chunk_size = 8
w1, u1, y1 = forward(q1, k1, v1, beta1, chunk_size=chunk_size)
(y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()
assert allclose(y0, y1, atol=1e-5), 'y1 is wrong'
q, k, v, beta = make_example(NH, T, D)
y = Delta.apply(q, k, v, beta, chunk_size)
(y - torch.ones_like(y).detach()).pow(2).mean().backward()
assert allclose(y1, y, atol=1e-5), 'y is wrong'
# print(beta1.grad - beta.grad, 'beta.grad diff')
# print(q1.grad - q.grad, 'q.grad diff')
# print(k1.grad - k.grad, 'k.grad diff')
# print(v1.grad - v.grad, 'v.grad diff')
assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong'
assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong'
assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong'
assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'
if __name__ == '__main__':
test_delta()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment