from pathlib import Path
from typing import Optional, Any, Union
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from tqdm import tqdm
from beartype import beartype
from beartype.cave import IterableType
from myogen.utils.types import SURFACE_EMG__TENSOR
def _get_axis(axes, row_idx: int, col_idx: int, n_rows: int, n_cols: int):
"""Helper function to safely get axis from matplotlib subplots."""
if n_rows == 1 and n_cols == 1:
return axes
elif n_rows == 1:
return axes[col_idx]
elif n_cols == 1:
return axes[row_idx]
else:
return axes[row_idx, col_idx]
def _auto_zoom_muaps(
muap_data: np.ndarray,
threshold: float = 0.01,
padding: float = 0.1,
center: bool = True,
) -> np.ndarray:
"""
Auto-zoom MUAPs by centering them and cropping to significant regions.
Based on the algorithm from check_generated_surface_muaps.py.
Parameters
----------
muap_data : np.ndarray
MUAP data with shape (n_muaps, n_rows, n_cols, n_time)
threshold : float
Threshold as fraction of max amplitude for significant regions
padding : float
Padding as fraction of signal length to add around significant regions
center : bool
Whether to center MUAPs temporally before cropping
Returns
-------
np.ndarray
Processed MUAP data with same shape but potentially fewer time samples
"""
n_muaps, n_rows, n_cols, n_time = muap_data.shape
processed_data = muap_data.copy()
if center:
# Center each MUAP around its centroid
for muap_idx in range(n_muaps):
# Find the center of mass for this MUAP across all electrodes
muap_abs = np.abs(processed_data[muap_idx])
# Calculate center of mass for each electrode
centers = np.zeros((n_rows, n_cols))
for row_idx in range(n_rows):
for col_idx in range(n_cols):
signal = muap_abs[row_idx, col_idx]
if np.sum(signal) > 0:
# Center of mass calculation
time_indices = np.arange(n_time)
centers[row_idx, col_idx] = np.average(
time_indices, weights=signal
)
else:
centers[row_idx, col_idx] = n_time // 2
# Use mean center across all electrodes
mean_center = int(np.mean(centers))
shift = n_time // 2 - mean_center
# Apply shift to center the MUAP
if shift != 0:
shifted_muap = np.zeros_like(processed_data[muap_idx])
for row_idx in range(n_rows):
for col_idx in range(n_cols):
if shift > 0:
# Shift right
shifted_muap[row_idx, col_idx, shift:] = processed_data[
muap_idx, row_idx, col_idx, :-shift
]
elif shift < 0:
# Shift left
shifted_muap[row_idx, col_idx, :shift] = processed_data[
muap_idx, row_idx, col_idx, -shift:
]
else:
shifted_muap[row_idx, col_idx] = processed_data[
muap_idx, row_idx, col_idx
]
processed_data[muap_idx] = shifted_muap
# Find significant regions across all MUAPs and crop
max_amplitude = np.max(np.abs(processed_data))
threshold_value = threshold * max_amplitude
# Find where any MUAP exceeds threshold
significant_mask = np.abs(processed_data) > threshold_value
significant_indices = np.where(significant_mask)[-1] # Get time dimension indices
if len(significant_indices) > 0:
start_idx = int(np.min(significant_indices))
end_idx = int(np.max(significant_indices))
# Add padding
padding_samples = int(padding * n_time)
start_idx = max(0, start_idx - padding_samples)
end_idx = min(n_time, end_idx + padding_samples)
# Crop all MUAPs to this region
processed_data = processed_data[..., start_idx:end_idx]
return processed_data
[docs]
@beartype
def plot_surface_emg(
surface_emg__tensor: SURFACE_EMG__TENSOR,
axs: list[Union[Axes, np.ndarray]],
apply_default_formatting: bool = True,
**kwargs: Any,
) -> list[Union[Axes, np.ndarray]]:
"""
Plot the EMG signal across electrode grids.
Parameters
----------
surface_emg__tensor : SURFACE_EMG__TENSOR
Tensor of shape (n_pools, n_rows, n_cols, n_time) containing EMG signals
axs : list[Union[Axes, np.ndarray]]
Matplotlib axes to plot on. Should provide one set of axes per pool.
Each set can be a 2D array of axes (from plt.subplots), a single axis, or a 1D array.
Expected structure: list of axes configurations, one per pool.
apply_default_formatting : bool, default=True
Whether to apply default formatting to the plot
**kwargs : dict
Additional keyword arguments to pass to the plot function. Only used if apply_default_formatting is False.
Returns
-------
list[Union[Axes, np.ndarray]]
The axes that were plotted on
Raises
------
ValueError
If the number of axes does not match the number of pools
"""
axs_list = list(axs)
n_pools = surface_emg__tensor.shape[0]
if len(axs_list) != n_pools:
raise ValueError(
f"Number of axes must match number of pools. Got {len(axs_list)} axes, but {n_pools} pools."
)
n_rows = surface_emg__tensor.shape[1]
n_cols = surface_emg__tensor.shape[2]
for pool_idx, pool_axes in enumerate(axs_list):
# Handle the case where pool_axes is a single axis or array of axes
if hasattr(pool_axes, "flat") and not isinstance(pool_axes, Axes):
# pool_axes is a 2D array of axes
axes_flat: Any = pool_axes.flat
elif isinstance(pool_axes, Axes):
# pool_axes is a single axis
axes_flat: Any = [pool_axes]
else:
# pool_axes is a 1D array or other iterable
axes_flat: Any = pool_axes
for row_idx in range(n_rows):
for col_idx in range(n_cols):
electrode_idx = row_idx * n_cols + col_idx
if electrode_idx < len(axes_flat):
ax = axes_flat[electrode_idx]
plot_kwargs = kwargs.copy() if not apply_default_formatting else {}
if apply_default_formatting:
ax.plot(surface_emg__tensor[pool_idx, row_idx, col_idx])
ax.set_title(f"Pool {pool_idx + 1} - R{row_idx} C{col_idx}")
ax.set_xlabel("Time (samples)")
ax.set_ylabel("Amplitude")
else:
ax.plot(
surface_emg__tensor[pool_idx, row_idx, col_idx],
**plot_kwargs,
)
return axs
[docs]
@beartype
def plot_muap_grid(
muap_data: np.ndarray,
axs: list[Union[Axes, np.ndarray]],
muap_indices: Optional[list[int]] = None,
apply_default_formatting: bool = True,
**kwargs: Any,
) -> list[Union[Axes, np.ndarray]]:
"""
Plot Motor Unit Action Potentials (MUAPs) in electrode grid format.
This function visualizes MUAPs as they appear across a grid of electrodes,
with each subplot showing the MUAP waveform at a specific electrode position.
The layout matches the physical electrode arrangement.
Parameters
----------
muap_data : np.ndarray
MUAP data with shape (n_muaps, n_electrode_rows, n_electrode_cols, n_time_samples)
or (n_electrode_rows, n_electrode_cols, n_time_samples) for a single MUAP.
axs : list[Union[Axes, np.ndarray]]
Matplotlib axes to plot on. Should provide one set of axes per MUAP.
Each set can be a 2D array of axes (from plt.subplots), a single axis, or a 1D array.
Expected structure: list of axes configurations, one per MUAP to plot.
muap_indices : list[int], optional
List of MUAP indices to plot. If None, plots all MUAPs.
apply_default_formatting : bool, default=True
Whether to apply default formatting to the plot
**kwargs : dict
Additional keyword arguments to pass to the plot function. Only used if apply_default_formatting is False.
Returns
-------
list[Union[Axes, np.ndarray]]
The axes that were plotted on
Raises
------
ValueError
If the number of axes does not match the number of MUAPs to plot
"""
# Handle single MUAP case by adding a dimension
if muap_data.ndim == 3:
muap_data = muap_data[np.newaxis, ...]
# Validate input dimensions
if muap_data.ndim != 4:
raise ValueError(
f"muap_data must have 3 or 4 dimensions, got {muap_data.ndim}. "
f"Expected shape: (n_muaps, n_rows, n_cols, n_time) or (n_rows, n_cols, n_time)"
)
n_muaps, n_rows, n_cols, n_time = muap_data.shape
# Set default MUAP indices if not provided
if muap_indices is None:
muap_indices = list(range(n_muaps))
# Validate MUAP indices
invalid_indices = [idx for idx in muap_indices if idx >= n_muaps or idx < 0]
if invalid_indices:
raise ValueError(
f"Invalid MUAP indices: {invalid_indices}. Must be in range [0, {n_muaps - 1}]"
)
axs_list = list(axs)
if len(axs_list) != len(muap_indices):
raise ValueError(
f"Number of axes must match number of MUAPs to plot. Got {len(axs_list)} axes, but {len(muap_indices)} MUAPs."
)
# Plot each requested MUAP
for i, muap_idx in enumerate(tqdm(muap_indices, desc="Plotting MUAPs")):
muap_axes = axs_list[i]
# Handle the case where muap_axes is a single axis or array of axes
if hasattr(muap_axes, "flat") and not isinstance(muap_axes, Axes):
# muap_axes is a 2D array of axes
axes_flat: Any = muap_axes.flat
elif isinstance(muap_axes, Axes):
# muap_axes is a single axis
axes_flat: Any = [muap_axes]
else:
# muap_axes is a 1D array or other iterable
axes_flat: Any = muap_axes
# Calculate global y-limits for consistent scaling across electrodes
muap_min = np.min(muap_data[muap_idx])
muap_max = np.max(muap_data[muap_idx])
# Plot MUAP at each electrode position
for row_idx in range(n_rows):
for col_idx in range(n_cols):
electrode_idx = row_idx * n_cols + col_idx
if electrode_idx < len(axes_flat):
ax = axes_flat[electrode_idx]
plot_kwargs = kwargs.copy() if not apply_default_formatting else {}
if apply_default_formatting:
# Plot MUAP waveform
ax.plot(muap_data[muap_idx, row_idx, col_idx])
# Set consistent y-limits across all electrodes
ax.set_ylim(muap_min, muap_max)
# Clean up axes for better visualization
ax.set_xticks([])
ax.set_yticks([])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["bottom"].set_visible(False)
else:
ax.plot(muap_data[muap_idx, row_idx, col_idx], **plot_kwargs)
return axs