Source code for myoverse.transforms.spatial

"""GPU-accelerated spatial transforms for grid-based EMG.

Spatial filters operate on electrode grids using 2D convolutions.
All transforms work with named tensors on any device.

Example:
-------
>>> import torch
>>> from myoverse.transforms import NDD, LSD, Pipeline
>>>
>>> # Create EMG tensor with grid info
>>> emg = myoverse.emg_tensor(data, grid_layouts=[grid1, grid2])
>>>
>>> # Apply spatial filter
>>> ndd = NDD(grids="all")
>>> filtered = ndd(emg)

"""

from __future__ import annotations

import numpy as np
import torch
import torch.nn.functional as F

from myoverse.transforms.base import TensorTransform

# Standard spatial filter kernels
SPATIAL_KERNELS = {
    # Normal Double Differential (Laplacian)
    "NDD": torch.tensor(
        [
            [0, -1, 0],
            [-1, 4, -1],
            [0, -1, 0],
        ],
        dtype=torch.float32,
    )
    / 4,
    # Longitudinal Single Differential (vertical)
    "LSD": torch.tensor(
        [
            [0, -1, 0],
            [0, 1, 0],
            [0, 0, 0],
        ],
        dtype=torch.float32,
    ),
    # Transverse Single Differential (horizontal)
    "TSD": torch.tensor(
        [
            [0, 0, 0],
            [-1, 1, 0],
            [0, 0, 0],
        ],
        dtype=torch.float32,
    ),
    # Inverse Binomial 2nd order
    "IB2": torch.tensor(
        [
            [1, -2, 1],
            [-2, 4, -2],
            [1, -2, 1],
        ],
        dtype=torch.float32,
    )
    / 4,
}


def _channels_to_grid(
    data: torch.Tensor,
    grid_layout: np.ndarray,
) -> torch.Tensor:
    """Reshape channel data to grid layout.

    Parameters
    ----------
    data : torch.Tensor
        Data with shape (..., n_channels, time).
    grid_layout : np.ndarray
        2D array mapping (row, col) to channel index. -1 for gaps.

    Returns
    -------
    torch.Tensor
        Data with shape (..., rows, cols, time). Gaps filled with 0.

    """
    rows, cols = grid_layout.shape
    time_dim = data.shape[-1]
    batch_shape = data.shape[:-2]

    # Create output tensor
    out = torch.zeros(
        *batch_shape,
        rows,
        cols,
        time_dim,
        dtype=data.dtype,
        device=data.device,
    )

    # Fill in valid electrodes using relative indexing
    ch_idx = 0
    for r in range(rows):
        for c in range(cols):
            if grid_layout[r, c] >= 0:
                out[..., r, c, :] = data[..., ch_idx, :]
                ch_idx += 1

    return out


def _grid_to_channels(
    data: torch.Tensor,
    grid_layout: np.ndarray,
) -> torch.Tensor:
    """Reshape grid data back to channels.

    Parameters
    ----------
    data : torch.Tensor
        Data with shape (..., rows, cols, time).
    grid_layout : np.ndarray
        2D array mapping (row, col) to channel index. -1 for gaps.

    Returns
    -------
    torch.Tensor
        Data with shape (..., n_channels, time).

    """
    rows, cols = grid_layout.shape
    time_dim = data.shape[-1]
    batch_shape = data.shape[:-3]

    # Count valid channels
    n_channels = np.sum(grid_layout >= 0)

    # Create output
    out = torch.zeros(
        *batch_shape,
        n_channels,
        time_dim,
        dtype=data.dtype,
        device=data.device,
    )

    # Extract valid electrodes
    ch_idx = 0
    for r in range(rows):
        for c in range(cols):
            if grid_layout[r, c] >= 0:
                out[..., ch_idx, :] = data[..., r, c, :]
                ch_idx += 1

    return out


[docs] class SpatialFilter(TensorTransform): """Apply spatial filtering using grid layouts. Spatial filters use 2D convolution on electrode grids. Grid layouts must be stored as a tensor attribute (via myoverse.emg_tensor). Parameters ---------- kernel : str | torch.Tensor Filter kernel. Either a name ("NDD", "LSD", "TSD", "IB2") or a custom 2D tensor. grids : str | list[int] Which grids to filter. "all" for all grids, or list of indices. dim : str Channel dimension name. Examples -------- >>> import myoverse >>> emg = myoverse.emg_tensor(data, grid_layouts=[grid1, grid2]) >>> ndd = SpatialFilter("NDD", grids="all") >>> filtered = ndd(emg) """
[docs] def __init__( self, kernel: str | torch.Tensor = "NDD", grids: str | list[int] = "all", dim: str = "channel", **kwargs, ): super().__init__(dim=dim, **kwargs) if isinstance(kernel, str): if kernel not in SPATIAL_KERNELS: raise ValueError( f"Unknown kernel '{kernel}'. " f"Available: {list(SPATIAL_KERNELS.keys())}", ) self.kernel = SPATIAL_KERNELS[kernel] self.kernel_name = kernel else: self.kernel = kernel self.kernel_name = "custom" self.grids = grids
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: # Get grid layouts from tensor attributes if not hasattr(x, "grid_layouts"): raise ValueError( "Tensor missing 'grid_layouts' attribute. " "Create with myoverse.emg_tensor(data, grid_layouts=[...]):\n\n" "\timport myoverse\n" "\temg = myoverse.emg_tensor(data, grid_layouts=[grid1, grid2])\n", ) grid_layouts = x.grid_layouts names = x.names # Determine which grids to process if self.grids == "all": grid_indices = list(range(len(grid_layouts))) else: grid_indices = self.grids x = x.rename(None) # Move kernel to same device kernel = self.kernel.to(device=x.device, dtype=x.dtype) # Process each grid results = [] channel_offset = 0 for grid_idx, grid_layout in enumerate(grid_layouts): n_channels = np.sum(grid_layout >= 0) # Extract this grid's channels grid_data = x[..., channel_offset : channel_offset + n_channels, :] if grid_idx in grid_indices: # Reshape to grid grid_shaped = _channels_to_grid(grid_data, grid_layout) # Apply 2D convolution # conv2d expects (N, C, H, W) - we have (..., H, W, T) # Reshape: move time to batch, grid to spatial original_shape = grid_shaped.shape time_dim = original_shape[-1] batch_shape = original_shape[:-3] # Flatten batch dims and time into batch grid_shaped = grid_shaped.reshape(-1, *original_shape[-3:]) # (batch*time_batch, rows, cols, time) -> need (N, C, H, W) # Actually we want to convolve over (rows, cols) for each time point # Reshape to (batch, time, rows, cols) then to (batch*time, 1, rows, cols) grid_shaped = grid_shaped.permute(0, 3, 1, 2) # (B, T, R, C) B, T, R, C = grid_shaped.shape grid_shaped = grid_shaped.reshape(B * T, 1, R, C) # Prepare kernel for conv2d: (out_channels, in_channels, H, W) kernel_4d = kernel.unsqueeze(0).unsqueeze(0) # Apply convolution with zero padding filtered = F.conv2d(grid_shaped, kernel_4d, padding=1) # Reshape back filtered = filtered.reshape(B, T, R, C) filtered = filtered.permute(0, 2, 3, 1) # (B, R, C, T) filtered = filtered.reshape(*batch_shape, R, C, time_dim) # Back to channels filtered_channels = _grid_to_channels(filtered, grid_layout) results.append(filtered_channels) else: # Keep unfiltered results.append(grid_data) channel_offset += n_channels # Concatenate all grids result = torch.cat(results, dim=-2) if names[0] is not None: result = result.rename(*names) return result
[docs] class NDD(SpatialFilter): """Normal Double Differential (Laplacian) filter. Enhances localized activity by subtracting the average of 4 neighbors. Parameters ---------- grids : str | list[int] Which grids to filter. "all" or list of indices. """
[docs] def __init__(self, grids: str | list[int] = "all", **kwargs): super().__init__(kernel="NDD", grids=grids, **kwargs)
[docs] class LSD(SpatialFilter): """Longitudinal Single Differential filter. Computes vertical (along muscle fiber) differences. Parameters ---------- grids : str | list[int] Which grids to filter. "all" or list of indices. """
[docs] def __init__(self, grids: str | list[int] = "all", **kwargs): super().__init__(kernel="LSD", grids=grids, **kwargs)
[docs] class TSD(SpatialFilter): """Transverse Single Differential filter. Computes horizontal (across muscle fiber) differences. Parameters ---------- grids : str | list[int] Which grids to filter. "all" or list of indices. """
[docs] def __init__(self, grids: str | list[int] = "all", **kwargs): super().__init__(kernel="TSD", grids=grids, **kwargs)
[docs] class IB2(SpatialFilter): """Inverse Binomial 2nd order filter. 2D high-pass filter using binomial weighting. Parameters ---------- grids : str | list[int] Which grids to filter. "all" or list of indices. """
[docs] def __init__(self, grids: str | list[int] = "all", **kwargs): super().__init__(kernel="IB2", grids=grids, **kwargs)