Last active
February 28, 2024 22:08
-
-
Save rrampage/c2fe7a585c6639163eeaca749df4cac7 to your computer and use it in GitHub Desktop.
Arena stuff
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import os | |
os.environ["KMP_DUPLICATE_LIB_OK"] ="TRUE" | |
import sys | |
import re | |
import time | |
import torch as t | |
import numpy as np | |
from pathlib import Path | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from typing import Any, Callable, Iterator, Iterable, Optional, Union, Dict, List, Tuple | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
Arr = np.ndarray | |
grad_tracking_enabled = True | |
# Make sure exercises are in the path | |
chapter = r"chapter0_fundamentals" | |
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve() | |
section_dir = exercises_dir / "part4_backprop" | |
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir)) | |
import part4_backprop.tests as tests | |
from part4_backprop.utils import visualize, get_mnist | |
from plotly_utils import line | |
# %% | |
def log_back(grad_out: Arr, out: Arr, x: Arr) -> Arr: | |
'''Backwards function for f(x) = log(x) | |
grad_out: Gradient of some loss wrt out | |
out: the output of np.log(x). | |
x: the input of np.log. | |
Return: gradient of the given loss wrt x | |
''' | |
return grad_out/x | |
tests.test_log_back(log_back) | |
# %% | |
def unbroadcast(broadcasted: Arr, original: Arr) -> Arr: | |
''' | |
Sum 'broadcasted' until it has the shape of 'original'. | |
broadcasted: An array that was formerly of the same shape of 'original' and was expanded by broadcasting rules. | |
''' | |
# SOLUTION | |
# Step 1: sum and remove prepended dims, so both arrays have same number of dims | |
n_dims_to_sum = len(broadcasted.shape) - len(original.shape) | |
broadcasted = broadcasted.sum(axis=tuple(range(n_dims_to_sum))) | |
# Step 2: sum over dims which were originally 1 (but don't remove them) | |
dims_to_sum = tuple([ | |
i for i, (o, b) in enumerate(zip(original.shape, broadcasted.shape)) | |
if o == 1 and b > 1 | |
]) | |
broadcasted = broadcasted.sum(axis=dims_to_sum, keepdims=True) | |
return broadcasted | |
tests.test_unbroadcast(unbroadcast) | |
# %% | |
def multiply_back0(grad_out: Arr, out: Arr, x: Arr, y: Union[Arr, float]) -> Arr: | |
'''Backwards function for x * y wrt argument 0 aka x.''' | |
if not isinstance(y, Arr): | |
y = np.array(y) | |
# f.g(x) = f'(g(x)) * g(f'(x)) -> unbroadcast . (x grad_out, y) | |
# multiply ->(broardcast(x), broadcast(y)) | |
print(f'{x.shape=} {y.shape=} {out.shape=} {grad_out.shape=}') | |
# x*y -> wrt x -> broadcast(x,y) * y | |
# B'(x*y, x) | |
return unbroadcast(y * grad_out, x) | |
def multiply_back1(grad_out: Arr, out: Arr, x: Union[Arr, float], y: Arr) -> Arr: | |
'''Backwards function for x * y wrt argument 1 aka y.''' | |
if not isinstance(x, Arr): | |
x = np.array(x) | |
print(f'{x.shape=} {y.shape=} {out.shape=} {grad_out.shape=}') | |
return unbroadcast(x * grad_out, y) | |
tests.test_multiply_back(multiply_back0, multiply_back1) | |
tests.test_multiply_back_float(multiply_back0, multiply_back1) | |
# %% | |
def forward_and_back(a: Arr, b: Arr, c: Arr) -> Tuple[Arr, Arr, Arr]: | |
''' | |
g = log(f) | |
f = d * e | |
d = a * b | |
e = log(c) | |
Calculates the output of the computational graph above (g), then backpropogates the gradients and returns dg/da, dg/db, and dg/dc | |
dg/da | |
dg/db | |
dg/dc | |
''' | |
d = a * b | |
e = np.log(c) | |
f = d*e | |
g = np.log(f) | |
out = np.ones_like(g) | |
dg_df = log_back(out,g,f) | |
dg_dd = multiply_back0(dg_df, f, d, e) | |
dg_de = multiply_back1(dg_df, f, d ,e) | |
dg_dc = log_back(dg_de, e, c) | |
dg_da = multiply_back0(dg_dd, d, a ,b) | |
dg_db = multiply_back1(dg_dd, d, a ,b) | |
return (dg_da, dg_db, dg_dc,) | |
tests.test_forward_and_back(forward_and_back) | |
# %% | |
@dataclass(frozen=True) | |
class Recipe: | |
'''Extra information necessary to run backpropagation. You don't need to modify this.''' | |
func: Callable | |
"The 'inner' NumPy function that does the actual forward computation." | |
"Note, we call it 'inner' to distinguish it from the wrapper we'll create for it later on." | |
args: tuple | |
"The input arguments passed to func." | |
"For instance, if func was np.sum then args would be a length-1 tuple containing the tensor to be summed." | |
kwargs: Dict[str, Any] | |
"Keyword arguments passed to func." | |
"For instance, if func was np.sum then kwargs might contain 'dim' and 'keepdims'." | |
parents: Dict[int, "Tensor"] | |
"Map from positional argument index to the Tensor at that position, in order to be able to pass gradients back along the computational graph." | |
# %% | |
class BackwardFuncLookup: | |
def __init__(self) -> None: | |
self.data = {} | |
def add_back_func(self, forward_fn: Callable, arg_position: int, back_fn: Callable) -> None: | |
self.data[(forward_fn, arg_position)] = back_fn | |
def get_back_func(self, forward_fn: Callable, arg_position: int) -> Callable: | |
return self.data[(forward_fn, arg_position)] | |
BACK_FUNCS = BackwardFuncLookup() | |
BACK_FUNCS.add_back_func(np.log, 0, log_back) | |
BACK_FUNCS.add_back_func(np.multiply, 0, multiply_back0) | |
BACK_FUNCS.add_back_func(np.multiply, 1, multiply_back1) | |
assert BACK_FUNCS.get_back_func(np.log, 0) == log_back | |
assert BACK_FUNCS.get_back_func(np.multiply, 0) == multiply_back0 | |
assert BACK_FUNCS.get_back_func(np.multiply, 1) == multiply_back1 | |
print("Tests passed - BackwardFuncLookup class is working as expected!") | |
# %% | |
Arr = np.ndarray | |
class Tensor: | |
''' | |
A drop-in replacement for torch.Tensor supporting a subset of features. | |
''' | |
array: Arr | |
"The underlying array. Can be shared between multiple Tensors." | |
requires_grad: bool | |
"If True, calling functions or methods on this tensor will track relevant data for backprop." | |
grad: Optional["Tensor"] | |
"Backpropagation will accumulate gradients into this field." | |
recipe: Optional[Recipe] | |
"Extra information necessary to run backpropagation." | |
def __init__(self, array: Union[Arr, list], requires_grad=False): | |
self.array = array if isinstance(array, Arr) else np.array(array) | |
if self.array.dtype == np.float64: | |
self.array = self.array.astype(np.float32) | |
self.requires_grad = requires_grad | |
self.grad = None | |
self.recipe = None | |
"If not None, this tensor's array was created via recipe.func(*recipe.args, **recipe.kwargs)." | |
def __neg__(self) -> "Tensor": | |
return negative(self) | |
def __add__(self, other) -> "Tensor": | |
return add(self, other) | |
def __radd__(self, other) -> "Tensor": | |
return add(other, self) | |
def __sub__(self, other) -> "Tensor": | |
return subtract(self, other) | |
def __rsub__(self, other): | |
return subtract(other, self) | |
def __mul__(self, other) -> "Tensor": | |
return multiply(self, other) | |
def __rmul__(self, other) -> "Tensor": | |
return multiply(other, self) | |
def __truediv__(self, other) -> "Tensor": | |
return true_divide(self, other) | |
def __rtruediv__(self, other) -> "Tensor": | |
return true_divide(other, self) | |
def __matmul__(self, other) -> "Tensor": | |
return matmul(self, other) | |
def __rmatmul__(self, other) -> "Tensor": | |
return matmul(other, self) | |
def __eq__(self, other) -> "Tensor": | |
return eq(self, other) | |
def __repr__(self) -> str: | |
return f"Tensor({repr(self.array)}, requires_grad={self.requires_grad})" | |
def __len__(self) -> int: | |
if self.array.ndim == 0: | |
raise TypeError | |
return self.array.shape[0] | |
def __hash__(self) -> int: | |
return id(self) | |
def __getitem__(self, index) -> "Tensor": | |
return getitem(self, index) | |
def add_(self, other: "Tensor", alpha: float = 1.0) -> "Tensor": | |
add_(self, other, alpha=alpha) | |
return self | |
@property | |
def T(self) -> "Tensor": | |
return permute(self, axes=(-1, -2)) | |
def item(self): | |
return self.array.item() | |
def sum(self, dim=None, keepdim=False): | |
return sum(self, dim=dim, keepdim=keepdim) | |
def log(self): | |
return log(self) | |
def exp(self): | |
return exp(self) | |
def reshape(self, new_shape): | |
return reshape(self, new_shape) | |
def expand(self, new_shape): | |
return expand(self, new_shape) | |
def permute(self, dims): | |
return permute(self, dims) | |
def maximum(self, other): | |
return maximum(self, other) | |
def relu(self): | |
return relu(self) | |
def argmax(self, dim=None, keepdim=False): | |
return argmax(self, dim=dim, keepdim=keepdim) | |
def uniform_(self, low: float, high: float) -> "Tensor": | |
self.array[:] = np.random.uniform(low, high, self.array.shape) | |
return self | |
def backward(self, end_grad: Union[Arr, "Tensor", None] = None) -> None: | |
if isinstance(end_grad, Arr): | |
end_grad = Tensor(end_grad) | |
return backprop(self, end_grad) | |
def size(self, dim: Optional[int] = None): | |
if dim is None: | |
return self.shape | |
return self.shape[dim] | |
@property | |
def shape(self): | |
return self.array.shape | |
@property | |
def ndim(self): | |
return self.array.ndim | |
@property | |
def is_leaf(self): | |
'''Same as https://pytorch.org/docs/stable/generated/torch.Tensor.is_leaf.html''' | |
if self.requires_grad and self.recipe and self.recipe.parents: | |
return False | |
return True | |
def __bool__(self): | |
if np.array(self.shape).prod() != 1: | |
raise RuntimeError("bool value of Tensor with more than one value is ambiguous") | |
return bool(self.item()) | |
def empty(*shape: int) -> Tensor: | |
'''Like torch.empty.''' | |
return Tensor(np.empty(shape)) | |
def zeros(*shape: int) -> Tensor: | |
'''Like torch.zeros.''' | |
return Tensor(np.zeros(shape)) | |
def arange(start: int, end: int, step=1) -> Tensor: | |
'''Like torch.arange(start, end).''' | |
return Tensor(np.arange(start, end, step=step)) | |
def tensor(array: Arr, requires_grad=False) -> Tensor: | |
'''Like torch.tensor.''' | |
return Tensor(array, requires_grad=requires_grad) | |
# %% | |
def log_forward(x: Tensor) -> Tensor: | |
'''Performs np.log on a Tensor object.''' | |
is_grad_req = grad_tracking_enabled and (x.requires_grad or x.recipe is not None) | |
out = Tensor(array=np.log(x.array), requires_grad=is_grad_req) | |
if is_grad_req: | |
out.recipe = Recipe(func=np.log, args=(x.array,), kwargs={}, parents={0: x}) | |
return out | |
log = log_forward | |
tests.test_log(Tensor, log_forward) | |
tests.test_log_no_grad(Tensor, log_forward) | |
a = Tensor([1], requires_grad=True) | |
grad_tracking_enabled = False | |
b = log_forward(a) | |
grad_tracking_enabled = True | |
assert not b.requires_grad, "should not require grad if grad tracking globally disabled" | |
assert b.recipe is None, "should not create recipe if grad tracking globally disabled" | |
# %% | |
def multiply_forward(a: Union[Tensor, int], b: Union[Tensor, int]) -> Tensor: | |
'''Performs np.multiply on a Tensor object.''' | |
assert isinstance(a, Tensor) or isinstance(b, Tensor) | |
is_grad_req = grad_tracking_enabled | |
array = None | |
recipe = None | |
if isinstance(a, Tensor) and isinstance(b, Tensor): | |
array = a.array * b.array | |
recipe = Recipe(func=np.multiply, args=(a.array, b.array), kwargs={}, parents={0: a, 1: b}) | |
is_grad_req = is_grad_req and (a.requires_grad or a.recipe is not None or b.requires_grad or b.recipe is not None) | |
elif isinstance(a, Tensor): | |
array = a.array * b | |
recipe = Recipe(func=np.multiply, args=(a.array, b), kwargs={}, parents={0: a}) | |
is_grad_req = is_grad_req and (a.requires_grad or a.recipe is not None) | |
else: | |
array = b.array * a | |
recipe = Recipe(func=np.multiply, args=(a, b.array), kwargs={}, parents={1: b}) | |
is_grad_req = is_grad_req and (b.requires_grad or b.recipe is not None) | |
out = Tensor(array=array, requires_grad=is_grad_req) | |
if is_grad_req: | |
out.recipe = recipe | |
return out | |
multiply = multiply_forward | |
tests.test_multiply(Tensor, multiply_forward) | |
tests.test_multiply_no_grad(Tensor, multiply_forward) | |
tests.test_multiply_float(Tensor, multiply_forward) | |
a = Tensor([2], requires_grad=True) | |
b = Tensor([3], requires_grad=True) | |
grad_tracking_enabled = False | |
b = multiply_forward(a, b) | |
grad_tracking_enabled = True | |
assert not b.requires_grad, "should not require grad if grad tracking globally disabled" | |
assert b.recipe is None, "should not create recipe if grad tracking globally disabled" | |
# %% | |
def wrap_forward_fn(numpy_func: Callable, is_differentiable=True) -> Callable: | |
''' | |
numpy_func: Callable | |
takes any number of positional arguments, some of which may be NumPy arrays, and | |
any number of keyword arguments which we aren't allowing to be NumPy arrays at | |
present. It returns a single NumPy array. | |
is_differentiable: | |
if True, numpy_func is differentiable with respect to some input argument, so we | |
may need to track information in a Recipe. If False, we definitely don't need to | |
track information. | |
Return: Callable | |
It has the same signature as numpy_func, except wherever there was a NumPy array, | |
this has a Tensor instead. | |
''' | |
def tensor_func(*args: Any, **kwargs: Any) -> Tensor: | |
req_grad = is_differentiable and grad_tracking_enabled and any([(isinstance(x, Tensor) and (x.requires_grad or x.recipe is not None)) for x in args]) | |
in_args = [a.array if isinstance(a, Tensor) else a for a in args] | |
out = Tensor(array=numpy_func(*in_args, **kwargs), requires_grad=req_grad ) | |
if req_grad: | |
parents = {idx: arr for idx, arr in enumerate(args) if isinstance(arr, Tensor)} | |
out.recipe = Recipe(func=numpy_func, args=in_args, kwargs=kwargs, parents=parents) | |
#print(f'{args=} {kwargs=} {out.array=}') | |
return out | |
return tensor_func | |
def _sum(x: Arr, dim=None, keepdim=False) -> Arr: | |
# need to be careful with sum, because kwargs have different names in torch and numpy | |
return np.sum(x, axis=dim, keepdims=keepdim) | |
log = wrap_forward_fn(np.log) | |
multiply = wrap_forward_fn(np.multiply) | |
eq = wrap_forward_fn(np.equal, is_differentiable=False) | |
sum = wrap_forward_fn(_sum) | |
tests.test_log(Tensor, log) | |
tests.test_log_no_grad(Tensor, log) | |
tests.test_multiply(Tensor, multiply) | |
tests.test_multiply_no_grad(Tensor, multiply) | |
tests.test_multiply_float(Tensor, multiply) | |
tests.test_sum(Tensor) | |
# %% | |
class Node: | |
def __init__(self, *children): | |
self.children = list(children) | |
def get_children(node: Node) -> List[Node]: | |
return node.children | |
def topological_sort(node: Node, get_children: Callable) -> List[Node]: | |
''' | |
Return a list of node's descendants in reverse topological order from future to past (i.e. `node` should be last). | |
Should raise an error if the graph with `node` as root is not in fact acyclic. | |
''' | |
seen = [] | |
visted_set= set() | |
temp = set() | |
def dfs(node): | |
if node in visted_set: | |
return | |
if node in temp: | |
raise ValueError("Cycle!!") | |
temp.add(node) | |
for neighbor in get_children(node): | |
dfs(neighbor) | |
seen.append(node) | |
visted_set.add(node) | |
temp.remove(node) | |
dfs(node) | |
print(f'{seen=}') | |
return seen | |
tests.test_topological_sort_linked_list(topological_sort) | |
tests.test_topological_sort_branching(topological_sort) | |
tests.test_topological_sort_rejoining(topological_sort) | |
tests.test_topological_sort_cyclic(topological_sort) | |
# %% | |
def sorted_computational_graph(tensor: Tensor) -> List[Tensor]: | |
''' | |
For a given tensor, return a list of Tensors that make up the nodes of the given Tensor's computational graph, | |
in reverse topological order (i.e. `tensor` should be first). | |
''' | |
get_parents = lambda tensor: [] if not tensor.requires_grad or tensor.recipe is None else list(tensor.recipe.parents.values()) | |
return topological_sort(tensor, get_parents)[::-1] | |
a = Tensor([1], requires_grad=True) | |
b = Tensor([2], requires_grad=True) | |
c = Tensor([3], requires_grad=True) | |
d = a * b | |
e = c.log() | |
f = d * e | |
g = f.log() | |
name_lookup = {a: "a", b: "b", c: "c", d: "d", e: "e", f: "f", g: "g"} | |
print([name_lookup[t] for t in sorted_computational_graph(g)]) | |
# %% | |
def backprop(end_node: Tensor, end_grad: Optional[Tensor] = None) -> None: | |
'''Accumulates gradients in the grad field of each leaf node. | |
tensor.backward() is equivalent to backprop(tensor). | |
end_node: | |
The rightmost node in the computation graph. | |
If it contains more than one element, end_grad must be provided. | |
end_grad: | |
A tensor of the same shape as end_node. | |
Set to 1 if not specified and end_node has only one element. | |
''' | |
# Get value of end_grad_arr | |
end_grad = end_grad if end_grad is not None else Tensor(np.ones_like(end_node.array)) | |
end_grad_arr = end_grad.array | |
# Create dictionary 'grads' to store gradients | |
grads = {} | |
grads[end_node] = end_grad_arr | |
# Iterate through the computational graph, using your sorting function | |
for node in sorted_computational_graph(end_node): | |
# Get the outgradient from the grads dict | |
outgradient = grads[node] | |
# If this node is a leaf & requires_grad is true, then store the gradient | |
if node.is_leaf and node.requires_grad: | |
if node.grad is None: | |
node.grad = Tensor(outgradient) | |
else: | |
node.grad.array += outgradient | |
if node.recipe is None or node.recipe.parents is None: | |
continue | |
# For all parents in the node: | |
# If node has a recipe, then we iterate through parents (which is a dict of {arg_posn: tensor}) | |
for argnum, parent in node.recipe.parents.items(): | |
# Get the backward function corresponding to the function that created this node | |
back_fn = BACK_FUNCS.get_back_func(node.recipe.func, argnum) | |
# Use this backward function to calculate the gradient | |
gnp = back_fn(outgradient, node.array, *node.recipe.args, **node.recipe.kwargs) | |
# Add the gradient to this node in the dictionary `grads` | |
if parent not in grads: | |
grads[parent] = gnp | |
else: | |
grads[parent] += gnp | |
tests.test_backprop(Tensor) | |
tests.test_backprop_branching(Tensor) | |
tests.test_backprop_requires_grad_false(Tensor) | |
tests.test_backprop_float_arg(Tensor) | |
tests.test_backprop_shared_parent(Tensor) | |
# %% | |
def negative_back(grad_out: Arr, out: Arr, x: Arr) -> Arr: | |
'''Backward function for f(x) = -x elementwise.''' | |
pass | |
negative = wrap_forward_fn(np.negative) | |
BACK_FUNCS.add_back_func(np.negative, 0, negative_back) | |
tests.test_negative_back(Tensor) | |
# %% | |
MAIN = __name__ == "__main__" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import os | |
import sys | |
import numpy as np | |
import einops | |
from typing import Union, Optional, Tuple, List, Dict | |
import torch as t | |
from torch import Tensor | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from jaxtyping import Float, Int | |
import functools | |
from pathlib import Path | |
from torchvision import datasets, transforms, models | |
from torch.utils.data import DataLoader, Subset | |
from tqdm.notebook import tqdm | |
from dataclasses import dataclass | |
from PIL import Image | |
import json | |
# Make sure exercises are in the path | |
chapter = r"chapter0_fundamentals" | |
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve() | |
section_dir = exercises_dir / "part2_cnns" | |
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir)) | |
from plotly_utils import imshow, line, bar | |
import part2_cnns.tests as tests | |
from part2_cnns.utils import print_param_count | |
# device = t.device("cuda" if t.cuda.is_available() else "cpu") | |
device = t.device("mps" if t.backends.mps.is_available() else "cpu") | |
# %% | |
class ReLU(nn.Module): | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
# return t.maximum(x, t.tensor(0.0)) | |
return x.max(t.zeros_like(x)) | |
tests.test_relu(ReLU) | |
# %% | |
class Linear(nn.Module): | |
def __init__(self, in_features: int, out_features: int, bias=True): | |
''' | |
A simple linear (technically, affine) transformation. | |
The fields should be named `weight` and `bias` for compatibility with PyTorch. | |
If `bias` is False, set `self.bias` to None. | |
''' | |
super().__init__() | |
n = in_features**-0.5 | |
self.weight = nn.Parameter(t.zeros([out_features, in_features]).uniform_(-n, n)) | |
self.bias = nn.Parameter(t.zeros([out_features]).uniform_(-n, n)) if bias else None | |
self.is_bias = bias | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
x: shape (*, in_features) | |
Return: shape (*, out_features) | |
''' | |
res = x @ self.weight.T | |
return res if self.bias is None else res + self.bias | |
def extra_repr(self) -> str: | |
return f'Weight: {self.weight.shape=}' | |
tests.test_linear_forward(Linear) | |
tests.test_linear_parameters(Linear) | |
tests.test_linear_no_bias(Linear) | |
# %% | |
class Flatten(nn.Module): | |
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: | |
super().__init__() | |
self.start_dim = start_dim | |
self.end_dim = end_dim | |
def forward(self, input: t.Tensor) -> t.Tensor: | |
''' | |
Flatten out dimensions from start_dim to end_dim, inclusive of both. | |
''' | |
s = input.shape #(x,y,z) -> (x, y*z) | |
sd = self.start_dim | |
ed = self.end_dim + 1 | |
if self.start_dim < 0: | |
sd = len(s) + sd | |
if self.end_dim < 0: | |
ed = len(s) + ed | |
p = 1 | |
for x in s[sd:ed]: | |
p *= x | |
new_shape = list(s[:sd]) + [p] + list(s[ed:]) | |
return t.reshape(input, new_shape) | |
def extra_repr(self) -> str: | |
return f'Flatten: {self.start_dim=} {self.end_dim=}' | |
tests.test_flatten(Flatten) | |
# %% | |
class SimpleMLP(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.flatten = Flatten(1, -1) | |
self.layer0 = Linear(784, 100, True) | |
self.relu = ReLU() | |
self.layer1 = Linear(100, 10, True) | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
return self.layer1(self.relu(self.layer0(self.flatten(x)))) | |
tests.test_mlp(SimpleMLP) | |
# %% | |
MNIST_TRANSFORM = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
def get_mnist(subset: int = 1): | |
'''Returns MNIST training data, sampled by the frequency given in `subset`.''' | |
mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=MNIST_TRANSFORM) | |
mnist_testset = datasets.MNIST(root="./data", train=False, download=True, transform=MNIST_TRANSFORM) | |
if subset > 1: | |
mnist_trainset = Subset(mnist_trainset, indices=range(0, len(mnist_trainset), subset)) | |
mnist_testset = Subset(mnist_testset, indices=range(0, len(mnist_testset), subset)) | |
return mnist_trainset, mnist_testset | |
mnist_trainset, mnist_testset = get_mnist() | |
mnist_trainloader = DataLoader(mnist_trainset, batch_size=64, shuffle=True) | |
mnist_testloader = DataLoader(mnist_testset, batch_size=64, shuffle=False) | |
# %% | |
@dataclass | |
class SimpleMLPTrainingArgs(): | |
''' | |
Defining this class implicitly creates an __init__ method, which sets arguments as | |
given below, e.g. self.batch_size = 64. Any of these arguments can also be overridden | |
when you create an instance, e.g. args = SimpleMLPTrainingArgs(batch_size=128). | |
''' | |
batch_size: int = 64 | |
epochs: int = 3 | |
learning_rate: float = 1e-3 | |
subset: int = 10 | |
@t.inference_mode() | |
def validate(model: SimpleMLP, loader: DataLoader): | |
batch_loss = 0 | |
batch_size = 0 | |
for imgs, labels in loader: | |
imgs = imgs.to(device) | |
labels = labels.to(device) | |
logits = model(imgs) | |
loss = (logits.argmax(dim=-1) == labels).sum() | |
batch_loss += loss | |
batch_size += imgs.shape[0] | |
return (batch_loss/batch_size).item() | |
def train(args: SimpleMLPTrainingArgs): | |
''' | |
Trains the model, using training parameters from the `args` object. | |
''' | |
model = SimpleMLP().to(device) | |
mnist_trainset, mnist_testset = get_mnist(subset=args.subset) | |
mnist_trainloader = DataLoader(mnist_trainset, batch_size=args.batch_size, shuffle=True) | |
mnist_testloader = DataLoader(mnist_testset, batch_size=args.batch_size, shuffle=True) | |
optimizer = t.optim.Adam(model.parameters(), lr=args.learning_rate) | |
loss_list = [] | |
validation_loss = [] | |
validation_loss.append(validate(model, mnist_testloader)) | |
for epoch in tqdm(range(args.epochs)): | |
for imgs, labels in mnist_trainloader: | |
imgs = imgs.to(device) | |
labels = labels.to(device) | |
logits = model(imgs) | |
loss = F.cross_entropy(logits, labels) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
loss_list.append(loss.item()) | |
validation_loss.append(validate(model, mnist_testloader)) | |
line( | |
loss_list, | |
yaxis_range=[0, max(loss_list) + 0.1], | |
labels={"x": "Num batches seen", "y": "Cross entropy loss"}, | |
title="SimpleMLP training on MNIST", | |
width=700 | |
) | |
line( | |
validation_loss, | |
yaxis_range=[0, max(validation_loss) + 0.1], | |
labels={"x": "Num batches seen", "y": "Validation loss"}, | |
title="SimpleMLP Validation loss training on MNIST", | |
width=700 | |
) | |
args = SimpleMLPTrainingArgs() | |
train(args) | |
# %% | |
class Conv2d(nn.Module): | |
def __init__( | |
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0 | |
): | |
''' | |
Same as torch.nn.Conv2d with bias=False. | |
Name your weight field `self.weight` for compatibility with the PyTorch version. | |
''' | |
super().__init__() | |
print(f'Stride: {stride=}') | |
if stride <= 0: | |
print("AAAAAAAAAAAAA") | |
self.stride = stride | |
self.padding = padding | |
self.kernel_size = kernel_size | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
# TODO kernel | |
n_in = in_channels*kernel_size*kernel_size | |
n_out = out_channels*kernel_size*kernel_size | |
n = (6**0.5)/ (n_in+n_out+1)**0.5 | |
self.kernel = t.zeros([out_channels, in_channels, kernel_size, kernel_size]).uniform_(-n, n) | |
self.weight = nn.Parameter(self.kernel) | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
'''Apply the functional conv2d, which you can import.''' | |
if self.stride <= 0: | |
print("AAAAAAAAAAAAA:)))))))") | |
print(self.extra_repr()) | |
return nn.functional.conv2d(input=x, weight=self.weight, stride=self.stride, padding=self.padding, bias=None) | |
def extra_repr(self) -> str: | |
return f'Conv2d: {self.in_channels=} {self.out_channels=} {self.stride=}' | |
tests.test_conv2d_module(Conv2d) | |
m = Conv2d(in_channels=24, out_channels=12, kernel_size=3, stride=2, padding=1) | |
print(f"Manually verify that this is an informative repr: {m}") | |
# %% | |
class MaxPool2d(nn.Module): | |
def __init__(self, kernel_size: int, stride: Optional[int] = None, padding: int = 1): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.padding = padding | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
'''Call the functional version of max_pool2d.''' | |
# result is not as smooth as for mean-pooling | |
# return einops.reduce(x, 'b (h h2) (w w2) c -> h (b w) c', 'max', h2=self.kernel_size, w2=self.kernel_size) | |
return nn.functional.max_pool2d(x, self.kernel_size, stride=self.stride, padding=self.padding) | |
def extra_repr(self) -> str: | |
'''Add additional information to the string representation of this class.''' | |
return ", ".join([f"{key}={getattr(self, key)}" for key in ["kernel_size", "stride", "padding"]]) | |
tests.test_maxpool2d_module(MaxPool2d) | |
m = MaxPool2d(kernel_size=3, stride=2, padding=1) | |
print(f"Manually verify that this is an informative repr: {m}") | |
# %% | |
class Sequential(nn.Module): | |
_modules: Dict[str, nn.Module] | |
def __init__(self, *modules: nn.Module): | |
super().__init__() | |
for index, mod in enumerate(modules): | |
self._modules[str(index)] = mod | |
def __getitem__(self, index: int) -> nn.Module: | |
index %= len(self._modules) # deal with negative indices | |
return self._modules[str(index)] | |
def __setitem__(self, index: int, module: nn.Module) -> None: | |
index %= len(self._modules) # deal with negative indices | |
self._modules[str(index)] = module | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
'''Chain each module together, with the output from one feeding into the next one.''' | |
for mod in self._modules.values(): | |
x = mod(x) | |
return x | |
# %% | |
class BatchNorm2d(nn.Module): | |
# The type hints below aren't functional, they're just for documentation | |
running_mean: Float[Tensor, "num_features"] | |
running_var: Float[Tensor, "num_features"] | |
num_batches_tracked: Int[Tensor, ""] # This is how we denote a scalar tensor | |
def __init__(self, num_features: int, eps=1e-05, momentum=0.1): | |
''' | |
Like nn.BatchNorm2d with track_running_stats=True and affine=True. | |
Name the learnable affine parameters `weight` and `bias` in that order. | |
''' | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
self.register_buffer("running_mean", t.zeros(num_features)) | |
self.register_buffer("running_var", t.ones(num_features)) | |
self.register_buffer("num_batches_tracked", t.tensor(0)) | |
self.weight = nn.Parameter(t.ones(num_features)) | |
self.bias = nn.Parameter(t.zeros(num_features)) | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
Normalize each channel. | |
Compute the variance using `torch.var(x, unbiased=False)` | |
Hint: you may also find it helpful to use the argument `keepdim`. | |
x: shape (batch, channels, height, width) | |
Return: shape (batch, channels, height, width) | |
''' | |
# running_mean <- (1 - momentum) * running_mean + momentum * new_mean | |
m = self.running_mean | |
v = self.running_var | |
if self.training: | |
m = x.mean(dim=(0,2,3)) | |
v = t.var(x, unbiased=False, dim=(0,2,3)) | |
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * m | |
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * v | |
self.num_batches_tracked += 1 | |
return ((x - m.view(1, -1, 1, 1))/(v.view(1, -1, 1, 1) + self.eps)**0.5)*self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) | |
def extra_repr(self) -> str: | |
'''Add additional information to the string representation of this class.''' | |
return ", ".join([f"{key}={getattr(self, key)}" for key in ["num_features","eps", "momentum"]]) | |
tests.test_batchnorm2d_module(BatchNorm2d) | |
tests.test_batchnorm2d_forward(BatchNorm2d) | |
tests.test_batchnorm2d_running_mean(BatchNorm2d) | |
# %% | |
class AveragePool(nn.Module): | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
x: shape (batch, channels, height, width) | |
Return: shape (batch, channels) | |
''' | |
return x.mean(dim=(2,3)) | |
# %% | |
class ResidualBlock(nn.Module): | |
def __init__(self, in_feats: int, out_feats: int, first_stride=1): | |
''' | |
A single residual block with optional downsampling. | |
For compatibility with the pretrained model, declare the left side branch first using a `Sequential`. | |
If first_stride is > 1, this means the optional (conv + bn) should be present on the right branch. Declare it second using another `Sequential`. | |
''' | |
super().__init__() | |
self.first_stride = first_stride | |
self.in_feats = in_feats | |
self.out_feats = out_feats | |
self.left = Sequential( | |
Conv2d(in_channels=in_feats, out_channels=out_feats, stride=first_stride, kernel_size=3, padding=1), | |
BatchNorm2d(num_features=out_feats), | |
ReLU(), | |
Conv2d(in_channels=out_feats, out_channels=out_feats, stride=1,kernel_size=3, padding=1), | |
BatchNorm2d(num_features=out_feats)) | |
if (first_stride > 1): | |
self.right = Sequential( | |
Conv2d(in_channels=in_feats, out_channels=out_feats, stride=first_stride,kernel_size=1, padding=0), | |
BatchNorm2d(num_features=out_feats) | |
) | |
else: | |
self.right = nn.Identity() | |
self.relu = ReLU() | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
Compute the forward pass. | |
x: shape (batch, in_feats, height, width) | |
Return: shape (batch, out_feats, height / stride, width / stride) | |
If no downsampling block is present, the addition should just add the left branch's output to the input. | |
''' | |
return self.relu(self.left(x) + self.right(x)) | |
# %% | |
class BlockGroup(nn.Module): | |
def __init__(self, n_blocks: int, in_feats: int, out_feats: int, first_stride=1): | |
'''An n_blocks-long sequence of ResidualBlock where only the first block uses the provided stride.''' | |
super().__init__() | |
self.n_blocks = n_blocks | |
self.in_feats = in_feats | |
self.out_feats = out_feats | |
self.first_stride = first_stride | |
seq = [ResidualBlock(in_feats=in_feats, out_feats=out_feats, first_stride=first_stride)] + [ResidualBlock(in_feats=out_feats, out_feats=out_feats, first_stride=1) for x in range(n_blocks-1)] | |
self.weight = Sequential( | |
*seq | |
) | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
Compute the forward pass. | |
x: shape (batch, in_feats, height, width) | |
Return: shape (batch, out_feats, height / first_stride, width / first_stride) | |
''' | |
return self.weight(x) | |
# %% | |
class ResNet34(nn.Module): | |
def __init__( | |
self, | |
n_blocks_per_group=[3, 4, 6, 3], | |
out_features_per_group=[64, 128, 256, 512], | |
first_strides_per_group=[1, 2, 2, 2], | |
n_classes=1000, | |
): | |
super().__init__() | |
self.n_blocks_per_group = n_blocks_per_group | |
self.out_features_per_group = out_features_per_group | |
self.first_strides_per_group = first_strides_per_group | |
self.n_classes = n_classes | |
head_seq = [Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3), BatchNorm2d(num_features=64), ReLU(), MaxPool2d(kernel_size=3, stride=2)] | |
all_in_feats = [64] + out_features_per_group[:-1] | |
in_features_per_group = [x//2 for x in out_features_per_group] | |
block_grp_seq = [BlockGroup(*params) for params in zip(n_blocks_per_group, all_in_feats, out_features_per_group, first_strides_per_group)] | |
self.head = Sequential(*head_seq) | |
self.block_grp_seq = Sequential(*block_grp_seq) | |
tail_seq = [AveragePool(), Flatten(1,-1), Linear(512,1000)] | |
self.tail = Sequential(*tail_seq) | |
def forward(self, x: t.Tensor) -> t.Tensor: | |
''' | |
x: shape (batch, channels, height, width) | |
Return: shape (batch, n_classes) | |
''' | |
x = self.head(x) | |
x = self.block_grp_seq(x) | |
return self.tail(x) | |
my_resnet = ResNet34() | |
# %% | |
from pprint import pprint as pp | |
def copy_weights(my_resnet: ResNet34, pretrained_resnet: models.resnet.ResNet) -> ResNet34: | |
'''Copy over the weights of `pretrained_resnet` to your resnet.''' | |
# Get the state dictionaries for each model, check they have the same number of parameters & buffers | |
mydict = my_resnet.state_dict() | |
pretraineddict = pretrained_resnet.state_dict() | |
assert len(mydict) == len(pretraineddict), "Mismatching state dictionaries." | |
# Define a dictionary mapping the names of your parameters / buffers to their values in the pretrained model | |
state_dict_to_load = { | |
mykey: pretrainedvalue | |
for (mykey, myvalue), (pretrainedkey, pretrainedvalue) in zip(mydict.items(), pretraineddict.items()) | |
} | |
# Load in this dictionary to your model | |
my_resnet.load_state_dict(state_dict_to_load) | |
return my_resnet | |
pretrained_resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1) | |
my_resnet = copy_weights(my_resnet, pretrained_resnet) | |
print("Success!") | |
# %% | |
IMAGE_FILENAMES = [ | |
"chimpanzee.jpg", | |
"golden_retriever.jpg", | |
"platypus.jpg", | |
"frogs.jpg", | |
"fireworks.jpg", | |
"astronaut.jpg", | |
"iguana.jpg", | |
"volcano.jpg", | |
"goofy.jpg", | |
"dragonfly.jpg", | |
] | |
IMAGE_FOLDER = section_dir / "resnet_inputs" | |
images = [Image.open(IMAGE_FOLDER / filename) for filename in IMAGE_FILENAMES] | |
IMAGE_SIZE = 224 | |
IMAGENET_MEAN = [0.485, 0.456, 0.406] | |
IMAGENET_STD = [0.229, 0.224, 0.225] | |
IMAGENET_TRANSFORM = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), | |
]) | |
prepared_images = t.stack([IMAGENET_TRANSFORM(img) for img in images], dim=0) | |
assert prepared_images.shape == (len(images), 3, IMAGE_SIZE, IMAGE_SIZE) | |
# %% | |
def predict(model, images: t.Tensor) -> t.Tensor: | |
''' | |
Returns the predicted class for each image (as a 1D array of ints). | |
''' | |
return t.argmax(model(images), dim=-1) | |
with open(section_dir / "imagenet_labels.json") as f: | |
imagenet_labels = list(json.load(f).values()) | |
# Check your predictions match those of the pretrained model | |
my_predictions = predict(my_resnet, prepared_images) | |
pretrained_predictions = predict(pretrained_resnet, prepared_images) | |
assert all(my_predictions == pretrained_predictions) | |
print("All predictions match!") | |
# Print out your predictions, next to the corresponding images | |
for img, label in zip(images, my_predictions): | |
print(f"Class {label}: {imagenet_labels[label]}") | |
display(img) | |
print() | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment