This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory. | |
For a simple single chunk version see `forward_simple`. | |
It computes decayed values by a little bit of recurrence (`decay_values`) | |
and then applies linear attention (`causal_attend`). | |
`forward_chunkwise` is inspired by Yang 2024. It applies single chunk version pointwise and | |
then performs chunk-level stitching. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
digit 10076 27 | |
digit 10154 2011 | |
digit 1017 4 | |
digit 10191 33 | |
digit 1025 5 | |
digit 10353 31 | |
digit 10389 2008 | |
digit 10411 120 | |
digit 10607 01 | |
digit 10858 195 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Randomized Binary Search Trees | |
https://www.cs.upc.edu/~conrado/research/papers/jacm-mr98.pdf | |
""" | |
import math | |
import random | |
from collections import Counter | |
class root: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory. | |
`forward` is inspired by Yang 2024. It applies single chunk version pointwise and then performs chunk-level stitching. | |
`forward_loop` is the reference implementation of the original recurrence. | |
References: | |
[1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985) | |
Method 1, section 3 guides `decay_values`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// uses https://github.com/HazyResearch/ThunderKittens | |
#include "tk/src/kittens.cuh" | |
#include "tk/src/common/pyutils/torch_helpers.cuh" | |
#define NUM_WORKERS 2 // This kernel uses this many workers in parallel per block, to help issue instructions more quickly. | |
#define DIMENSION 64 // This kernel operates over 64-dimensional vectors | |
#define DEBUG 0 | |
using namespace kittens; // this kernel only handles headdim=q_reg.cols for simplicity. Also n should be a multiple of 256 here. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"a linear RNN that receives ones as input and gives increasingly better approximations to pi as output" | |
import numpy as np | |
import math | |
def binary(digits: int): | |
"Make a basis of powers of two of dimension `digits`, lowest bits first" | |
return 1 << np.arange(digits) | |
def leibniz(n): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import math | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
plt.rcParams['axes.spines.left'] = False | |
plt.rcParams['axes.spines.right'] = False | |
plt.rcParams['axes.spines.top'] = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
from collections import defaultdict | |
import bisect | |
import json | |
import matplotlib.pyplot as plt | |
import torch | |
from matplotlib import rcParams | |
rcParams['font.family'] = 'serif' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# prompt: https://twitter.com/francoisfleuret/status/1783479122418716805 | |
import os | |
os.environ['TORCH_LOGS'] = 'output_code' # shows all the bmms | |
import torch | |
torch.set_float32_matmul_precision('high') | |
N, T, D, U, C = 3, 128, 5, 32, 32 # batch, time, heads, head_dim, dim | |
S = T | |
A = torch.randn(N, T, D, U) / U**0.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float() | |
y = torch.logical_xor(X[:, 0], X[:, 1]).float() |
NewerOlder