Skip to content

Instantly share code, notes, and snippets.

@eugeneyan
Last active April 4, 2024 15:52
Show Gist options
  • Save eugeneyan/1d2ea70fed81662271f784034cc30b73 to your computer and use it in GitHub Desktop.
Save eugeneyan/1d2ea70fed81662271f784034cc30b73 to your computer and use it in GitHub Desktop.
Benchmarking Mojo vs. Python on Mandelbrot sets

Mandelbrot in Mojo with Python plots

Not only Mojo is great for writing high-performance code, but it also allows us to leverage huge Python ecosystem of libraries and tools. With seamless Python interoperability, Mojo can use Python for what it's good at, especially GUIs, without sacrificing performance in critical code. Let's take the classic Mandelbrot set algorithm and implement it in Mojo.

We'll introduce a Complex type and use it in our implementation.

Mandelbrot in python

%%python
import numpy as np
import numba
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import time
%%python
# Constants
xmin = -2.25
xmax = 0.75
xn = 450
ymin = -1.25
ymax = 1.25
yn = 375
max_iter = 200

# Compute the number of steps to escape
def mandelbrot_kernel(c):
    z = c
    for i in range(max_iter):
        z = z * z + c
        if abs(z) > 2:
            return i
    return max_iter

def mandelbrot():
    # Create a matrix. Each element of the matrix corresponds to a pixel
    result = np.zeros((yn, xn), dtype=np.uint32)

    dx = (xmax - xmin) / xn
    dy = (ymax - ymin) / yn

    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            result[j, i] = mandelbrot_kernel(complex(x, y))
            x += dx
        y += dy
    return result

def make_plot_python(m):
    dpi = 32
    width = 5
    height = 5 * yn // xn

    fig = plt.figure(1, [width, height], dpi=dpi)
    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], frame_on=False, aspect=1)

    light = colors.LightSource(315, 10, 0, 1, 1, 0)

    image = light.shade(m, plt.cm.hot, colors.PowerNorm(0.3), blend_mode='hsv', vert_exag=1.5)
    plt.imshow(image)
    plt.axis("off")
    plt.show()
%%python
start_time = time.time()
mandelbrot_set = mandelbrot()
end_time = time.time()
execution_time = (end_time - start_time) * 1000  # Make it milliseconds

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot: {execution_time:.0f} ms")

output_4_0

Execution time for Python Mandelbrot: 1266 ms

Python numba JIT compiler

%%python

# Run with Numba JIT compiler
@numba.jit(nopython=True)
def mandelbrot_kernel_numba(c):
    z = c
    for i in range(max_iter):
        z = z * z + c
        if abs(z) > 2:
            return i
    return max_iter

@numba.jit(nopython=True)
def mandelbrot_numba():
    # Create a matrix. Each element of the matrix corresponds to a pixel
    result = np.zeros((yn, xn), dtype=np.uint32)

    dx = (xmax - xmin) / xn
    dy = (ymax - ymin) / yn

    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            result[j, i] = mandelbrot_kernel_numba(complex(x, y))
            x += dx
        y += dy
    return result
%%python
dummy = mandelbrot_numba()  # Compile numba first

start_time = time.time()
mandelbrot_set = mandelbrot_numba()
end_time = time.time()
execution_time = (end_time - start_time) * 1000  # Make it milliseconds

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (numba): {execution_time:.0f} ms")

output_7_0

Execution time for Python Mandelbrot (numba): 60 ms

Python vectorized

%%python
def mandelbrot_vectorized(xn, yn, max_iter=200):
    # Define the boundaries of the complex plane
    xmin = -2.25
    xmax = 0.75
    ymin = -1.25
    ymax = 1.25

    # Create the grid of complex numbers
    x = np.linspace(xmin, xmax, xn)
    y = np.linspace(ymin, ymax, yn)
    c = np.array([[complex(re, im) for re in x] for im in y])

    # Initialize the Mandelbrot set and iteration count array
    mandelbrot_set = np.zeros((yn, xn), dtype=np.uint32)
    iter_count = np.zeros_like(mandelbrot_set)

    # Initialize the z values with the complex grid
    z = c.copy()

    # Iterate over each point using vectorized operations
    for i in range(max_iter):
        # Update z values based on the Mandelbrot equation
        z = z**2 + c
        # Update the iteration count for points that have not escaped
        iter_count[(np.abs(z) < 2) & (mandelbrot_set == 0)] = i
        # Mark points that have escaped
        mandelbrot_set[np.abs(z) >= 2] = 1

    # Replace points that never escaped with the maximum iteration count
    iter_count[mandelbrot_set == 0] = max_iter

    return iter_count
%%python
start_time = time.time()
mandelbrot_set = mandelbrot_vectorized(xn, yn)
end_time = time.time()
execution_time = (end_time - start_time) * 1000

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (vectorized): {execution_time:.0f} ms")
<string>:47: RuntimeWarning: overflow encountered in square
<string>:47: RuntimeWarning: invalid value encountered in square

output_10_1

Execution time for Python Mandelbrot (vectorized): 239 ms

Python vectorized numba JIT

%%python

@numba.vectorize([numba.uint32(numba.complex128, numba.uint32)], nopython=True)
def mandelbrot_element(c, max_iter):
    z = c
    for i in range(max_iter):
        z = z**2 + c
        if abs(z) > 2:
            return i
    return max_iter

def mandelbrot_vectorized_numba(xn, yn, max_iter=200):
    # Define the boundaries of the complex plane
    xmin = -2.25
    xmax = 0.75
    ymin = -1.25
    ymax = 1.25

    # Create the grid of complex numbers
    x = np.linspace(xmin, xmax, xn)
    y = np.linspace(ymin, ymax, yn)
    c = np.array([[complex(re, im) for re in x] for im in y])

    # Compute the Mandelbrot set element-wise using the vectorized function
    iter_count = mandelbrot_element(c, max_iter)

    return iter_count
%%python
dummy = mandelbrot_vectorized_numba(xn, yn)  # Compile numba first

start_time = time.time()
mandelbrot_set = mandelbrot_vectorized_numba(xn, yn)
end_time = time.time()
execution_time = (end_time - start_time) * 1000

make_plot_python(mandelbrot_set)
print(f"Execution time for Python Mandelbrot (vectorized-numba): {execution_time:.0f} ms")

output_13_0

Execution time for Python Mandelbrot (vectorized-numba): 93 ms

