from typing import Literal
import numpy as np
import numba
import astropy.units as u
from ._indices import (
rectify_index_lower,
rectify_index_upper,
)
__all__ = [
"convolve",
]
[docs]
def convolve(
array: np.ndarray | u.Quantity,
kernel: np.ndarray | u.Quantity,
axis: None | int | tuple[int, ...] = None,
where: bool | np.ndarray = True,
mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror",
) -> np.ndarray:
"""
Multidimensional convolution of an array with a given kernel.
This function differs from :func:`scipy.ndimage.convolve` and
:func:`astropy.convolution.convolve` because it implements a vectorized
convolution operation where the kernel is allowed to vary along axes
orthogonal to the convolution axes.
Parameters
----------
array
The input array to be convolved.
kernel
The convolution kernel.
Any non-convolution axes must be broadcastable with `array`.
axis
The axes of `array` over which to apply the kernel.
If :obj:`None`, it is assumed that the convolution is applied to all
the axes of `array`.
where
An optional mask that can be used to exclude elements of `array`
during the convolution.
mode
The method used to extend `array` beyond its boundaries.
Examples
--------
Apply a Gaussian blur to a sample image.
.. jupyter-execute::
import numpy as np
import matplotlib.pyplot as plt
import scipy
import ndfilters
# Define arbitrary coordinate system
# of the kenrel
x = np.linspace(-1, 1, 51)
y = np.linspace(-1, 1, 51)
x, y = np.meshgrid(x, y)
# Rotate the coordinate system
t = np.pi / 4
u = x * np.cos(t) - y * np.sin(t)
v = x * np.sin(t) + y * np.cos(t)
# Define the standard deviation
# in each dimension of the 2D kernel
sigma_u = 0.5
sigma_v = 0.1
# Compute a 2D Gaussian kernel
kernel_u = np.exp(-np.square(u / sigma_u) / 2)
kernel_v = np.exp(-np.square(v / sigma_v) / 2)
kernel = kernel_u * kernel_v
kernel = kernel / kernel.sum()
# Download a sample image
img = scipy.datasets.ascent()
# Convolve the sample image with the kernel
img_convolved = ndfilters.convolve(
array=img,
kernel=kernel,
)
# Plot the results
fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True)
axs[0].set_title("original image");
axs[0].imshow(img, cmap="gray");
axs[1].set_title("convolved image");
axs[1].imshow(img_convolved, cmap="gray");
"""
if isinstance(array, u.Quantity):
unit = array.unit
array = array.value
else:
unit = None
if axis is None:
axis = tuple(range(array.ndim))
axis = np.array(axis)
axis = np.lib.array_utils.normalize_axis_tuple(~axis, ndim=array.ndim)
axis = ~np.array(axis)
shape_kernel = list(kernel.shape)
for ax in axis:
shape_kernel[ax] = 1
shape = np.broadcast_shapes(array.shape, shape_kernel, np.shape(where))
shape_kernel = list(shape)
for ax in axis:
shape_kernel[ax] = kernel.shape[ax]
array = np.broadcast_to(array, shape)
kernel = np.broadcast_to(kernel, shape_kernel)
where = np.broadcast_to(where, shape)
axis_numba = ~np.arange(len(axis))[::-1]
shape_numba = tuple(shape[ax] for ax in axis)
shape_kernel_numba = tuple(shape_kernel[ax] for ax in axis)
array_ = np.moveaxis(array, axis, axis_numba)
kernel_ = np.moveaxis(kernel, axis, axis_numba)
where_ = np.moveaxis(where, axis, axis_numba)
if len(axis) == 1:
_convolve_nd = _convolve_1d
elif len(axis) == 2:
_convolve_nd = _convolve_2d
elif len(axis) == 3:
_convolve_nd = _convolve_3d
else: # pragma: nocover
raise ValueError(f"Only 1-3 axes supported, got {axis=}.")
result = _convolve_nd(
array=array_.reshape(-1, *shape_numba),
kernel=kernel_.reshape(-1, *shape_kernel_numba),
where=where_.reshape(-1, *shape_numba),
mode=mode,
)
result = result.reshape(array_.shape)
result = np.moveaxis(result, axis_numba, axis)
if unit is not None:
result = result << unit
return result
@numba.njit(parallel=True, cache=True)
def _convolve_1d(
array: np.ndarray,
kernel: np.ndarray,
where: np.ndarray,
mode: str,
):
result = np.zeros_like(array)
array_shape_t, array_shape_x = array.shape
_, kernel_shape_x = kernel.shape
for it in range(array_shape_t):
for ix in numba.prange(array_shape_x):
r = 0
for kx in range(kernel_shape_x):
px = kx - (kernel_shape_x - 1) // 2
jx = ix + px
if jx < 0:
if mode == "truncate":
continue
jx = rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
if mode == "truncate":
continue
jx = rectify_index_upper(jx, array_shape_x, mode)
if where[it, jx]:
array_tx = array[it, jx]
kernel_tx = kernel[it, ~kx]
r += array_tx * kernel_tx
result[it, ix] = r
return result
@numba.njit(parallel=True, cache=True)
def _convolve_2d(
array: np.ndarray,
kernel: np.ndarray,
where: np.ndarray,
mode: str,
):
result = np.empty_like(array)
array_shape_t, array_shape_x, array_shape_y = array.shape
_, kernel_shape_x, kernel_shape_y = kernel.shape
for it in range(array_shape_t):
for ix in numba.prange(array_shape_x):
for iy in numba.prange(array_shape_y):
r = 0
for kx in range(kernel_shape_x):
px = kx - (kernel_shape_x - 1) // 2
jx = ix + px
if jx < 0:
if mode == "truncate":
continue
jx = rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
if mode == "truncate":
continue
jx = rectify_index_upper(jx, array_shape_x, mode)
for ky in range(kernel_shape_y):
py = ky - (kernel_shape_y - 1) // 2
jy = iy + py
if jy < 0:
if mode == "truncate":
continue
jy = rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
if mode == "truncate":
continue
jy = rectify_index_upper(jy, array_shape_y, mode)
if where[it, jx, jy]:
array_txy = array[it, jx, jy]
kernel_txy = kernel[it, ~kx, ~ky]
r += array_txy * kernel_txy
result[it, ix, iy] = r
return result
@numba.njit(parallel=True, cache=True)
def _convolve_3d(
array: np.ndarray,
kernel: np.ndarray,
where: np.ndarray,
mode: str,
):
result = np.empty_like(array)
array_shape_t, array_shape_x, array_shape_y, array_shape_z = array.shape
_, kernel_shape_x, kernel_shape_y, kernel_shape_z = kernel.shape
for it in range(array_shape_t):
for ix in numba.prange(array_shape_x):
for iy in numba.prange(array_shape_y):
for iz in numba.prange(array_shape_z):
r = 0
for kx in range(kernel_shape_x):
px = kx - (kernel_shape_x - 1) // 2
jx = ix + px
if jx < 0:
if mode == "truncate":
continue
jx = rectify_index_lower(jx, array_shape_x, mode)
elif jx >= array_shape_x:
if mode == "truncate":
continue
jx = rectify_index_upper(jx, array_shape_x, mode)
for ky in range(kernel_shape_y):
py = ky - (kernel_shape_y - 1) // 2
jy = iy + py
if jy < 0:
if mode == "truncate":
continue
jy = rectify_index_lower(jy, array_shape_y, mode)
elif jy >= array_shape_y:
if mode == "truncate":
continue
jy = rectify_index_upper(jy, array_shape_y, mode)
for kz in range(kernel_shape_z):
pz = kz - (kernel_shape_z - 1) // 2
jz = iz + pz
if jz < 0:
if mode == "truncate":
continue
jz = rectify_index_lower(jz, array_shape_z, mode)
elif jz >= array_shape_z:
if mode == "truncate":
continue
jz = rectify_index_upper(jz, array_shape_z, mode)
if where[it, jx, jy, jz]:
array_txyz = array[it, jx, jy, jz]
kernel_txyz = kernel[it, ~kx, ~ky, ~kz]
r += array_txyz * kernel_txyz
result[it, ix, iy, iz] = r
return result