Skip to content

Instantly share code, notes, and snippets.

@itamarst
Last active December 15, 2023 14:49
Show Gist options
  • Save itamarst/32bfbd4550d22cfa0bb47b522c5a9f78 to your computer and use it in GitHub Desktop.
Save itamarst/32bfbd4550d22cfa0bb47b522c5a9f78 to your computer and use it in GitHub Desktop.
Optimized Numba kernel
@numba.jit(nopython=True, parallel=True, fastmath=True)
def spatial_regularisation_numbaoptimized(x: np.ndarray, dy: int = 1, dx: int = 1) -> float:
ny, nx = x.shape
total_cost = 0.0
for i in numba.prange(ny):
i = numba.int64(i)
min_y = max(i - dy, 0)
max_y = min(i + dy + 1, ny)
for j in numba.prange(nx):
j = numba.int64(j)
min_x = max(j - dx, 0)
max_x = min(j + dx + 1, nx)
for m in range(min_y, max_y):
for n in range(min_x, max_x):
total_cost += (x[i, j] - x[m, n]) ** 2
return 0.5 * total_cost
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment