Source code for myoverse.datasets.filters.generic

import warnings
from functools import partial
from typing import Literal

import numpy as np
from numpy.lib.stride_tricks import as_strided

from myoverse.datasets.filters._template import FilterBaseClass


def _get_windows_with_shift(
    input_array: np.ndarray, window_size: int, shift: int
) -> np.ndarray:
    """Create windows of specified size and shift from input array using strided operations.

    Parameters
    ----------
    input_array : numpy.ndarray
        The input array to window.
    window_size : int
        Size of each window.
    shift : int
        Number of samples to shift between consecutive windows.

    Returns
    -------
    numpy.ndarray
        Array of windows with shape (n_windows, *input_array.shape[:-1], window_size)
        where n_windows = (input_array.shape[-1] - window_size) // shift + 1

    Notes
    -----
    This function uses numpy's as_strided function for efficient windowing without
    creating copies of the data. The returned array is read-only (writeable=False).
    """
    # Calculate how many windows we'll have
    n_windows = (input_array.shape[-1] - window_size) // shift + 1

    # Calculate new shape with windows
    # Original dimensions (except the last) + number of windows + window size
    window_shape = (*input_array.shape[:-1], n_windows, window_size)

    # Calculate strides for the windowed view
    # Original strides + stride for windows + original stride for last dimension
    original_strides = input_array.strides
    window_strides = (
        *original_strides[:-1],
        shift * original_strides[-1],  # Step between windows
        original_strides[-1],
    )  # Step within a window

    # Create windowed view using as_strided
    windows = as_strided(
        input_array, shape=window_shape, strides=window_strides, writeable=False
    )

    # Transpose the windows dimension to be the first dimension
    # The final shape will be (n_windows, *input_array.shape[:-1], window_size)
    transpose_axes = (
        len(window_shape) - 2,
        *range(len(window_shape) - 2),
        len(window_shape) - 1,
    )
    return np.transpose(windows, transpose_axes)


[docs] class ApplyFunctionFilter(FilterBaseClass): """Filter that applies a function to the input array. This filter provides a flexible way to apply any function to the input data array. The function can be a simple lambda, a NumPy function, or any custom function that operates on numpy arrays. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool, optional Whether this is an output filter, by default False. name : str, optional Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. function : callable, optional The function to apply to the input array, by default None **function_kwargs Additional keyword arguments to pass to the function. Examples -------- >>> import numpy as np >>> from myoverse.datasets.filters.generic import ApplyFunctionFilter >>> # Create data >>> data = np.random.rand(10, 500) >>> # Apply absolute value function >>> abs_filter = ApplyFunctionFilter(function=np.abs, input_is_chunked=False) >>> abs_data = abs_filter(data) >>> # Apply mean function with axis parameter >>> mean_filter = ApplyFunctionFilter(function=np.mean, axis=-1, input_is_chunked=False) >>> mean_data = mean_filter(data) >>> # Apply custom function >>> custom_filter = ApplyFunctionFilter(function=lambda x: x**2 - x, input_is_chunked=False) >>> custom_data = custom_filter(data) Notes ----- The function is applied directly to the input array without any pre-processing. The output shape depends on the function being applied. For example, np.mean with axis=-1 will reduce the last dimension, changing the output shape. See Also -------- IdentityFilter : A filter that returns the input unchanged """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, function: callable = None, **function_kwargs, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, run_checks=run_checks, ) self.function = partial(function, **function_kwargs)
[docs] def _run_filter_checks(self, input_array: np.ndarray): """Run checks on the input array. Parameters ---------- input_array : np.ndarray The input array to check Raises ------ ValueError If the input is not a numpy array If multiple inputs are provided TypeError If the function is not callable """ if isinstance(input_array, list): raise ValueError( f"ApplyFunctionFilter expects a single numpy array, but received a list of {len(input_array)} arrays. " f"This filter can only process one input at a time." ) if not callable(self.function): raise TypeError( f"The provided function must be callable, but got {type(self.function)}" )
[docs] def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray: """Apply the function to the input array. Parameters ---------- input_array : np.ndarray The input array to apply the function to **kwargs Additional keyword arguments from the Data object Returns ------- np.ndarray The result of applying the function """ return self.function(input_array)
[docs] class IndexDataFilter(FilterBaseClass): """Filter that indexes the input array using NumPy-style indexing. This filter provides a flexible way to select specific elements or slices from the input array using NumPy's powerful indexing syntax. It supports basic slicing, integer array indexing, boolean masks, ellipsis, and advanced indexing. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool, optional Whether the filter is an output filter. If True, the resulting signal will be outputted by any dataset pipeline, by default False. name : str, optional The name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. indices : any valid NumPy index The indices to use for indexing the input array. Can be: - Single index: 0, -1 - Slice: slice(0, 10), slice(None, None, 2) - Tuple of indices for multiple dimensions: (0, slice(None), [1, 2, 3]) - Ellipsis: ... or Ellipsis - Boolean mask: array([True, False, True]) - Integer arrays for fancy indexing: np.array([0, 2, 4]) - Any combination of the above Examples -------- >>> import numpy as np >>> from myoverse.datasets.filters.generic import IndexDataFilter >>> # Create data >>> data = np.random.rand(5, 10, 100) >>> >>> # Select first element of first dimension >>> filter_first = IndexDataFilter(indices=0) >>> output = filter_first(data) # shape: (10, 100) >>> >>> # Select first three elements of last dimension >>> filter_slice = IndexDataFilter(indices=(slice(None), slice(None), slice(0, 3))) >>> output = filter_slice(data) # shape: (5, 10, 3) >>> >>> # Select specific elements with fancy indexing >>> filter_fancy = IndexDataFilter(indices=([0, 2], slice(None), [0, 50, 99])) >>> output = filter_fancy(data) # shape: (2, 10, 3) >>> >>> # Use ellipsis to simplify indexing (equivalent to the above) >>> filter_ellipsis = IndexDataFilter(indices=([0, 2], ..., [0, 50, 99])) >>> output = filter_ellipsis(data) # shape: (2, 10, 3) >>> >>> # Select specific elements from the last dimension >>> filter_lastdim = IndexDataFilter(indices=(slice(None), slice(None), [0, 1, 2])) >>> output = filter_lastdim(data) # shape: (5, 10, 3) Notes ----- This filter directly passes the provided indices to NumPy's indexing system. The behavior will match exactly what you would expect from numpy.ndarray indexing. """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, indices: any = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, run_checks=run_checks, ) self.indices = indices
[docs] def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray: """Apply the indices to the input array. This method directly applies the indices to the input array using NumPy's indexing system, which supports basic slicing, integer array indexing, boolean masks, ellipsis, and advanced indexing. Parameters ---------- input_array : np.ndarray The input array to index **kwargs Additional keyword arguments from the Data object Returns ------- np.ndarray The indexed array """ return input_array[self.indices]
[docs] class ChunkizeDataFilter(FilterBaseClass): """Filter that chunks the input array into overlapping or non-overlapping segments. This filter divides a continuous signal into chunks along the last dimension. It's useful for preparing data for window-based analysis or for applying sliding window techniques. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool, optional Whether the filter is an output filter. If True, the resulting signal will be outputted by any dataset pipeline, by default False. name : str, optional The name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. chunk_size : int The size of each chunk along the last dimension. chunk_shift : int, optional The shift between consecutive chunks. If provided, chunk_overlap is ignored. A small shift creates more overlapping chunks. chunk_overlap : int, optional The overlap between consecutive chunks. If provided, chunk_shift is ignored. Overlap = chunk_size - chunk_shift. Raises ------ ValueError If input_is_chunked is True (this filter only accepts unchunked input). ValueError If chunk_size is not specified. ValueError If neither chunk_shift nor chunk_overlap is specified. ValueError If chunk_shift is less than 1. ValueError If chunk_overlap is less than 0 or greater than chunk_size. Examples -------- >>> import numpy as np >>> from myoverse.datasets.filters.generic import ChunkizeDataFilter >>> # Create data >>> data = np.random.rand(10, 1000) >>> # Create non-overlapping chunks >>> no_overlap = ChunkizeDataFilter( ... chunk_size=100, ... chunk_shift=100, ... input_is_chunked=False ... ) >>> chunked_data = no_overlap(data) # shape: (10, 10, 100) >>> # Create overlapping chunks >>> with_overlap = ChunkizeDataFilter( ... chunk_size=100, ... chunk_overlap=50, ... input_is_chunked=False ... ) >>> overlapped_data = with_overlap(data) # shape: (19, 10, 100) Notes ----- The output shape will be (n_chunks, *original_dims, chunk_size), where n_chunks = (input_length - chunk_size) // chunk_shift + 1 or n_chunks = (input_length - chunk_size) // (chunk_size - chunk_overlap) + 1 When both chunk_shift and chunk_overlap are provided, chunk_shift takes precedence. See Also -------- _get_windows_with_shift : Efficient windowing function used in temporal filters """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, chunk_size: int = None, chunk_shift: int = None, chunk_overlap: int = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="not chunked", is_output=is_output, name=name, run_checks=run_checks, ) if input_is_chunked == True: raise ValueError("This filter only accepts unchunked input.") self.chunk_size = chunk_size self.chunk_shift = chunk_shift self.chunk_overlap = chunk_overlap if self.chunk_size is None: raise ValueError("chunk_size must be specified.") if self.chunk_shift is None and self.chunk_overlap is None: raise ValueError("Either chunk_shift or chunk_overlap must be specified.") if self.chunk_shift is not None: if self.chunk_shift < 1: raise ValueError("chunk_shift must be greater than 0.") if self.chunk_shift >= self.chunk_size: warnings.warn( "chunk_shift is greater than or equal to chunk_size. " "Some parts of the data will be skipped. Be sure this is intended." ) if self.chunk_overlap is not None: if self.chunk_overlap < 0: raise ValueError("chunk_overlap must be greater than or equal to 0.") if self.chunk_overlap > self.chunk_size: raise ValueError( "chunk_overlap must be less than or equal to chunk_size." )
[docs] def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray: """Chunk the input array into overlapping segments. Parameters ---------- input_array : np.ndarray The input array to chunk **kwargs Additional keyword arguments from the Data object Returns ------- np.ndarray The chunked array with shape (n_chunks, *original_shape, chunk_size) """ # Use the existing _get_windows_with_shift function for more efficient windowing return _get_windows_with_shift( input_array, self.chunk_size, self.chunk_shift if self.chunk_shift is not None else self.chunk_size - self.chunk_overlap, )
[docs] class IdentityFilter(FilterBaseClass): """Filter that returns the input unchanged. This is useful as a placeholder in filter sequences or for testing. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. This must be explicitly set by the user as they have the best knowledge of their data structure. is_output : bool, optional Whether this is an output filter, by default False name : str, optional Name of the filter, by default None run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. Examples -------- >>> import numpy as np >>> from myoverse.datasets.filters.generic import IdentityFilter >>> # Create data >>> data = np.random.rand(10, 500) >>> # Apply identity filter >>> identity_filter = IdentityFilter(input_is_chunked=False) >>> output = identity_filter(data) >>> # Check that output is the same as input >>> np.array_equal(data, output) # True Notes ----- This filter simply returns the input array unchanged. See Also -------- ApplyFunctionFilter : A more general filter for applying arbitrary functions """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, run_checks=run_checks, )
[docs] def _run_filter_checks(self, input_array: np.ndarray): """Run checks on the input array. Parameters ---------- input_array : np.ndarray The input array to check Raises ------ ValueError If the input is not a numpy array If multiple inputs are provided """ if isinstance(input_array, list): raise ValueError( f"IdentityFilter expects a single numpy array, but received a list of {len(input_array)} arrays. " f"This filter can only process one input at a time." )
[docs] def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray: """Return the input array unchanged. Parameters ---------- input_array : np.ndarray The input array to pass through **kwargs Additional keyword arguments from the Data object Returns ------- np.ndarray The input array unchanged """ return input_array