Source code for myoverse.datasets.filters.temporal

from __future__ import annotations

import numpy as np
from numpy.lib.stride_tricks import as_strided
from numpy.random.mtrand import Sequence
from scipy.fft import irfft, rfft, rfftfreq
from scipy.signal import savgol_filter, sosfilt, sosfiltfilt

from myoverse.datasets.filters._template import FilterBaseClass
from myoverse.datasets.filters.generic import (
    ApplyFunctionFilter,
    _get_windows_with_shift,
)


[docs] class SOSFrequencyFilter(FilterBaseClass): """Filter that applies a second-order-section filter to the input array. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. sos_filter_coefficients : tuple[np.ndarray, np.ndarray | float, np.ndarray] The second-order-section filter coefficients, typically from scipy.signal.butter with output="sos". forwards_and_backwards : bool Whether to apply the filter forwards and backwards or only forwards. continuous_approach : tuple[int, int] | None If input_is_chunked is True and a tuple of window size and shift is provided, the filter will first unchunk the data, filter, then re-chunk the data. .. warning:: Some data **will** be lost at the end of the signal if the window shift is not equal to the window size. real_time_support_chunks: np.ndarray | None If not None and continuous_approach is not None, the support chunks will be prepended to the input array. That way boundary Methods ------- __call__(input_array: np.ndarray) -> np.ndarray Filters the input array. Input shape is determined by whether the allowed_input_type is "both", "chunked" or "not chunked". reset_state() Resets the internal filter state. Only relevant when real_time_mode=True. """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, sos_filter_coefficients: tuple[np.ndarray, np.ndarray | float, np.ndarray], forwards_and_backwards: bool = True, continuous_approach: tuple[int, int] | None = None, real_time_support_chunks: np.ndarray | None = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, run_checks=run_checks, ) self.sos_filter_coefficients = sos_filter_coefficients self.forwards_and_backwards = forwards_and_backwards self.continuous_approach = continuous_approach self.real_time_support_chunks = real_time_support_chunks self._filtering_method = sosfiltfilt if self.forwards_and_backwards else sosfilt
[docs] def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray: """Apply the filter to the input array. Parameters ---------- input_array : numpy.ndarray Input array to filter **kwargs Additional keyword arguments from the Data object Returns ------- numpy.ndarray Filtered array """ if not self.input_is_chunked: # Simply apply the filter to the non-chunked data return self._filtering_method(self.sos_filter_coefficients, input_array) if self.continuous_approach: temp = input_array if self.real_time_support_chunks is not None: # Prepend the support chunks to the input array temp = np.concatenate( (self.real_time_support_chunks, temp), axis=0 ) temp = self._filtering_method(self.sos_filter_coefficients, np.concatenate(temp[..., : self.continuous_approach[1]], axis=-1)) # Reshape the filtered data back to the original shape temp = _get_windows_with_shift(temp, *self.continuous_approach) if self.real_time_support_chunks is not None: # Remove the prepended support chunks temp = temp[self.real_time_support_chunks.shape[0] :] return temp return self._filtering_method(self.sos_filter_coefficients, input_array)
[docs] class RectifyFilter(ApplyFunctionFilter): """Filter that rectifies the input array. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, ): super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, function=np.abs, )
[docs] class WindowedFunctionFilter(ApplyFunctionFilter): """Base class for filters that apply a function to windowed data. This filter creates windows using _get_windows_with_shift and then applies a specified function to each window. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. window_function : callable Function to apply to each window (along the last axis). .. note:: The function should take two arguments: the window and the axis to apply the function along. """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, window_function: callable, ): # Validate parameters if window_size < 1: raise ValueError("window_size must be greater than 0.") if shift < 1: raise ValueError("shift must be greater than 0.") # Define the windowed function application def apply_window_function(x, window_size, shift, window_function): # Get windows if not already chunked windowed_array = _get_windows_with_shift(x, window_size, shift) # Apply the function to each window func_result = np.squeeze(window_function(windowed_array, axis=-1), axis=-1) return np.transpose(func_result, (*list(range(func_result.ndim))[1:], 0)) # Initialize parent with the windowed function super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, function=apply_window_function, window_size=window_size, shift=shift, window_function=window_function, ) # Store parameters for reference self.window_size = window_size self.shift = shift
[docs] class RMSFilter(WindowedFunctionFilter): """Filter that computes the root mean squared value [1]_ of the input array. Root mean squared value is the square root of the mean of the squared values of the input array. It is a measure of the magnitude of the signal. .. math:: \\text{RMS} = \\sqrt{\\frac{1}{N} \\sum_{i=1}^{N} x_i^2} Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. stabilization_factor : float A small value to add to the squared values before taking the mean to stabilize the computation. By default, this is the machine epsilon for float values. See numpy.finfo(float).eps. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, stabilization_factor: float = np.finfo(float).eps, ): super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=lambda x, axis: np.sqrt( np.mean(np.square(x), axis=axis, keepdims=True) + stabilization_factor ), )
[docs] class VARFilter(WindowedFunctionFilter): """Filter that computes the variance [1]_ of the input array. Variance is the average of the squared differences from the mean. It is a measure of the spread of the signal. .. math:: \\text{VAR} = \\frac{1}{N} \\sum_{i=1}^{N} (x_i - \\mu)^2 Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, ): # Initialize with variance function super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=lambda x, axis: np.var(x, axis=axis, keepdims=True), )
[docs] class MAVFilter(WindowedFunctionFilter): """Filter that computes the mean absolute value [1]_ of the input array. Mean absolute value is the average of the absolute values of the input array. It is a measure of the average magnitude of the signal. .. math:: \\text{MAV} = \\frac{1}{N} \\sum_{i=1}^{N} |x_i| Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, ): super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=lambda x, axis: np.mean( np.abs(x), axis=axis, keepdims=True ), )
[docs] class IAVFilter(WindowedFunctionFilter): """Filter that computes the integrated absolute value [1]_ of the input array. Integrated absolute value is the sum of the absolute values of the input array. It is a measure of the total magnitude of the signal. .. math:: \\text{IAV} = \\sum_{i=1}^{N} |x_i| Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, ): super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=lambda x, axis: np.sum(np.abs(x), axis=axis, keepdims=True), )
[docs] class WFLFilter(WindowedFunctionFilter): """Filter that computes the waveform length [1]_ of the input array. Waveform length is the sum of the absolute differences between consecutive samples. It is a measure of the total magnitude of the signal. .. math:: \\text{WFL} = \\sum_{i=1}^{N} |x_i - x_{i-1}| Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, ): super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=lambda x, axis: np.sum( np.abs(np.diff(x, axis=axis)), axis=axis, keepdims=True ), )
[docs] class ZCFilter(WindowedFunctionFilter): """Computes the zero crossings [1]_ of the input array. Zero crossings are the number of times the signal crosses the zero axis. It is a measure of the number of times the signal changes sign. .. math:: \\text{ZC} = \\sum_{i=1}^{N} \\frac{1}{2} |\\text{sign}(x_i) - \\text{sign}(x_{i-1})| Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, ): # Define Zero Crossing function def zc_function(windowed_array, axis=-1): # 1. Calculate sign of all elements in each window signs = np.sign(windowed_array) # 2. Calculate differences of signs to find changes sign_changes = np.diff(signs, axis=axis) # 3. Count absolute changes (divided by 2 since each crossing counts twice) return np.sum(np.abs(sign_changes) // 2, axis=axis, keepdims=True) super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=zc_function, )
[docs] class SSCFilter(WindowedFunctionFilter): """Computes the slope sign change [1]_ of the input array. Slope sign change is the number of times the slope of the signal changes sign. It is a measure of the number of times the signal changes direction. .. math:: \\text{SSC} = \\sum_{i=1}^{N} \\frac{1}{2} |\\text{sign}(x_i - x_{i-1}) - \\text{sign}(x_{i-1} - x_{i-2})| Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. window_size : int The window size to use. shift : int The shift to use. Default is 1. References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, window_size: int, shift: int = 1, ): # Define Slope Sign Change function def ssc_function(windowed_array, axis=-1): # 1. Calculate differences (first derivative) diffs = np.diff(windowed_array, axis=axis) # 2. Calculate sign of differences sign_diffs = np.sign(diffs) # 3. Calculate sign changes of the sign differences (second derivative sign changes) sign_changes = np.diff(sign_diffs, axis=axis) # 4. Count number of sign changes in each window return np.sum(np.abs(sign_changes) // 2, axis=axis, keepdims=True) super().__init__( input_is_chunked=input_is_chunked, is_output=is_output, name=name, run_checks=run_checks, window_size=window_size, shift=shift, window_function=ssc_function, )
[docs] class SpectralInterpolationFilter(FilterBaseClass): """Filter that removes certain frequency bands from the signal and interpolates the gaps. This is ideal for removing power line interference or other narrowband noise. The filter works by: 1. Computing FFT of the signal 2. Setting the magnitude of the specified frequency bands to interpolated values 3. Preserving the phase information 4. Converting back to time domain .. warning:: When used for real-time applications, performance depends on chunk size. Smaller chunks reduce latency but may decrease frequency resolution and interpolation quality, especially for narrow frequency bands. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. name : str | None Name of the filter, by default None. run_checks : bool Whether to run the checks when filtering. By default, True. If False can potentially speed up performance. .. warning:: If False, the user is responsible for ensuring that the input array is valid. bandwidth : tuple[float, float] Frequency band to remove (min_freq, max_freq) in Hz, by default (47.5, 52.5) number_of_harmonics : int Number of harmonics to remove, by default 3 sampling_frequency : float Sampling frequency in Hz, by default 2000 interpolation_window : int Window size for interpolation, must be an odd number, by default 15 interpolation_poly_order : int Polynomial order for interpolation, must be less than interpolation_window, by default 3 remove_dc : bool Whether to remove the DC component (set FFT[0] to 0), by default True """
[docs] def __init__( self, input_is_chunked: bool, is_output: bool = False, name: str | None = None, run_checks: bool = True, *, bandwidth: tuple[float, float] = (47.5, 52.5), number_of_harmonics: int = 3, sampling_frequency: float = 2000, interpolation_window: int = 15, interpolation_poly_order: int = 3, remove_dc: bool = True, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, run_checks=run_checks, ) self.bandwidth = bandwidth self.number_of_harmonics = number_of_harmonics self.sampling_frequency = sampling_frequency # Validate interpolation parameters if interpolation_window % 2 == 0: raise ValueError("interpolation_window must be an odd number") if interpolation_poly_order >= interpolation_window: raise ValueError( "interpolation_poly_order must be less than interpolation_window" ) self.interpolation_window = interpolation_window self.interpolation_poly_order = interpolation_poly_order self.remove_dc = remove_dc # Pre-compute harmonic frequencies center_freq = (bandwidth[0] + bandwidth[1]) / 2 self.harmonic_freqs = [ (i * center_freq, i * bandwidth[0], i * bandwidth[1]) for i in range(1, number_of_harmonics + 1) ]
[docs] def _get_indices_to_interpolate(self, freqs): """Get the indices of the frequency bins to interpolate. Parameters ---------- freqs : numpy.ndarray Frequency bins from rfftfreq Returns ------- list List of arrays, each containing indices for a harmonic frequency band """ indices_list = [] for _, min_freq, max_freq in self.harmonic_freqs: indices = np.where((freqs >= min_freq) & (freqs <= max_freq))[0] indices_list.append(indices) return indices_list
[docs] def _filter(self, input_array, **kwargs): """Apply the filter to the input array. Parameters ---------- input_array : numpy.ndarray Input array to filter **kwargs Additional keyword arguments from the Data object Returns ------- numpy.ndarray Filtered array """ # Use sampling_frequency from kwargs if available, otherwise use the instance attribute sampling_frequency = kwargs.get("sampling_frequency", self.sampling_frequency) # Save original shape original_shape = input_array.shape # For multidimensional arrays, reshape to 2D for easier processing if len(original_shape) > 1: reshaped_array = input_array.reshape(-1, original_shape[-1]) # Process each signal for j in range(reshaped_array.shape[0]): # Compute the FFT signal_fft = rfft(reshaped_array[j], axis=-1) # Set the DC component to zero (optional) if self.remove_dc: signal_fft[0] = 0 # Calculate frequency bins freqs = rfftfreq(original_shape[-1], d=1 / sampling_frequency) # Get interpolation indices for each harmonic's frequency band for indices in self._get_indices_to_interpolate(freqs): if len(indices) > 0: # Get magnitude and phase magnitude = np.abs(signal_fft) phase = np.angle(signal_fft) # Determine appropriate window size based on available data window_size = min(self.interpolation_window, len(magnitude) - 2) if window_size % 2 == 0: window_size -= 1 # Ensure it's odd # Ensure polynomial order is appropriate for window size poly_order = min(self.interpolation_poly_order, window_size - 1) # Use 'nearest' mode if data is too short for 'interp' filter_mode = ( "nearest" if window_size >= len(magnitude) else "interp" ) # Apply savgol_filter to get a smooth interpolation, with adjusted parameters if ( window_size >= 3 ): # Minimum window size for Savitzky-Golay filter smooth_magnitude = savgol_filter( magnitude, window_size, poly_order, axis=-1, mode=filter_mode, ) # Replace the magnitude in the specified indices while preserving phase signal_fft[indices] = smooth_magnitude[indices] * np.exp( 1j * phase[indices] ) # Convert back to time domain reshaped_array[j] = irfft(signal_fft, n=original_shape[-1], axis=-1) # Reshape back to original dimensions output_array = reshaped_array.reshape(original_shape) else: # Simple case: just one dimension freqs = rfftfreq(original_shape[-1], d=1 / sampling_frequency) # Compute the FFT signal_fft = rfft(input_array, axis=-1) # Set the DC component to zero (optional) if self.remove_dc: signal_fft[0] = 0 # Get interpolation indices for each harmonic's frequency band for indices in self._get_indices_to_interpolate(freqs): if len(indices) > 0: # Get magnitude and phase magnitude = np.abs(signal_fft) phase = np.angle(signal_fft) # Determine appropriate window size based on available data window_size = min(self.interpolation_window, len(magnitude) - 2) if window_size % 2 == 0: window_size -= 1 # Ensure it's odd # Ensure polynomial order is appropriate for window size poly_order = min(self.interpolation_poly_order, window_size - 1) # Use 'nearest' mode if data is too short for 'interp' filter_mode = ( "nearest" if window_size >= len(magnitude) else "interp" ) # Apply savgol_filter to get a smooth interpolation, with adjusted parameters if ( window_size >= 3 ): # Minimum window size for Savitzky-Golay filter smooth_magnitude = savgol_filter( magnitude, window_size, poly_order, axis=-1, mode=filter_mode, ) # Replace the magnitude in the specified indices while preserving phase signal_fft[indices] = smooth_magnitude[indices] * np.exp( 1j * phase[indices] ) # Convert back to time domain output_array = irfft(signal_fft, n=original_shape[-1], axis=-1) return output_array