Source code for ndfilters._generic

from typing import Callable, Literal
import numpy as np
import numba
import astropy.units as u
from ._indices import (
    rectify_index_lower,
    rectify_index_upper,
)

__all__ = [
    "generic_filter",
]


[docs] def generic_filter( array: np.ndarray | u.Quantity, function: Callable[[np.ndarray, tuple], float], size: int | tuple[int, ...], axis: None | int | tuple[int, ...] = None, where: bool | np.ndarray = True, mode: Literal["mirror", "nearest", "wrap", "truncate"] = "mirror", args: tuple = (), ) -> np.ndarray: """ Filter a multidimensional array using an arbitrary compiled function. Parameters ---------- array The input array to be filtered function The function to applied to each kernel footprint. This is usually either a Numpy reduction function like :func:`numpy.mean`, or a function compiled using :func:`numba.njit`. This function must accept a 1D array and a tuple of extra arguments as input and return a scalar. size The shape of the kernel over which the trimmed mean will be calculated. axis The axes over which to apply the kernel. Should either be a scalar or have the same number of items as `size`. If :obj:`None` (the default) the kernel spans every axis of the array. where An optional mask that can be used to exclude parts of the array during filtering. mode The method used to extend the input array beyond its boundaries. See :func:`scipy.ndimage.generic_filter` for the definitions. Currently, only "mirror", "nearest", "wrap", and "truncate" modes are supported. args Extra arguments to pass to function. Examples -------- .. jupyter-execute:: import numpy as np import numba import matplotlib.pyplot as plt import scipy.datasets import ndfilters # Download a sample image img = scipy.datasets.ascent() # Define a compiled function to apply at every # kernel footprint. @numba.njit def function(a: np.ndarray, args: tuple) -> float: return np.mean(a) # Filter the image using an arbitrary function. img_filtered = ndfilters.generic_filter( function=function, array=img, size=21, ) fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True) axs[0].set_title("original image"); axs[0].imshow(img, cmap="gray"); axs[1].set_title("filtered image"); axs[1].imshow(img_filtered, cmap="gray"); """ if isinstance(array, u.Quantity): unit = array.unit array = array.value else: unit = None if axis is None: axis = tuple(range(array.ndim)) axis = np.lib.array_utils.normalize_axis_tuple(axis, ndim=array.ndim) if isinstance(size, int): size = (size,) * len(axis) else: if len(size) != len(axis): raise ValueError( f"{size=} should have the same number of elements as {axis=}." ) axis_numba = ~np.arange(len(axis))[::-1] shape = array.shape shape_numba = tuple(shape[ax] for ax in axis) where = np.broadcast_to(where, shape) array_ = np.moveaxis(array, axis, axis_numba) where_ = np.moveaxis(where, axis, axis_numba) if len(axis) == 1: _generic_filter_nd = _generic_filter_1d elif len(axis) == 2: _generic_filter_nd = _generic_filter_2d elif len(axis) == 3: _generic_filter_nd = _generic_filter_3d else: # pragma: nocover raise ValueError(f"Only 1-3 axes supported, got {axis=}.") result = _generic_filter_nd( array=array_.reshape(-1, *shape_numba), function=function, size=size, where=where_.reshape(-1, *shape_numba), mode=mode, args=args, ) result = result.reshape(array_.shape) result = np.moveaxis(result, axis_numba, axis) if unit is not None: result = result << unit return result
@numba.njit(parallel=True, cache=True) def _generic_filter_1d( array: np.ndarray, function: Callable[[np.ndarray, tuple], float], size: tuple[int], where: np.ndarray, mode: str, args: tuple, ): result = np.empty_like(array) array_shape_t, array_shape_x = array.shape (kernel_shape_x,) = size for it in range(array_shape_t): for ix in numba.prange(array_shape_x): values = np.zeros(shape=size) mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): px = kx - kernel_shape_x // 2 jx = ix + px if jx < 0: if mode == "truncate": continue jx = rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: if mode == "truncate": continue jx = rectify_index_upper(jx, array_shape_x, mode) values[kx] = array[it, jx] mask[kx] = where[it, jx] if np.any(mask): result[it, ix] = function(values[mask], args) else: result[it, ix] = np.nan return result @numba.njit(parallel=True, cache=True) def _generic_filter_2d( array: np.ndarray, function: Callable[[np.ndarray, tuple], float], size: tuple[int, int], where: np.ndarray, mode: str, args: tuple, ): result = np.empty_like(array) array_shape_t, array_shape_x, array_shape_y = array.shape kernel_shape_x, kernel_shape_y = size for it in range(array_shape_t): for ix in numba.prange(array_shape_x): for iy in numba.prange(array_shape_y): values = np.zeros(shape=size) mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): px = kx - kernel_shape_x // 2 jx = ix + px if jx < 0: if mode == "truncate": continue jx = rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: if mode == "truncate": continue jx = rectify_index_upper(jx, array_shape_x, mode) for ky in range(kernel_shape_y): py = ky - kernel_shape_y // 2 jy = iy + py if jy < 0: if mode == "truncate": continue jy = rectify_index_lower(jy, array_shape_y, mode) elif jy >= array_shape_y: if mode == "truncate": continue jy = rectify_index_upper(jy, array_shape_y, mode) values[kx, ky] = array[it, jx, jy] mask[kx, ky] = where[it, jx, jy] values = values.reshape(-1) mask = mask.reshape(-1) if np.any(mask): result[it, ix, iy] = function(values[mask], args) else: result[it, ix, iy] = np.nan return result @numba.njit(parallel=True, cache=True) def _generic_filter_3d( array: np.ndarray, function: Callable[[np.ndarray, tuple], float], size: tuple[int, int, int], where: np.ndarray, mode: str, args: tuple, ): result = np.empty_like(array) array_shape_t, array_shape_x, array_shape_y, array_shape_z = array.shape kernel_shape_x, kernel_shape_y, kernel_shape_z = size for it in range(array_shape_t): for ix in numba.prange(array_shape_x): for iy in numba.prange(array_shape_y): for iz in numba.prange(array_shape_z): values = np.zeros(shape=size) mask = np.zeros(shape=size, dtype=np.bool_) for kx in range(kernel_shape_x): px = kx - kernel_shape_x // 2 jx = ix + px if jx < 0: if mode == "truncate": continue jx = rectify_index_lower(jx, array_shape_x, mode) elif jx >= array_shape_x: if mode == "truncate": continue jx = rectify_index_upper(jx, array_shape_x, mode) for ky in range(kernel_shape_y): py = ky - kernel_shape_y // 2 jy = iy + py if jy < 0: if mode == "truncate": continue jy = rectify_index_lower(jy, array_shape_y, mode) elif jy >= array_shape_y: if mode == "truncate": continue jy = rectify_index_upper(jy, array_shape_y, mode) for kz in range(kernel_shape_z): pz = kz - kernel_shape_z // 2 jz = iz + pz if jz < 0: if mode == "truncate": continue jz = rectify_index_lower(jz, array_shape_z, mode) elif jz >= array_shape_z: if mode == "truncate": continue jz = rectify_index_upper(jz, array_shape_z, mode) values[kx, ky, kz] = array[it, jx, jy, jz] mask[kx, ky, kz] = where[it, jx, jy, jz] values = values.reshape(-1) mask = mask.reshape(-1) if np.any(mask): result[it, ix, iy, iz] = function(values[mask], args) else: result[it, ix, iy, iz] = np.nan return result