Not only Mojo is great for writing high-performance code, but it also allows us to leverage huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it's good at, especially GUIs, without sacrificing performance in critical code. Let's take the classic Mandelbrot set algorithm and implement it in Mojo.
We'll introduce a Complex
type and use it in our implementation.
%%python
import numpy as np
import numba
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import time
%%python
# Constants
xmin = -2.25
xmax = 0.75
xn = 450
ymin = -1.25
ymax = 1.25
yn = 375
max_iter = 200
# Compute the number of steps to escape
def mandelbrot_kernel(c):
z = c
for i in range(max_iter):
z = z * z + c
if abs(z) > 2:
return i
return max_iter
def mandelbrot():
# Create a matrix. Each element of the matrix corresponds to a pixel
result = np.zeros((yn, xn), dtype=np.uint32)
dx = (xmax - xmin) / xn
dy = (ymax - ymin) / yn
y = ymin
for j in range(yn):
x = xmin
for i in range(xn):
result[j, i] = mandelbrot_kernel(complex(x, y))
x += dx
y += dy
return result
def make_plot_python(m):
dpi = 32
width = 5
height = 5 * yn // xn
fig = plt.figure(1, [width, height], dpi=dpi)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frame_on=False, aspect=1)
light = colors.LightSource(315, 10, 0, 1, 1, 0)
image = light.shade(m, plt.cm.hot, colors.PowerNorm(0.3), blend_mode='hsv', vert_exag=1.5)
plt.imshow(image)
plt.axis("off")
plt.show()
%%python
start_time = time.time()
mandelbrot_set = mandelbrot()
end_time = time.time()
execution_time = (end_time - start_time) * 1000 # Make it milliseconds
make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot: {execution_time:.0f} ms")
Execution time for Python Mandelbrot: 1266 ms
%%python
# Run with Numba JIT compiler
@numba.jit(nopython=True)
def mandelbrot_kernel_numba(c):
z = c
for i in range(max_iter):
z = z * z + c
if abs(z) > 2:
return i
return max_iter
@numba.jit(nopython=True)
def mandelbrot_numba():
# Create a matrix. Each element of the matrix corresponds to a pixel
result = np.zeros((yn, xn), dtype=np.uint32)
dx = (xmax - xmin) / xn
dy = (ymax - ymin) / yn
y = ymin
for j in range(yn):
x = xmin
for i in range(xn):
result[j, i] = mandelbrot_kernel_numba(complex(x, y))
x += dx
y += dy
return result
%%python
dummy = mandelbrot_numba() # Compile numba first
start_time = time.time()
mandelbrot_set = mandelbrot_numba()
end_time = time.time()
execution_time = (end_time - start_time) * 1000 # Make it milliseconds
make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (numba): {execution_time:.0f} ms")
Execution time for Python Mandelbrot (numba): 60 ms
%%python
def mandelbrot_vectorized(xn, yn, max_iter=200):
# Define the boundaries of the complex plane
xmin = -2.25
xmax = 0.75
ymin = -1.25
ymax = 1.25
# Create the grid of complex numbers
x = np.linspace(xmin, xmax, xn)
y = np.linspace(ymin, ymax, yn)
c = np.array([[complex(re, im) for re in x] for im in y])
# Initialize the Mandelbrot set and iteration count array
mandelbrot_set = np.zeros((yn, xn), dtype=np.uint32)
iter_count = np.zeros_like(mandelbrot_set)
# Initialize the z values with the complex grid
z = c.copy()
# Iterate over each point using vectorized operations
for i in range(max_iter):
# Update z values based on the Mandelbrot equation
z = z**2 + c
# Update the iteration count for points that have not escaped
iter_count[(np.abs(z) < 2) & (mandelbrot_set == 0)] = i
# Mark points that have escaped
mandelbrot_set[np.abs(z) >= 2] = 1
# Replace points that never escaped with the maximum iteration count
iter_count[mandelbrot_set == 0] = max_iter
return iter_count
%%python
start_time = time.time()
mandelbrot_set = mandelbrot_vectorized(xn, yn)
end_time = time.time()
execution_time = (end_time - start_time) * 1000
make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (vectorized): {execution_time:.0f} ms")
<string>:47: RuntimeWarning: overflow encountered in square
<string>:47: RuntimeWarning: invalid value encountered in square
Execution time for Python Mandelbrot (vectorized): 239 ms
%%python
@numba.vectorize([numba.uint32(numba.complex128, numba.uint32)], nopython=True)
def mandelbrot_element(c, max_iter):
z = c
for i in range(max_iter):
z = z**2 + c
if abs(z) > 2:
return i
return max_iter
def mandelbrot_vectorized_numba(xn, yn, max_iter=200):
# Define the boundaries of the complex plane
xmin = -2.25
xmax = 0.75
ymin = -1.25
ymax = 1.25
# Create the grid of complex numbers
x = np.linspace(xmin, xmax, xn)
y = np.linspace(ymin, ymax, yn)
c = np.array([[complex(re, im) for re in x] for im in y])
# Compute the Mandelbrot set element-wise using the vectorized function
iter_count = mandelbrot_element(c, max_iter)
return iter_count
%%python
dummy = mandelbrot_vectorized_numba(xn, yn) # Compile numba first
start_time = time.time()
mandelbrot_set = mandelbrot_vectorized_numba(xn, yn)
end_time = time.time()
execution_time = (end_time - start_time) * 1000
make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (vectorized-numba): {execution_time:.0f} ms")
Execution time for Python Mandelbrot (vectorized-numba): 93 ms
# %%python
# import os
# os.system('pip install cython')
# %load_ext cython
# %%cython
# ## Try with cython
# import numpy as np
# cimport numpy as np
# # Constants
# cdef double xmin = -2.25
# cdef double xmax = 0.75
# cdef int xn = 450
# cdef double ymin = -1.25
# cdef double ymax = 1.25
# cdef int yn = 375
# cdef int max_iter = 200
# # Mandelbrot computation in Cython
# cpdef np.ndarray[np.uint32_t, ndim=2] mandelbrot_cython():
# cdef double dx = (xmax - xmin) / xn
# cdef double dy = (ymax - ymin) / yn
# cdef np.ndarray[np.uint32_t, ndim=2] result = np.zeros((yn, xn), dtype=np.uint32)
# cdef double x, y, real, imag, abs_val
# cdef int i, j, k
# y = ymin
# for j in range(yn):
# x = xmin
# for i in range(xn):
# real = x
# imag = y
# for k in range(max_iter):
# abs_val = real * real + imag * imag
# if abs_val > 4:
# break
# real, imag = real * real - imag * imag + x, 2 * real * imag + y
# result[j, i] = k
# x += dx
# y += dy
# return result
# start_time = time.time()
# mandelbrot_set = mandelbrot_cython()
# end_time = time.time()
# execution_time = (end_time - start_time) * 1000 # Make it milliseconds
# make_plot_python(mandelbrot_set)
# print(f"Execution time for Python Mandelbrot (cython): {execution_time:.0f} ms")
from Benchmark import Benchmark
from DType import DType
from Memory import memset_zero
from Object import object, Attr
from Pointer import DTypePointer, Pointer
from Random import rand
from Range import range
from TargetInfo import dtype_sizeof
from Time import now
from Complex import ComplexSIMD as ComplexGenericSIMD
struct Matrix:
var data: DTypePointer[DType.si64]
var rows: Int
var cols: Int
var rc: Pointer[Int]
fn __init__(self&, cols: Int, rows: Int):
self.data = DTypePointer[DType.si64].alloc(rows * cols)
self.rows = rows
self.cols = cols
self.rc = Pointer[Int].alloc(1)
self.rc.store(1)
fn __copyinit__(self&, other: Self):
other._inc_rc()
self.data = other.data
self.rc = other.rc
self.rows = other.rows
self.cols = other.cols
fn __del__(owned self):
self._dec_rc()
fn _get_rc(self) -> Int:
return self.rc.load()
fn _dec_rc(self):
let rc = self._get_rc()
if rc > 1:
self.rc.store(rc - 1)
return
self._free()
fn _inc_rc(self):
let rc = self._get_rc()
self.rc.store(rc + 1)
fn _free(self):
self.data.free()
self.rc.free()
@always_inline
fn __getitem__(self, col: Int, row: Int) -> SI64:
return self.load[1](col, row)
@always_inline
fn load[nelts:Int](self, col: Int, row: Int) -> SIMD[DType.si64, nelts]:
return self.data.simd_load[nelts](row * self.cols + col)
@always_inline
fn __setitem__(self, col: Int, row: Int, val: SI64):
return self.store[1](col, row, val)
@always_inline
fn store[nelts:Int](self, col: Int, row: Int, val: SIMD[DType.si64, nelts]):
self.data.simd_store[nelts](row * self.cols + col, val)
def to_numpy(self) -> PythonObject:
let np = Python.import_module("numpy")
let numpy_array = np.zeros((self.rows, self.cols), np.uint32)
for col in range(self.cols):
for row in range(self.rows):
numpy_array.itemset((row, col), self[col, row].cast[DType.f32]())
return numpy_array
@register_passable("trivial")
struct Complex:
var real: F32
var imag: F32
fn __init__(real: F32, imag: F32) -> Self:
return Self {real: real, imag: imag}
fn __add__(lhs, rhs: Self) -> Self:
return Self(lhs.real + rhs.real, lhs.imag + rhs.imag)
fn __mul__(lhs, rhs: Self) -> Self:
return Self(
lhs.real * rhs.real - lhs.imag * rhs.imag,
lhs.real * rhs.imag + lhs.imag * rhs.real,
)
fn norm(self) -> F32:
return self.real * self.real + self.imag * self.imag
Then we can write the core Mandelbrot algorithm, which involves computing an iterative complex function for each pixel until it "escapes" the complex circle of radius 2, counting the number of iterations to escape.
alias xmin: F32 = -2.25
alias xmax: F32 = 0.75
alias xn = 450
alias ymin: F32 = -1.25
alias ymax: F32 = 1.25
alias yn = 375
# Compute the number of steps to escape.
def mandelbrot_kernel(c: Complex) -> Int:
max_iter = 200
z = c
for i in range(max_iter):
z = z * z + c
if z.norm() > 4:
return i
return max_iter
def compute_mandelbrot() -> Matrix:
# create a matrix. Each element of the matrix corresponds to a pixel
result = Matrix(xn, yn)
dx = (xmax - xmin) / xn
dy = (ymax - ymin) / yn
y = ymin
for j in range(yn):
x = xmin
for i in range(xn):
result[i, j] = mandelbrot_kernel(Complex(x, y))
x += dx
y += dy
return result
Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python's matplotlib
right from Mojo!
def make_plot(m: Matrix):
np = Python.import_module("numpy")
plt = Python.import_module("matplotlib.pyplot")
colors = Python.import_module("matplotlib.colors")
dpi = 32
width = 5
height = 5 * yn // xn
fig = plt.figure(1, [width, height], dpi)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1)
light = colors.LightSource(315, 10, 0, 1, 1, 0)
image = light.shade(m.to_numpy(), plt.cm.hot, colors.PowerNorm(0.3), "hsv", 0, 0, 1.5)
plt.imshow(image)
plt.axis("off")
plt.show()
let eval_begin: Int = now() # This is in nanoseconds
let mandelbrot_set = compute_mandelbrot()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000
make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot: ', execution_time, 'ms')
Execution time for Mojo Mandelbrot: 27 ms
We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo's access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the vectorize
higher order generator.
We start by defining our main iteration loop in a vectorized fashion
fn mandelbrot_kernel_simd[simd_width:Int](c: ComplexGenericSIMD[DType.f32, simd_width]) -> SIMD[DType.si64, simd_width]:
var z = c
var nv = SIMD[DType.si64, simd_width](0)
var escape_mask = SIMD[DType.bool, simd_width](0)
var i = 200
while i != 0 and not escape_mask:
z = z*z + c
# Only update elements that haven't escaped yet
escape_mask = escape_mask.select(escape_mask, z.norm() > 4)
nv = escape_mask.select(nv, nv + 1)
i -= 1
return nv
The above function is parameterized on the simd_width and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the vectorize
generator to make this a simple function call.
from Functional import vectorize
from Math import iota
from TargetInfo import dtype_simd_width
def compute_mandelbrot_simd() -> Matrix:
# create a matrix. Each element of the matrix corresponds to a pixel
var result = Matrix(xn, yn)
let dx = (xmax - xmin) / xn
let dy = (ymax - ymin) / yn
var y = ymin
alias simd_width = dtype_simd_width[DType.f32]()
for row in range(yn):
var x = xmin
@parameter
fn _process_simd_element[simd_width:Int](col: Int):
let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x,
SIMD[DType.f32, simd_width](y))
result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c))
x += simd_width*dx
vectorize[simd_width, _process_simd_element](xn)
y += dy
return result
let eval_begin: Int = now()
let mandelbrot_set = compute_mandelbrot_simd()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000
make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot (vectorized): ', execution_time, 'ms')
Execution time for Mojo Mandelbrot (vectorized): 2 ms
While the vectorized implementation above is efficient, we can get better performance by parallelizing on the rows. This again is simple in Mojo using the parallelize
higher order function. Only the function that performs the invocation needs to change.
from Functional import parallelize
def compute_mandelbrot_simd_parallel() -> Matrix:
# create a matrix. Each element of the matrix corresponds to a pixel
var result = Matrix(xn, yn)
let dx = (xmax - xmin) / xn
let dy = (ymax - ymin) / yn
alias simd_width = dtype_simd_width[DType.f32]()
@parameter
fn _process_row(row:Int):
var y = ymin + dy*row
var x = xmin
@parameter
fn _process_simd_element[simd_width:Int](col: Int):
let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x,
SIMD[DType.f32, simd_width](y))
result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c))
x += simd_width*dx
vectorize[simd_width, _process_simd_element](xn)
parallelize[_process_row](yn)
return result
let eval_begin: Int = now()
let mandelbrot_set = compute_mandelbrot_simd_parallel()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000
make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot (vectorized-parallelized): ', execution_time, 'ms')
Execution time for Mojo Mandelbrot (vectorized-parallelized): 4 ms
As the numba examples are written, it looks like you are including numba jit compile time in the benchmark, which isn’t a very fair comparison. Is this the case? Or did you re-run the numba example to make sure that you were timing already compiled code.