Skip to content

Instantly share code, notes, and snippets.

@awni
Last active October 4, 2024 04:33
Show Gist options
  • Save awni/fde217c67e6be098e0773d3a7de93f02 to your computer and use it in GitHub Desktop.
Save awni/fde217c67e6be098e0773d3a7de93f02 to your computer and use it in GitHub Desktop.
Conway's Game of Life Accelerated with Custom Kernels in MLX
import numpy as np
import mlx.core as mx
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import tqdm
def conway(a: mx.array):
source = """
uint i = thread_position_in_grid.x;
uint j = thread_position_in_grid.y;
uint n = threads_per_grid.x;
uint m = threads_per_grid.y;
uint down = (i == 0) ? n : (i - 1);
uint up = (i + 1) == n ? 0 : (i + 1);
uint left = (j == 0) ? m : (j - 1);
uint right = (j + 1) == m ? 0 : (j + 1);
size_t idx = i * m + j;
int count = grid[up * m + right] + grid[up * m + j]
+ grid[i * m + right] + grid[up * m + left] + grid[down * m + left]
+ grid[down * m + j] + grid[i * m + left] + grid[down * m + right];
if ((grid[idx] && count == 2) || count == 3) {
out[idx] = true;
} else {
out[idx] = false;
}
"""
kernel = mx.fast.metal_kernel(
name="conway",
input_names=["grid"],
output_names=["out"],
source=source,
)
return kernel(
inputs=[a],
grid=(a.shape[0], a.shape[1], 1),
threadgroup=(2, 512, 1),
output_shapes=[a.shape],
output_dtypes=[a.dtype],
)[0]
def generator(grid, steps=1000):
for i in range(steps):
mx.eval(grid)
yield (~grid).astype(mx.uint8) * 255
grid = conway(grid)
def animate(grid, steps=300, fps=30, save_as="out.mp4"):
fig, ax = plt.subplots(figsize=(8, 4))
im = ax.imshow(mx.zeros_like(grid), cmap="gray", vmin=0, vmax=255)
progress_bar = tqdm.tqdm(total=steps, desc="Animating", ncols=100)
def update(frame):
im.set_data(frame)
progress_bar.update(1)
return [im]
ani = FuncAnimation(
fig,
update,
frames=generator(grid, steps=steps),
interval=steps // fps,
blit=True,
cache_frame_data=False,
)
fig.tight_layout()
try:
ani.save(save_as, writer="ffmpeg", fps=fps, dpi=300)
finally:
progress_bar.close()
plt.show()
if __name__ == "__main__":
grid_size = 2048
grid = mx.random.bernoulli(p=0.3, shape=(grid_size, grid_size))
animate(grid)
@cdotwang
Copy link

cdotwang commented Oct 3, 2024

There is a bug with this code in latest mlx 0.18.0

TypeError: metal_kernel(): incompatible function arguments. The following argument types are supported:
    1. metal_kernel(name: str, input_names: collections.abc.Sequence[str], output_names: collections.abc.Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) -> object

Invoked with types: kwargs = { name: str, source: str }

@awni
Copy link
Author

awni commented Oct 4, 2024

I updated it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment