Source code for myoverse.datasets.filters.temporal

from __future__ import annotations

from typing import List, Literal, Sequence, Tuple, Union

import numpy as np
from scipy.fft import irfft, rfft, rfftfreq
from scipy.signal import savgol_filter, sosfilt, sosfiltfilt
from statsmodels.tsa.ar_model import AutoReg

from myoverse.datasets.filters._template import FilterBaseClass
from myoverse.datasets.filters.generic import ApplyFunctionFilter


[docs] class SOSFrequencyFilter(FilterBaseClass): """Filter that applies a second-order-section filter to the input array. Parameters ---------- sos_filter_coefficients : tuple[np.ndarray, np.ndarray | float, np.ndarray] The second-order-section filter coefficients. This is a tuple of the form (sos, gain, delay). forwards_and_backwards : bool Whether to apply the filter forwards and backwards or only forwards. input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. Methods ------- __call__(input_array: np.ndarray) -> np.ndarray Filters the input array. Input shape is determined by whether the allowed_input_type is "both", "chunked" or "not chunked". """ def __init__( self, sos_filter_coefficients: tuple[ np.ndarray, Union[np.ndarray, float], np.ndarray ], forwards_and_backwards: bool = True, input_is_chunked: bool = None, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.sos_filter_coefficients = sos_filter_coefficients self.forwards_and_backwards = forwards_and_backwards self._filtering_method = sosfiltfilt if self.forwards_and_backwards else sosfilt def _filter(self, input_array: np.ndarray) -> np.ndarray: return self._filtering_method(self.sos_filter_coefficients, input_array)
[docs] class RectifyFilter(ApplyFunctionFilter): """Filter that rectifies the input array. Parameters ---------- input_is_chunked : bool Whether the input is chunked or not. is_output : bool Whether the filter is an output filter. If True, the resulting signal will be outputted by and dataset pipeline. """ def __init__(self, input_is_chunked: bool = None, is_output: bool = False): super().__init__( input_is_chunked=input_is_chunked, function=np.abs, is_output=is_output )
[docs] class ARFilter(FilterBaseClass): """Filter that computes n autoregressive coefficients for each window of the input array.""" def __init__( self, n_coefficients: int = 4, input_is_chunked: bool = True, representations_to_filter: Union[Literal["all"], Sequence[int]] = (0,), ): super().__init__( input_is_chunked=input_is_chunked, representations_to_filter=representations_to_filter, ) self.n_coefficients = n_coefficients def _filter( self, input_array: np.ndarray, representations_to_filter_indices: np.ndarray ) -> np.ndarray: output_array = [] for i in range(len(representations_to_filter_indices)): for j in range(input_array.shape[1]): for k in range(input_array.shape[2]): output_array.append( AutoReg( input_array[representations_to_filter_indices[i], j, k], lags=self.n_coefficients - 1, ) .fit() .params ) return output_array
[docs] class RMSFilter(FilterBaseClass): """Filter that computes the root mean squared value [1]_ of the input array. Parameters ---------- window_size : int The window size to use. shift : int The shift to use. input_is_chunked : bool Whether the input is chunked or not. Methods ------- __call__(input_array: np.ndarray) -> np.ndarray Filters the input array. Input shape is determined by whether the allowed_input_type is "both", "chunked" or "not chunked". References ---------- .. [1] https://doi.org/10.1080/10255842.2023.2165068 """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = None, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift if self.window_size < 1: raise ValueError("window_size must be greater than 0.") if self.shift < 1: raise ValueError("shift must be greater than 0.") def _filter(self, input_array: np.ndarray) -> np.ndarray: if self.input_is_chunked: return np.concatenate( [ np.sqrt( np.mean( input_array[..., i : i + self.window_size] ** 2, axis=-1, keepdims=True, ) ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, ) return np.sqrt(np.mean(input_array**2, axis=-1))
[docs] class VARFilter(FilterBaseClass): """Computes the Variance with given window length and window shift over the input signal.""" def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.var( input_array[..., i : i + self.window_size], axis=-1, keepdims=True, ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
# TODO: Check if this is correct
[docs] class HISTFilter(FilterBaseClass): """Computes the Histogram with given window length and window shift over the input signal.""" def __init__( self, window_size: int, shift: int = 1, bins: int = 10, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift self.bins = bins def _filter(self, input_array: np.ndarray) -> np.ndarray: output_array = [] for i in range(0, input_array.shape[-1] - self.window_size + 1, self.shift): input_segment = input_array[..., i : i + self.window_size] histograms = np.zeros( (input_segment.shape[0], input_segment.shape[1], self.bins) ) for j in range(histograms.shape[0]): for k in range(histograms.shape[1]): histograms[j, k] = np.histogram( input_segment[j, k], bins=self.bins )[0] output_array.append( histograms.reshape((histograms.shape[0], -1, 1), order="F") ) return np.concatenate(output_array, axis=-1)
[docs] class MAVFilter(FilterBaseClass): """Computes the Mean Absolute Value with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.1080/10255842.2023.2165068. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.mean( np.abs(input_array[..., i : i + self.window_size]), axis=-1, keepdims=True, ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
[docs] class IAVFilter(FilterBaseClass): """Computes the Integrated Absolute Value with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.1080/10255842.2023.2165068. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.sum( np.abs(input_array[..., i : i + self.window_size]), axis=-1, keepdims=True, ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
[docs] class WFLFilter(FilterBaseClass): """Computes the Waveform Length with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.1080/10255842.2023.2165068. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.sum( np.abs(np.diff(input_array[..., i : i + self.window_size])), axis=-1, keepdims=True, ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
[docs] class ZCFilter(FilterBaseClass): """Computes the Zero Crossings with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.1080/10255842.2023.2165068. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.sum( np.abs(np.diff(np.sign(input_array[..., i : i + self.window_size]))) // 2, axis=-1, keepdims=True, ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
[docs] class SSCFilter(FilterBaseClass): """Computes the Slope Sign Change with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.1080/10255842.2023.2165068. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.sum( np.abs( np.diff( np.sign(np.diff(input_array[..., i : i + self.window_size])) ) ) // 2, axis=-1, keepdims=True, ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
[docs] class GaileyFeature2(FilterBaseClass): """Computes the second EMG feature from the Gailey et al. paper with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.3389/fneur.2017.00007. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: return np.concatenate( [ np.log( ( np.sum( input_array[..., i : i + self.window_size] ** 2, axis=-1, keepdims=True, ) + np.finfo(float).eps ) / self.window_size ) for i in range( 0, input_array.shape[-1] - self.window_size + 1, self.shift ) ], axis=-1, )
[docs] class GaileyFeature3(FilterBaseClass): """Computes the third EMG feature from the Gailey et al. paper with given window length and window shift over the input signal. See formula in the following paper: https://doi.org/10.3389/fneur.2017.00007. """ def __init__( self, window_size: int, shift: int = 1, input_is_chunked: bool = True, is_output: bool = False, name: str = None, ): super().__init__( input_is_chunked=input_is_chunked, allowed_input_type="both", is_output=is_output, name=name, ) self.window_size = window_size self.shift = shift def _filter(self, input_array: np.ndarray) -> np.ndarray: output_array = [] for i in range(0, input_array.shape[-1] - self.window_size + 1, self.shift): segment = input_array[..., i : i + self.window_size] m0 = np.sum(segment**2, axis=-1, keepdims=True) m2 = np.sum( np.diff(segment, axis=-1, prepend=segment[..., [0]]) ** 2, axis=-1, keepdims=True, ) m4 = np.sum( np.diff( np.diff(segment, axis=-1, prepend=segment[..., [0]]), prepend=segment[..., [0]], ) ** 2, axis=-1, keepdims=True, ) IF = np.sqrt( m2**2 / (m0 + np.finfo(float).eps) / (m4 + np.finfo(float).eps) ) WL = np.sum( np.abs(np.diff(segment, axis=-1, prepend=segment[..., [0]])), axis=-1, keepdims=True, ) output_array.append( np.log((IF + np.finfo(float).eps) / (WL + np.finfo(float).eps)) ) return np.concatenate(output_array, axis=-1)
# TODO
[docs] class SpectralInterpolationFilter(FilterBaseClass): def __init__( self, bandwidth: Tuple[float, float] = (47.5, 50.75), number_of_harmonics: int = 5, emg_frequency: float = 2044, input_is_chunked: bool = True, representations_to_filter: Union[Literal["all"], Sequence[int]] = "all", ): super().__init__( input_is_chunked=input_is_chunked, representations_to_filter=representations_to_filter, ) self.bandwidth = bandwidth self.number_of_harmonics = number_of_harmonics self.emg_frequency = emg_frequency self._indices_to_interpolate = ( np.repeat(np.array([bandwidth]), self.number_of_harmonics, axis=0) * np.arange(1, self.number_of_harmonics + 1)[..., None] ) def _get_indices_to_interpolate(self, rfft_freqs: np.ndarray) -> List[np.ndarray]: mean_diff = np.mean(np.diff(rfft_freqs)) / 2 return [ np.argwhere( np.logical_and( frequency_to_interpolate[0] - mean_diff <= rfft_freqs, rfft_freqs <= frequency_to_interpolate[1] + mean_diff, ) ).flatten() for frequency_to_interpolate in self._indices_to_interpolate ] def _filter( self, input_array: np.ndarray, representations_to_filter_indices: np.ndarray ) -> np.ndarray: output_array = np.copy(input_array) fourier = rfft(input_array[representations_to_filter_indices], axis=-1) fourier[..., 0] = 0 smooth_fourier = savgol_filter(np.abs(fourier), 15, 3, axis=-1) for i, indices in enumerate( self._get_indices_to_interpolate( rfftfreq(input_array.shape[-1], d=1 / self.emg_frequency) ) ): fourier[..., indices] = smooth_fourier[..., indices] output_array[representations_to_filter_indices] = irfft(fourier, axis=-1) return output_array