Cython (can't load extension)

# %%python
# import os
# os.system('pip install cython')
# %load_ext cython
# %%cython

# ## Try with cython
# import numpy as np
# cimport numpy as np

# # Constants
# cdef double xmin = -2.25
# cdef double xmax = 0.75
# cdef int xn = 450
# cdef double ymin = -1.25
# cdef double ymax = 1.25
# cdef int yn = 375
# cdef int max_iter = 200

# # Mandelbrot computation in Cython
# cpdef np.ndarray[np.uint32_t, ndim=2] mandelbrot_cython():
#     cdef double dx = (xmax - xmin) / xn
#     cdef double dy = (ymax - ymin) / yn
#     cdef np.ndarray[np.uint32_t, ndim=2] result = np.zeros((yn, xn), dtype=np.uint32)
#     cdef double x, y, real, imag, abs_val
#     cdef int i, j, k
#     y = ymin
#     for j in range(yn):
#         x = xmin
#         for i in range(xn):
#             real = x
#             imag = y
#             for k in range(max_iter):
#                 abs_val = real * real + imag * imag
#                 if abs_val > 4:
#                     break
#                 real, imag = real * real - imag * imag + x, 2 * real * imag + y
#             result[j, i] = k
#             x += dx
#         y += dy
#     return result
# start_time = time.time()
# mandelbrot_set = mandelbrot_cython()
# end_time = time.time()
# execution_time = (end_time - start_time) * 1000  # Make it milliseconds

# make_plot_python(mandelbrot_set)
# print(f"Execution time for Python Mandelbrot (cython): {execution_time:.0f} ms")

Mandelbrot in Mojo

from Benchmark import Benchmark
from DType import DType
from Memory import memset_zero
from Object import object, Attr
from Pointer import DTypePointer, Pointer
from Random import rand
from Range import range
from TargetInfo import dtype_sizeof
from Time import now
from Complex import ComplexSIMD as ComplexGenericSIMD
struct Matrix:
    var data: DTypePointer[DType.si64]
    var rows: Int
    var cols: Int
    var rc: Pointer[Int]

    fn __init__(self&, cols: Int, rows: Int):
        self.data = DTypePointer[DType.si64].alloc(rows * cols)
        self.rows = rows
        self.cols = cols
        self.rc = Pointer[Int].alloc(1)
        self.rc.store(1)

    fn __copyinit__(self&, other: Self):
        other._inc_rc()
        self.data = other.data
        self.rc   = other.rc
        self.rows = other.rows
        self.cols = other.cols

    fn __del__(owned self):
        self._dec_rc()

    fn _get_rc(self) -> Int:
        return self.rc.load()

    fn _dec_rc(self):
        let rc = self._get_rc()
        if rc > 1:
            self.rc.store(rc - 1)
            return
        self._free()

    fn _inc_rc(self):
        let rc = self._get_rc()
        self.rc.store(rc + 1)

    fn _free(self):
        self.data.free()
        self.rc.free()

    @always_inline
    fn __getitem__(self, col: Int, row: Int) -> SI64:
        return self.load[1](col, row)

    @always_inline
    fn load[nelts:Int](self, col: Int, row: Int) -> SIMD[DType.si64, nelts]:
        return self.data.simd_load[nelts](row * self.cols + col)

    @always_inline
    fn __setitem__(self, col: Int, row: Int, val: SI64):
        return self.store[1](col, row, val)

    @always_inline
    fn store[nelts:Int](self, col: Int, row: Int, val: SIMD[DType.si64, nelts]):
        self.data.simd_store[nelts](row * self.cols + col, val)

    def to_numpy(self) -> PythonObject:
        let np = Python.import_module("numpy")
        let numpy_array = np.zeros((self.rows, self.cols), np.uint32)
        for col in range(self.cols):
            for row in range(self.rows):
                numpy_array.itemset((row, col), self[col, row].cast[DType.f32]())
        return numpy_array
@register_passable("trivial")
struct Complex:
    var real: F32
    var imag: F32

    fn __init__(real: F32, imag: F32) -> Self:
        return Self {real: real, imag: imag}

    fn __add__(lhs, rhs: Self) -> Self:
        return Self(lhs.real + rhs.real, lhs.imag + rhs.imag)

    fn __mul__(lhs, rhs: Self) -> Self:
        return Self(
            lhs.real * rhs.real - lhs.imag * rhs.imag,
            lhs.real * rhs.imag + lhs.imag * rhs.real,
        )

    fn norm(self) -> F32:
        return self.real * self.real + self.imag * self.imag

Then we can write the core Mandelbrot algorithm, which involves computing an iterative complex function for each pixel until it "escapes" the complex circle of radius 2, counting the number of iterations to escape.

$$z_{i+1} = z_i^2 + c$$

alias xmin: F32 = -2.25
alias xmax: F32 = 0.75
alias xn = 450
alias ymin: F32 = -1.25
alias ymax: F32 = 1.25
alias yn = 375

# Compute the number of steps to escape.
def mandelbrot_kernel(c: Complex) -> Int:
    max_iter = 200
    z = c
    for i in range(max_iter):
        z = z * z + c
        if z.norm() > 4:
            return i
    return max_iter


def compute_mandelbrot() -> Matrix:
    # create a matrix. Each element of the matrix corresponds to a pixel
    result = Matrix(xn, yn)

    dx = (xmax - xmin) / xn
    dy = (ymax - ymin) / yn

    y = ymin
    for j in range(yn):
        x = xmin
        for i in range(xn):
            result[i, j] = mandelbrot_kernel(Complex(x, y))
            x += dx
        y += dy
    return result

Plotting the number of iterations to escape with some color gives us the canonical Mandelbrot set plot. To render it we can directly leverage Python's matplotlib right from Mojo!

def make_plot(m: Matrix):
    np = Python.import_module("numpy")
    plt = Python.import_module("matplotlib.pyplot")
    colors = Python.import_module("matplotlib.colors")
    dpi = 32
    width = 5
    height = 5 * yn // xn

    fig = plt.figure(1, [width, height], dpi)
    ax = fig.add_axes([0.0, 0.0, 1.0, 1.0], False, 1)

    light = colors.LightSource(315, 10, 0, 1, 1, 0)

    image = light.shade(m.to_numpy(), plt.cm.hot, colors.PowerNorm(0.3), "hsv", 0, 0, 1.5)
    plt.imshow(image)
    plt.axis("off")
    plt.show()
let eval_begin: Int = now()  # This is in nanoseconds
let mandelbrot_set = compute_mandelbrot()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000

make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot: ', execution_time, 'ms')

output_27_0

Execution time for Mojo Mandelbrot:  27 ms

Vectorizing Mandelbrot

We showed a naive implementation of the Mandelbrot algorithm, but there are two things we can do to speed it up. We can early-stop the loop iteration when a pixel is known to have escaped, and we can leverage Mojo's access to hardware by vectorizing the loop, computing multiple pixels simultaneously. To do that we will use the vectorize higher order generator.

We start by defining our main iteration loop in a vectorized fashion

fn mandelbrot_kernel_simd[simd_width:Int](c: ComplexGenericSIMD[DType.f32, simd_width]) -> SIMD[DType.si64, simd_width]:
    var z = c
    var nv = SIMD[DType.si64, simd_width](0)
    var escape_mask = SIMD[DType.bool, simd_width](0)

    var i = 200
    while i != 0 and not escape_mask:
        z = z*z + c
        # Only update elements that haven't escaped yet
        escape_mask = escape_mask.select(escape_mask, z.norm() > 4)
        nv = escape_mask.select(nv, nv + 1) 
        i -= 1
    
    return nv

The above function is parameterized on the simd_width and processes simd_width pixels. It only escapes once all pixels within the vector lane are done. We can use the same iteration loop as above, but this time we vectorize within each row instead. We use the vectorize generator to make this a simple function call.

from Functional import vectorize
from Math import iota
from TargetInfo import dtype_simd_width


def compute_mandelbrot_simd() -> Matrix:
    # create a matrix. Each element of the matrix corresponds to a pixel
    var result = Matrix(xn, yn)

    let dx = (xmax - xmin) / xn
    let dy = (ymax - ymin) / yn

    var y = ymin
    alias simd_width = dtype_simd_width[DType.f32]()

    for row in range(yn):
        var x = xmin
        @parameter
        fn _process_simd_element[simd_width:Int](col: Int):
            let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, 
                                                              SIMD[DType.f32, simd_width](y))
            result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c))
            x += simd_width*dx

        vectorize[simd_width, _process_simd_element](xn)
        y += dy
    return result
let eval_begin: Int = now()
let mandelbrot_set = compute_mandelbrot_simd()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000

make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot (vectorized): ', execution_time, 'ms')

output_34_0

Execution time for Mojo Mandelbrot (vectorized):  2 ms

Parallelizing Mandelbrot

While the vectorized implementation above is efficient, we can get better performance by parallelizing on the rows. This again is simple in Mojo using the parallelize higher order function. Only the function that performs the invocation needs to change.

from Functional import parallelize 

def compute_mandelbrot_simd_parallel() -> Matrix:
    # create a matrix. Each element of the matrix corresponds to a pixel
    var result = Matrix(xn, yn)

    let dx = (xmax - xmin) / xn
    let dy = (ymax - ymin) / yn

    alias simd_width = dtype_simd_width[DType.f32]()

    @parameter
    fn _process_row(row:Int):
        var y = ymin + dy*row
        var x = xmin
        @parameter
        fn _process_simd_element[simd_width:Int](col: Int):
            let c = ComplexGenericSIMD[DType.f32, simd_width](dx*iota[simd_width, DType.f32]() + x, 
                                                              SIMD[DType.f32, simd_width](y))
            result.store[simd_width](col, row, mandelbrot_kernel_simd[simd_width](c))
            x += simd_width*dx
            
        vectorize[simd_width, _process_simd_element](xn)

    parallelize[_process_row](yn)
    return result
let eval_begin: Int = now()
let mandelbrot_set = compute_mandelbrot_simd_parallel()
let eval_end: Int = now()
let execution_time = (eval_end - eval_begin) // 1000000

make_plot(mandelbrot_set)
print('Execution time for Mojo Mandelbrot (vectorized-parallelized): ', execution_time, 'ms')

output_38_0

Execution time for Mojo Mandelbrot (vectorized-parallelized):  4 ms
@Nicholaswogan
Copy link

As the numba examples are written, it looks like you are including numba jit compile time in the benchmark, which isn’t a very fair comparison. Is this the case? Or did you re-run the numba example to make sure that you were timing already compiled code.

@eugeneyan
Copy link
Author

Ah you're right. Rerunning it again puts non-vectorized numba at 60 ms.

@guillaume-michel
Copy link

guillaume-michel commented May 12, 2023

To be fair, you should make the exit condition the same for Python and Mojo: Either you compare the norm of z to 2 or the norm squared of z to 4. Avoiding millions of square root in the Mojo version is not fair.

Python's floating point number are 64-bit precision by default but Mojo is using 32-bit precision which is again not fair.

@pauljurczak
Copy link

@guillaume-michel Good points. Here is a single process version with float32 representation and norm squared. It is 43.8 times faster than the initial Python version on my PC.

import numpy as np, numba as nb, timeit as ti

xmin, xmax, xn = -2.25, 0.75, 450
ymin, ymax, yn = -1.25, 1.25, 375
imax = 200


@nb.njit(fastmath=True, locals=dict(x=nb.complex64))
def abs2(x):
  return x.real**2 + x.imag**2


@nb.njit(fastmath=True, locals=dict(c=nb.complex64))
def kernel(c):
    z = c

    for i in range(imax):
        z = z * z + c
        if abs2(z) > 4:
            return i
        
    return imax


@nb.njit(fastmath=True)
def mandelbrot():
    result = np.zeros((yn, xn), dtype=np.uint32)

    for j, y in zip(range(yn), np.arange(ymin, ymax, (ymax-ymin)/yn)):
        for i, x in zip(range(xn), np.arange(xmin, xmax, (xmax-xmin)/xn)):
            result[j, i] = kernel(np.csingle(x+y*1j))
            
    return result


fun = f'mandelbrot()'
t = 1000 * np.array(ti.repeat(stmt=fun, setup=fun, globals=globals(), number=1, repeat=100))
print(f'{fun}:  {np.amin(t):6.3f}ms  {np.median(t):6.3f}ms')

@pauljurczak
Copy link

A small change allows for Numba parallelization. The speedup is now 266 times:

import numpy as np, numba as nb, timeit as ti

xmin, xmax, xn = -2.25, 0.75, 450
ymin, ymax, yn = -1.25, 1.25, 375
imax = 200


@nb.njit(fastmath=True, locals=dict(x=nb.complex64))
def abs2(x):
  return x.real**2 + x.imag**2


@nb.njit(fastmath=True, locals=dict(c=nb.complex64))
def kernel(c):
  z = c

  for i in range(imax):
    z = z*z + c
    if abs2(z) > 4:
      return i
      
  return imax


@nb.njit(fastmath=True, parallel=True)
def mandelbrot():
  result = np.zeros((yn, xn), dtype=np.uint32)
  dy = (ymax-ymin)/yn

  for j in nb.prange(yn):
    y = ymin+j*dy
    for i, x in enumerate(np.arange(xmin, xmax, (xmax-xmin)/xn)):
      result[j, i] = kernel(np.csingle(x+y*1j))

  return result


if __name__ == "__main__":
  fun = f'mandelbrot()'
  t = 1000 * np.array(ti.repeat(stmt=fun, setup=fun, globals=globals(), number=1, repeat=100))
  print(f'{fun}:  {np.amin(t):6.3f}ms  {np.median(t):6.3f}ms')

@guillaume-michel
Copy link

@pauljurczak very nice code indeed! to be fair, we should keep fastmath=False, as I don't think Mojo is using fastmath by default. That would be a terrible default setup but works great in that case.

I observe a 237x speedup from the first original version and your last version which is coherent with your observations. That means there is still room for improvement to get to Mojo level (633x speedup from original).

Mojo is using SIMD but I am not sure numba is able to auto vectorize the early exit in kernel. You will have to do it yourself like Mojo did, using mask and keep iterating until all elements in SIMD vector escaped. That should bring you in the same ballpark as Mojo.

To conclude, I am not saying Mojo is bad or anything. I am just trying to understand what it is good for and how MLIR can be leverage to good use.
Using an embarrassingly parallel algorithm like Mandelbrot is a very bad way to showcase supremacy as it is a king of problem competition is already very good at.

@pauljurczak
Copy link

@guillaume-michel I will push back on fastmath issue. I tried both fastmath=False and fastmath=True and the results are identical, as expected. Bear in mind that we are calculating with values in a fairly close neighborhood of 1.0.

@pauljurczak
Copy link

pauljurczak commented May 15, 2023

Using Numba with CUDA on a mediocre GPU (NVIDIA Geforce RTX 3050) achieves 2662x speedup over the initial version:

import numpy as np, timeit as ti
from numba import cuda

xmin, xmax, xn = -2.25, 0.75, 450
ymin, ymax, yn = -1.25, 1.25, 375
dx = (xmax-xmin)/xn
dy = (ymax-ymin)/yn
imax = 200


@cuda.jit
def kernel(a):
  y, x = cuda.grid(2)

  if y < a.shape[0] and x < a.shape[1]:
    c = np.csingle(xmin+x*dx + (ymin+y*dy)*1j)
    z = c
    a[y, x] = imax

    for i in range(imax):
      z = z*z + c

      if z.real**2 + z.imag**2 > 4:
        a[y, x] = i
        break


def mandelbrot():
  a = cuda.to_device(np.empty((yn, xn), dtype='u1'))
  tpb = (16, 16)
  bpg = (a.shape[0]//tpb[0] + 1, a.shape[1]//tpb[0] + 1)
  kernel[bpg, tpb](a)

  return a.copy_to_host()


if __name__ == "__main__":
  fun = f'mandelbrot()'
  t = 1000 * np.array(ti.repeat(stmt=fun, setup=fun, globals=globals(), number=1, repeat=1000))
  print(f'{fun}:  {np.amin(t):6.3f}ms  {np.median(t):6.3f}ms')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment