from abc import ABC, abstractmethod
import math
from typing import Any, Dict, List, Optional, Tuple, Literal, Union
import numpy as np
from scipy.signal import convolve
from myoverse.datasets.filters._template import FilterBaseClass
[docs]
class SpatialFilterGridAware(FilterBaseClass, ABC):
"""Base class for spatial filters that need to be grid-aware.
This class provides methods for handling multiple electrode grids in spatial filters.
It allows applying filters to specific grids and optionally preserving unprocessed grids.
Parameters
----------
input_is_chunked : bool
Whether the input is chunked or not.
allowed_input_type : Literal["chunked", "non_chunked", "both"]
Type of input this filter accepts.
grids_to_process : Union[Literal["all"], List[int]]
Specifies which grids to apply the filter to:
- "all": Process all grids (default)
- List[int]: Process only the grids with these indices
is_output : bool, optional
Whether the filter is an output filter.
name : str, optional
Name of the filter.
run_checks : bool, optional
Whether to run validation checks when filtering.
"""
[docs]
def __init__(
self,
input_is_chunked: bool = None,
allowed_input_type: Literal["both", "chunked", "not chunked"] = "both",
grids_to_process: Union[Literal["all"], List[int]] = "all",
is_output: bool = False,
name: str = None,
run_checks: bool = True,
):
super().__init__(
input_is_chunked=input_is_chunked,
allowed_input_type=allowed_input_type,
is_output=is_output,
name=name,
run_checks=run_checks,
)
self.grids_to_process = grids_to_process
[docs]
def _get_grids_to_filter(self, grid_layouts) -> list[int]:
# Determine which grids to process
if self.grids_to_process == "all":
output = list(range(len(grid_layouts)))
elif isinstance(self.grids_to_process, int):
output = [self.grids_to_process]
elif isinstance(self.grids_to_process, list):
output = self.grids_to_process
else:
raise ValueError(
'grids_to_process should be either Literal["all"], int, or list[int]'
)
return output
[docs]
@abstractmethod
def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray:
"""Apply the filter to the input array.
Parameters
----------
input_array : np.ndarray
The input array to filter.
**kwargs
Additional keyword arguments from the Data object.
.. important:: "grid_layouts" must be passed in kwargs for this filter to work.
Returns
-------
np.ndarray
The filtered array.
Raises
------
AttributeError
If the grid_layouts are not provided in kwargs. This filter only operates in grid-aware mode.
Notes
-----
.. important:: Insure that the grid_layouts are updated using the new grid_layout. Use pass by reference stuff.
.. code-block:: python
:linenos:
for i, new_grid_layout in enumerate(new_grid_layouts):
grid_layouts[i] = new_grid_layout
"""
raise NotImplementedError
[docs]
class DifferentialSpatialFilter(SpatialFilterGridAware):
"""Differential spatial filter for EMG data.
This filter applies various differential spatial filters to EMG data,
which help improve signal quality by enhancing differences between adjacent electrodes.
The filters are defined according to https://doi.org/10.1088/1741-2552/ad3498.
Parameters
----------
filter_name : Literal["LSD", "TSD", "LDD", "TDD", "NDD", "IB2", "IR", "identity"]
Name of the filter to be applied. Options include:
- "LSD": Longitudinal Single Differential - computes difference between adjacent electrodes along columns
- "TSD": Transverse Single Differential - computes difference between adjacent electrodes along rows
- "LDD": Longitudinal Double Differential - computes double difference along columns
- "TDD": Transverse Double Differential - computes double difference along rows
- "NDD": Normal Double Differential - combines information from electrodes in cross pattern
- "IB2": Inverse Binomial filter of the 2nd order
- "IR": Inverse Rectangle filter
- "identity": No filtering, returns the original signal
input_is_chunked : bool
Whether the input data is organized in chunks (3D array) or not (2D array).
grids_to_process : Union[Literal["all"], List[int]]
Specifies which grids to apply the filter to:
- "all": Process all grids (default)
- List[int]: Process only the grids with these indices
is_output : bool, default=False
Whether the filter is an output filter.
name : str, optional
Custom name for the filter. If None, the class name will be used.
run_checks : bool, default=True
Whether to run validation checks when filtering.
Notes
-----
This filter can work with both chunked and non-chunked EMG data, and can selectively
process specific grids when multiple grids are present in the data.
The convolution operation reduces the spatial dimensions based on the filter size,
which means the output will have fewer electrodes than the input.
Examples
--------
>>> import numpy as np
>>> from myoverse.datatypes import EMGData
>>> from myoverse.datasets.filters.spatial import DifferentialSpatialFilter
>>>
>>> # Create sample EMG data (64 channels, 1000 samples)
>>> emg_data = np.random.randn(64, 1000)
>>> emg = EMGData(emg_data, 2000)
>>>
>>> # Apply Laplacian filter to all grids
>>> ndd_filter = DifferentialSpatialFilter(
>>> filter_name="NDD",
>>> input_is_chunked=False
>>> )
>>> filtered_data = emg.apply_filter(ndd_filter)
>>>
>>> # Apply Laplacian filter to only the first grid
>>> ndd_first_grid = DifferentialSpatialFilter(
>>> filter_name="NDD",
>>> input_is_chunked=False,
>>> grids_to_process=0
>>> )
>>> filtered_first = emg.apply_filter(ndd_first_grid)
"""
# Dictionary below is used to define differential filters that can be applied across the monopolar electrode grids
_DIFFERENTIAL_FILTERS = {
"identity": np.array([[1]]), # identity case when no filtering is applied
"LSD": np.array([[-1], [1]]), # longitudinal single differential
"LDD": np.array([[1], [-2], [1]]), # longitudinal double differential
"TSD": np.array([[-1, 1]]), # transverse single differential
"TDD": np.array([[1, -2, 1]]), # transverse double differential
"NDD": np.array(
[[0, -1, 0], [-1, 4, -1], [0, -1, 0]]
), # normal double differential or Laplacian filter
"IB2": np.array(
[[-1, -2, -1], [-2, 12, -2], [-1, -2, -1]]
), # inverse binomial filter of order 2
"IR": np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]), # inverse rectangle
}
[docs]
def __init__(
self,
input_is_chunked: bool,
is_output: bool = False,
name: str | None = None,
run_checks: bool = True,
*,
filter_name: Literal[
"LSD", "TSD", "LDD", "TDD", "NDD", "IB2", "IR", "identity"
],
grids_to_process: Union[Literal["all"], int, List[int]] = "all",
):
super().__init__(
input_is_chunked=input_is_chunked,
allowed_input_type="both",
grids_to_process=grids_to_process,
is_output=is_output,
name=name,
run_checks=run_checks,
)
self.filter_name = filter_name
# Validate filter name
if self.run_checks and filter_name not in (
valid_filters := list(self._DIFFERENTIAL_FILTERS.keys())
):
raise ValueError(
f"Invalid filter_name: '{filter_name}'. Must be one of: {', '.join(valid_filters)}"
)
[docs]
def _run_filter_checks(self, input_array: np.ndarray):
"""Additional validation for input data.
Parameters
----------
input_array : np.ndarray
The input array to validate.
"""
super()._run_filter_checks(input_array)
[docs]
def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray:
"""Apply the selected differential spatial filter to the input array.
Parameters
----------
input_array : np.ndarray
The input EMG data to filter.
**kwargs
Additional keyword arguments from the Data object, including:
- grid_layouts: List of 2D arrays specifying electrode arrangements
- sampling_frequency: The sampling frequency of the EMG data
Returns
-------
np.ndarray
The filtered EMG data, with dimensions depending on the filter size and
convolution mode. The number of electrodes will typically be reduced.
Raises
------
AttributeError
If the grid_layouts are not provided in kwargs.
"""
# Get grid_layouts from kwargs
try:
grid_layouts = kwargs["grid_layouts"]
except KeyError:
raise AttributeError(
"grid_layouts not found in kwargs. This filter only operates in grid-aware mode."
)
grids_to_filter = self._get_grids_to_filter(grid_layouts)
outputs, new_grid_layouts = [], []
for i, grid_layout in enumerate(grid_layouts):
output = input_array[
...,
grid_layout.shape[0] * grid_layout.shape[1] * i : grid_layout.shape[0]
* grid_layout.shape[1]
* (i + 1),
:,
]
new_grid_layout = grid_layout - np.min(grid_layout)
if i in grids_to_filter:
output, new_grid_layout = self._apply_differential_filter(
output,
new_grid_layout,
)
outputs.append(output)
new_grid_layouts.append(
new_grid_layout + (0 if i == 0 else np.max(new_grid_layouts[-1]) + 1)
)
# Have to do this so that values are updated in the original grid_layouts. Passing by reference stuff
for i, new_grid_layout in enumerate(new_grid_layouts):
grid_layouts[i] = new_grid_layout
return np.concatenate(outputs, axis=-2)
[docs]
def _apply_differential_filter(self, grid_data, grid_layout):
"""Apply differential filter to a single grid's data.
Parameters
----------
grid_data : np.ndarray
Data for a single grid to filter
grid_layout : np.ndarray
The grid layout. Shape is (n_rows, n_cols).
Returns
-------
tuple[np.ndarray, np.ndarray]
Filtered grid data and new grid layout
"""
# Special case for identity filter
if self.filter_name == "identity":
return grid_data, grid_layout
reshaped_grid_data = np.zeros(
((grid_data.shape[0],) if self.input_is_chunked else ())
+ (*grid_layout.shape, grid_data.shape[-1])
)
for i, j in np.ndindex(grid_layout.shape):
index_to_select = grid_layout[i, j]
reshaped_grid_data[..., i, j, :] = grid_data[..., index_to_select, :]
differential_filter = self._DIFFERENTIAL_FILTERS[self.filter_name]
new_grid_layout = np.lib.stride_tricks.sliding_window_view(
grid_layout.copy(), differential_filter.shape
).min((-2, -1))
mapping = {val: idx for idx, val in enumerate(np.unique(new_grid_layout))}
new_grid_layout = np.vectorize(lambda x: mapping[x])(new_grid_layout)
convolution_result = convolve(
reshaped_grid_data,
np.expand_dims(
differential_filter,
axis=(0, -1) if self.input_is_chunked else (-1,),
),
mode="valid",
).astype(np.float32)
output = np.zeros(
((convolution_result.shape[0],) if self.input_is_chunked else ())
+ (np.max(new_grid_layout) + 1, convolution_result.shape[-1]),
dtype=convolution_result.dtype,
)
for i, j in np.ndindex(new_grid_layout.shape):
output[..., new_grid_layout[i, j], :] = convolution_result[..., i, j, :]
return output, new_grid_layout
[docs]
class ApplyFunctionSpatialFilter(SpatialFilterGridAware):
"""Apply a function over the EMG grids using a user defined kernel.
This filter applies a function over a user defined grid. The user can define the
kernel and the function will be applied over the grid using the kernel.
Parameters
----------
kernel_size : tuple[int, int]
The kernel size to use for the convolution. Must be a tuple of two integers.
strides : tuple[int, int]
The strides to use for the convolution. Must be a tuple of two integers.
padding : str
The padding to use for the convolution. Must be one of "same" or "valid".
function : callable
The function to apply over the grid. If input_is_chunked is True, the function must take and return a 4D array otherwise it must take and return a 3D array.
.. note:: The input shape will be (chunks, time, y, x) if input_is_chunked is True and (time, y, x) if input_is_chunked is False.
.. warning:: The function should only modify the y and x dimensions of the input array.
input_is_chunked : bool
Whether the input data is organized in chunks (3D array) or not (2D array).
grids_to_process : Union[Literal["all"], List[int]]
Specifies which grids to apply the filter to:
- "all": Process all grids (default)
- List[int]: Process only the grids with these indices
is_output : bool, default=False
Whether the filter is an output filter.
name : str, optional
Custom name for the filter. If None, the class name will be used.
run_checks : bool, default=True
Whether to run validation checks when filtering.
Notes
-----
This filter can work with both chunked and non-chunked EMG data, and can selectively
process specific grids when multiple grids are present in the data.
.. important:: Because the filter can have strides the output will always be row-major. This also reflects in the new grid layout.
"""
[docs]
def __init__(
self,
input_is_chunked: bool,
is_output: bool = False,
name: str | None = None,
run_checks: bool = True,
*,
kernel_size: tuple[int, int],
function: callable,
strides: tuple[int, int] = (1, 1),
padding: str = "same",
grids_to_process: Union[Literal["all"], List[int]] = "all",
):
super().__init__(
input_is_chunked=input_is_chunked,
allowed_input_type="both",
grids_to_process=grids_to_process,
is_output=is_output,
name=name,
run_checks=run_checks,
)
if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
raise ValueError("Kernel must be a tuple of two integers.")
self.kernel_size = kernel_size
if not callable(function):
raise ValueError("Function must be a callable.")
self.function = function
if not isinstance(strides, tuple) or len(strides) != 2:
raise ValueError("Strides must be a tuple of two integers.")
self.strides = strides
if padding not in ["same", "valid"]:
raise ValueError("Padding must be 'same' or 'valid'.")
self.padding = padding
[docs]
def _run_filter_checks(self, input_array: np.ndarray):
"""Additional validation for input data.
Parameters
----------
input_array : np.ndarray
The input array to validate.
"""
super()._run_filter_checks(input_array)
[docs]
def _filter(self, input_array: np.ndarray, **kwargs) -> np.ndarray:
"""Apply the selected differential spatial filter to the input array.
Parameters
----------
input_array : np.ndarray
The input EMG data to filter.
**kwargs
Additional keyword arguments from the Data object, including:
- grid_layouts: List of 2D arrays specifying electrode arrangements
- sampling_frequency: The sampling frequency of the EMG data
Returns
-------
np.ndarray
The filtered EMG data, with dimensions depending on the filter size and
convolution mode. The number of electrodes will typically be reduced.
Raises
------
AttributeError
If the grid_layouts are not provided in kwargs.
"""
# Get grid_layouts from kwargs
try:
grid_layouts = kwargs["grid_layouts"]
except KeyError:
raise AttributeError(
"grid_layouts not found in kwargs. This filter only operates in grid-aware mode."
)
grids_to_filter = self._get_grids_to_filter(grid_layouts)
outputs, new_grid_layouts = [], []
for i, grid_layout in enumerate(grid_layouts):
output = input_array[
...,
grid_layout.shape[0] * grid_layout.shape[1] * i : grid_layout.shape[0]
* grid_layout.shape[1]
* (i + 1),
:,
]
new_grid_layout = grid_layout - np.min(grid_layout)
if i in grids_to_filter:
output, new_grid_layout = self._apply_custom_filter(
output,
new_grid_layout,
)
outputs.append(output)
new_grid_layouts.append(
new_grid_layout + (0 if i == 0 else np.max(new_grid_layouts[-1]) + 1)
)
# Have to do this so that values are updated in the original grid_layouts. Passing by reference stuff
for i, new_grid_layout in enumerate(new_grid_layouts):
grid_layouts[i] = new_grid_layout
return np.concatenate(outputs, axis=-2)
[docs]
def _apply_custom_filter(self, grid_data, grid_layout):
"""Apply custom filter to a single grid's data.
Parameters
----------
grid_data : np.ndarray
Data for a single grid to filter
grid_layout : np.ndarray
The grid layout. Shape is (n_rows, n_cols).
Returns
-------
tuple[np.ndarray, np.ndarray]
Filtered grid data and new grid layout
"""
ky, kx = self.kernel_size
sy, sx = self.strides
reshaped_grid_data = np.zeros(
((grid_data.shape[0],) if self.input_is_chunked else ())
+ (*grid_layout.shape, grid_data.shape[-1])
)
for i, j in np.ndindex(grid_layout.shape):
index_to_select = grid_layout[i, j]
reshaped_grid_data[..., i, j, :] = grid_data[..., index_to_select, :]
# Calculate padding for 'same' mode
if self.padding == "same":
# For y-axis
input_y = reshaped_grid_data.shape[-3]
output_y = math.ceil(input_y / sy)
pad_total_y = max((output_y - 1) * sy + ky - input_y, 0)
pad_y_before = pad_total_y // 2
pad_y_after = pad_total_y - pad_y_before
# For x-axis
input_x = reshaped_grid_data.shape[-2]
output_x = math.ceil(input_x / sx)
pad_total_x = max((output_x - 1) * sx + kx - input_x, 0)
pad_x_before = pad_total_x // 2
pad_x_after = pad_total_x - pad_x_before
grid_data_padded = np.pad( # noqa
reshaped_grid_data,
(((0, 0),) if self.input_is_chunked else ())
+ ((pad_y_before, pad_y_after), (pad_x_before, pad_x_after), (0, 0)),
mode="constant",
)
elif self.padding == "valid":
grid_data_padded = grid_data
else:
raise ValueError("Padding must be 'same' or 'valid'")
# Generate sliding windows and apply strides
windows = np.lib.stride_tricks.sliding_window_view( # noqa
grid_data_padded, (ky, kx), axis=(-3, -2)
)[..., ::sy, ::sx, :, :, :]
# Apply the function to each window
new_y, new_x = None, None
new_windows = None
for y, x in np.ndindex(windows.shape[-5:-3]):
function_output = self.function(windows[..., y, x, :, :, :])
if y == x == 0:
new_y, new_x = function_output.shape[-2], function_output.shape[-1]
new_windows = np.zeros(
((windows.shape[0],) if self.input_is_chunked else ())
+ (
new_y * windows.shape[-5],
new_x * windows.shape[-4],
windows.shape[-3],
),
dtype=windows.dtype,
)
new_windows[
..., y * new_y : (y + 1) * new_y, x * new_y : (x + 1) * new_x, :
] = np.transpose(
self.function(windows[..., y, x, :, :, :]),
(2, 3, 0, 1) if self.input_is_chunked else (1, 2, 0),
)
y, x = new_windows.shape[-3:-1]
new_grid_layout = np.arange(0, y * x).reshape((y, x), order="F")
new_windows = np.reshape(
new_windows,
((new_windows.shape[0],) if self.input_is_chunked else ())
+ (-1, new_windows.shape[-1]),
order="F",
)
return new_windows, new_grid_layout