"""
Modular plotting functions for neuron simulation results.
This module provides specialized plotting functions for different aspects of
neuromuscular simulation results from MyoGen's neuron simulator, including
raster plots, membrane traces, and physiological model dynamics.
"""
import warnings
from typing import Any
import numpy as np
from beartype.cave import IterableType
from matplotlib.axes import Axes
from neo import Block
from myogen.utils.decorators import beartowertype
[docs]
@beartowertype
def plot_raster_spikes(
results: Block,
axs: IterableType[Axes],
populations: list[str],
time_range: tuple[float, float] | None = None,
dot_size: float = 0.8,
alpha: float = 1.0,
title: str = "Raster Plot",
xlabel: str = "Time [ms]",
ylabel: str = "Neuron ID",
apply_default_formatting: bool = True,
**kwargs: Any,
) -> IterableType[Axes]:
"""
Plot spike raster plots for neural populations, one per axis.
Parameters
----------
results : Block
NEO Block containing spike train segments for each population.
axs : IterableType[Axes]
Matplotlib axes to plot on. Must have as many axes as populations.
populations : list[str]
List of population names to plot.
time_range : tuple[float, float], optional
Time range to plot (start, end) in milliseconds, by default None (full range).
dot_size : float, optional
Size of spike markers, by default 0.8.
alpha : float, optional
Transparency of spike markers (0.0 to 1.0), by default 1.0.
title : str, optional
Plot title, by default "Raster Plot".
xlabel : str, optional
X-axis label, by default "Time [ms]".
ylabel : str, optional
Y-axis label, by default "Neuron ID".
apply_default_formatting : bool, optional
Whether to apply default formatting, by default True.
**kwargs : Any
Additional keyword arguments passed to matplotlib plot functions.
Returns
-------
IterableType[Axes]
The axes that were plotted on.
"""
ax_list = list(axs)
if len(ax_list) != len(populations):
raise ValueError(
f"plot_raster_spikes requires {len(populations)} axes (one per population), got {len(ax_list)}"
)
# Plot each population on its own axis
for pop_idx, pop_name in enumerate(populations):
ax = ax_list[pop_idx]
# Find segment for this population
segment = None
for seg in results.segments:
if seg.name == pop_name:
segment = seg
break
if segment is None or not segment.spiketrains:
# No spikes for this population, skip but still format axis
if apply_default_formatting:
ax.set_xlabel(xlabel if pop_idx == len(populations) - 1 else "")
ax.set_ylabel(ylabel)
ax.set_title(f"{title} - {pop_name}")
if time_range is not None:
ax.set_xlim(time_range)
continue
# Extract spike times and IDs from spike trains
spike_times = []
spike_ids = []
for i, spiketrain in enumerate(segment.spiketrains):
for spike_time in spiketrain.times:
spike_times.append(float(spike_time.rescale("ms").magnitude))
spike_ids.append(i)
if spike_times:
spike_times = np.array(spike_times)
spike_ids = np.array(spike_ids)
# Apply time range filter if specified
if time_range is not None:
time_mask = (spike_times >= time_range[0]) & (spike_times <= time_range[1])
spike_times = spike_times[time_mask]
spike_ids = spike_ids[time_mask]
# Plot spikes for this population
ax.plot(spike_times, spike_ids, ".", ms=dot_size, alpha=alpha, **kwargs)
if apply_default_formatting:
# Only show xlabel on bottom plot
ax.set_xlabel(xlabel if pop_idx == len(populations) - 1 else "")
ax.set_ylabel(ylabel)
ax.set_title(f"{title} - {pop_name}")
if time_range is not None:
ax.set_xlim(time_range)
# Set y-axis to show neuron IDs from 0 to max
if spike_times.size > 0:
ax.set_ylim(-0.5, len(segment.spiketrains) - 0.5)
# Remove spines for cleaner appearance when multiple populations
if len(populations) > 1:
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
# Remove bottom spine for all except the last plot
if pop_idx < len(populations) - 1:
ax.spines["bottom"].set_visible(False)
ax.tick_params(bottom=False, labelbottom=False)
return axs
[docs]
@beartowertype
def plot_membrane_potentials(
results: Block,
axs: IterableType[Axes],
populations: list[str] | str = "aMN",
cell_indices: list[int] = [0, 10, 20, 30, 40],
time_range: tuple[float, float] | None = None,
title: str = "Membrane Potential",
xlabel: str = "Time [ms]",
ylabel: str = "Voltage [mV]",
apply_default_formatting: bool = True,
**kwargs: Any,
) -> IterableType[Axes]:
"""
Plot membrane potential traces for selected cells across populations.
Parameters
----------
results : Block
NEO Block containing analog signal segments for membrane potentials.
axs : IterableType[Axes]
Matplotlib axes to plot on. If populations is a list, must have as many axes as populations.
populations : list[str] | str, optional
Population name(s) to plot. If list, plots one per axis. By default "aMN".
cell_indices : list[int], optional
List of cell indices to plot, by default [0, 10, 20, 30, 40].
time_range : tuple[float, float], optional
Time range to plot (start, end) in milliseconds, by default None (full range).
title : str, optional
Plot title, by default "Membrane Potential".
xlabel : str, optional
X-axis label, by default "Time [ms]".
ylabel : str, optional
Y-axis label, by default "Voltage [mV]".
apply_default_formatting : bool, optional
Whether to apply default formatting, by default True.
**kwargs : Any
Additional keyword arguments passed to matplotlib plot functions.
Returns
-------
IterableType[Axes]
The axes that were plotted on.
"""
ax_list = list(axs)
# Handle backward compatibility - convert single population to list
if isinstance(populations, str):
populations = [populations]
if len(ax_list) != len(populations):
raise ValueError(
f"plot_membrane_potentials requires {len(populations)} axes (one per population), got {len(ax_list)}"
)
# Plot each population on its own axis
for pop_idx, population in enumerate(populations):
ax = ax_list[pop_idx]
# Find segment for this population
segment = None
for seg in results.segments:
if seg.name == population:
segment = seg
break
if segment is None or not segment.analogsignals:
warnings.warn(f"No membrane potential data found for population '{population}'")
if apply_default_formatting:
ax.set_xlabel(xlabel if pop_idx == len(populations) - 1 else "")
ax.set_ylabel(ylabel)
ax.set_title(f"{title} - {population}")
if time_range is not None:
ax.set_xlim(time_range)
continue
# Plot traces for requested cell indices
for signal in segment.analogsignals:
cell_id = int(signal.name)
if cell_id in cell_indices:
times = signal.times.rescale("ms").magnitude
voltage = signal.magnitude.flatten()
# Apply time range filter if specified
if time_range is not None:
time_mask = (times >= time_range[0]) & (times <= time_range[1])
times = times[time_mask]
voltage = voltage[time_mask]
ax.plot(times, voltage, label=f"{population}[{cell_id}]", **kwargs)
if apply_default_formatting:
# Only show xlabel on bottom plot
ax.set_xlabel(xlabel if pop_idx == len(populations) - 1 else "")
ax.set_ylabel(ylabel)
ax.set_title(f"{title} - {population}")
ax.legend(loc="upper right")
if time_range is not None:
ax.set_xlim(time_range)
# Remove spines for cleaner appearance when multiple populations
if len(populations) > 1:
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
# Remove bottom spine for all except the last plot
if pop_idx < len(populations) - 1:
ax.spines["bottom"].set_visible(False)
ax.tick_params(bottom=False, labelbottom=False)
return axs
[docs]
@beartowertype
def plot_muscle_dynamics(
results: Block,
joint_angle: np.ndarray,
time: np.ndarray,
axs: IterableType[Axes],
muscle_name: str = "hill",
include_signals: list[str] = ["artAng", "L", "force", "torque"],
include_activations: list[str] = ["TypeI", "TypeII"],
normalize: bool = True,
time_range: tuple[float, float] | None = None,
title: str = "Muscle Hill Model Dynamics",
xlabel: str = "Time [ms]",
apply_default_formatting: bool = True,
**kwargs: Any,
) -> IterableType[Axes]:
"""
Plot muscle dynamics from Hill model.
Parameters
----------
results : Block
NEO Block containing muscle model segment.
joint_angle : np.ndarray
Joint angle time series data.
time : np.ndarray
Time vector in milliseconds.
axs : IterableType[Axes]
Matplotlib axes to plot on (one per signal).
muscle_name : str, optional
Name of muscle segment to plot, by default "hill".
include_signals : list[str], optional
Signals to plot, by default ["artAng", "L", "force", "torque"].
include_activations : list[str], optional
Activation types to plot, by default ["TypeI", "TypeII"].
normalize : bool, optional
Whether to normalize signals, by default True.
time_range : tuple[float, float], optional
Time range to plot, by default None.
title : str, optional
Plot title, by default "Muscle Hill Model Dynamics".
xlabel : str, optional
X-axis label, by default "Time [ms]".
apply_default_formatting : bool, optional
Whether to apply default formatting, by default True.
**kwargs : Any
Additional keyword arguments passed to matplotlib plot functions.
Returns
-------
IterableType[Axes]
The axes that were plotted on.
"""
ax_list = list(axs)
# Find muscle segment by name
muscle_segment = None
for seg in results.segments:
if seg.name == muscle_name:
muscle_segment = seg
break
if muscle_segment is None:
warnings.warn(f"No muscle dynamics data found for '{muscle_name}' in results")
return axs
# Create signal data dictionary
signal_data = {}
# Add joint angle if requested
if "artAng" in include_signals:
signal_data["artAng"] = (joint_angle, "Joint Angle [deg]")
# Extract muscle signals from analog signals
for signal in muscle_segment.analogsignals:
signal_name = signal.name
signal_array = signal.magnitude.flatten()
if signal_name == "muscle_length" and "L" in include_signals:
signal_data["L"] = (signal_array, "Length [L0]")
elif signal_name == "muscle_force" and "force" in include_signals:
signal_data["force"] = (signal_array, "Force [F0]")
elif signal_name == "muscle_torque" and "torque" in include_signals:
signal_data["torque"] = (signal_array, "Torque [F0·cm]")
elif signal_name == "type1_activation" and "TypeI" in include_activations:
if "act" not in signal_data:
signal_data["act"] = {}
signal_data["act"]["TypeI"] = signal_array
elif signal_name == "type2_activation" and "TypeII" in include_activations:
if "act" not in signal_data:
signal_data["act"] = {}
signal_data["act"]["TypeII"] = signal_array
# Apply time range filter
plot_time = time
if time_range is not None:
time_mask = (time >= time_range[0]) & (time <= time_range[1])
plot_time = time[time_mask]
# Apply same mask to all signal data
for key, value in signal_data.items():
if key != "act" and isinstance(value, tuple):
if len(value) >= 2:
# Ensure signal array and time mask have compatible lengths
signal_array = value[0]
if len(signal_array) != len(time_mask):
# Trim signal array to match time array length
min_length = min(len(signal_array), len(time_mask))
signal_array = signal_array[:min_length]
time_mask_trimmed = time_mask[:min_length]
signal_data[key] = (signal_array[time_mask_trimmed], value[1])
else:
signal_data[key] = (signal_array[time_mask], value[1])
else:
print(
f"Warning: signal_data['{key}'] is a tuple but has length {len(value)}: {value}"
)
elif key == "act":
for act_type, act_data in value.items():
# Handle activation data length mismatch
if len(act_data) != len(time_mask):
min_length = min(len(act_data), len(time_mask))
act_data_trimmed = act_data[:min_length]
time_mask_trimmed = time_mask[:min_length]
signal_data[key][act_type] = act_data_trimmed[time_mask_trimmed]
else:
signal_data[key][act_type] = act_data[time_mask]
else:
print(
f"Warning: signal_data['{key}'] is not a tuple or 'act': type={type(value)}, value={value}"
)
# Plot signals
plot_idx = 0
# Regular signals
for signal_name in include_signals:
if signal_name in signal_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
data, ylabel = signal_data[signal_name]
ax.plot(plot_time, data, **kwargs)
if apply_default_formatting:
ax.set_ylabel(ylabel)
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
# Activations (if any to plot)
if include_activations and "act" in signal_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
for act_type in include_activations:
if act_type in signal_data["act"]:
ax.plot(plot_time, signal_data["act"][act_type], label=act_type, **kwargs)
if apply_default_formatting:
ax.set_ylabel("Activation [a.u.]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
return axs
[docs]
@beartowertype
def plot_antagonist_muscle_comparison(
results: Block,
joint_angle: np.ndarray,
time: np.ndarray,
axs: IterableType[Axes],
flexor_name: str = "hill_flexor",
extensor_name: str = "hill_extensor",
include_signals: list[str] = ["artAng", "force", "torque"],
time_range: tuple[float, float] | None = None,
title: str = "Antagonist Muscle Comparison",
xlabel: str = "Time [ms]",
apply_default_formatting: bool = True,
**kwargs: Any,
) -> IterableType[Axes]:
"""
Plot comparison of antagonist muscle dynamics.
Parameters
----------
results : Block
NEO Block containing both muscle model segments.
joint_angle : np.ndarray
Joint angle time series data.
time : np.ndarray
Time vector in milliseconds.
axs : IterableType[Axes]
Matplotlib axes to plot on (one per signal).
flexor_name : str, optional
Name of flexor muscle segment, by default "hill_flexor".
extensor_name : str, optional
Name of extensor muscle segment, by default "hill_extensor".
include_signals : list[str], optional
Signals to compare, by default ["artAng", "force", "torque"].
time_range : tuple[float, float], optional
Time range to plot, by default None.
title : str, optional
Plot title, by default "Antagonist Muscle Comparison".
xlabel : str, optional
X-axis label, by default "Time [ms]".
apply_default_formatting : bool, optional
Whether to apply default formatting, by default True.
**kwargs : Any
Additional keyword arguments passed to matplotlib plot functions.
Returns
-------
IterableType[Axes]
The axes that were plotted on.
"""
ax_list = list(axs)
# Find muscle segments
flexor_segment = None
extensor_segment = None
for seg in results.segments:
if seg.name == flexor_name:
flexor_segment = seg
elif seg.name == extensor_name:
extensor_segment = seg
if flexor_segment is None or extensor_segment is None:
warnings.warn("Could not find both flexor and extensor muscle segments")
return axs
# Extract data from both muscles
flexor_data = {}
extensor_data = {}
for signal in flexor_segment.analogsignals:
flexor_data[signal.name] = signal.magnitude.flatten()
for signal in extensor_segment.analogsignals:
extensor_data[signal.name] = signal.magnitude.flatten()
# Apply time range filter
plot_time = time
if time_range is not None:
time_mask = (time >= time_range[0]) & (time <= time_range[1])
plot_time = time[time_mask]
# Handle length mismatches for joint angle
if len(joint_angle) != len(time_mask):
min_length = min(len(joint_angle), len(time_mask))
joint_angle = joint_angle[:min_length]
time_mask_trimmed = time_mask[:min_length]
joint_angle = joint_angle[time_mask_trimmed]
else:
joint_angle = joint_angle[time_mask]
# Apply mask to muscle data with length checks
for key in flexor_data:
if len(flexor_data[key]) != len(time_mask):
min_length = min(len(flexor_data[key]), len(time_mask))
data_trimmed = flexor_data[key][:min_length]
time_mask_trimmed = time_mask[:min_length]
flexor_data[key] = data_trimmed[time_mask_trimmed]
else:
flexor_data[key] = flexor_data[key][time_mask]
for key in extensor_data:
if len(extensor_data[key]) != len(time_mask):
min_length = min(len(extensor_data[key]), len(time_mask))
data_trimmed = extensor_data[key][:min_length]
time_mask_trimmed = time_mask[:min_length]
extensor_data[key] = data_trimmed[time_mask_trimmed]
else:
extensor_data[key] = extensor_data[key][time_mask]
# Plot signals
plot_idx = 0
# Joint angle
if "artAng" in include_signals and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
ax.plot(plot_time, joint_angle, label="Joint Angle", **kwargs)
if apply_default_formatting:
ax.set_ylabel("Joint Angle [deg]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
# Force comparison
if "force" in include_signals and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
if "muscle_force" in flexor_data:
ax.plot(plot_time, flexor_data["muscle_force"], label="Flexor", **kwargs)
if "muscle_force" in extensor_data:
ax.plot(plot_time, extensor_data["muscle_force"], label="Extensor", **kwargs)
if apply_default_formatting:
ax.set_ylabel("Force [F0]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
# Torque comparison (including net torque)
if "torque" in include_signals and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
if "muscle_torque" in flexor_data and "muscle_torque" in extensor_data:
flex_torque = flexor_data["muscle_torque"]
ext_torque = -extensor_data["muscle_torque"] # Negative for extensor
net_torque = flex_torque + ext_torque
ax.plot(plot_time, flex_torque, label="Flexor", **kwargs)
ax.plot(plot_time, ext_torque, label="Extensor", **kwargs)
ax.plot(plot_time, net_torque, label="Net", linewidth=2, **kwargs)
if apply_default_formatting:
ax.set_ylabel("Torque [F0·cm]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
return axs
[docs]
@beartowertype
def plot_spindle_dynamics(
results: Block,
axs: IterableType[Axes],
muscle_name: str = "hill_flexor",
include_signals: list[str] = ["L"],
include_activations: list[str] = ["Bag1", "Bag2", "Chain"],
include_tensions: list[str] = ["Bag1", "Bag2", "Chain"],
include_afferents: list[str] = ["Ia", "II"],
time_range: tuple[float, float] | None = None,
title: str = "Spindle Model Dynamics",
xlabel: str = "Time [ms]",
apply_default_formatting: bool = True,
**kwargs: Any,
) -> IterableType[Axes]:
"""
Plot spindle model dynamics.
Parameters
----------
results : Block
NEO Block containing spindle model segment.
axs : IterableType[Axes]
Matplotlib axes to plot on (one per signal group).
muscle_name : str, optional
Name of muscle segment to use for length and time data, by default "hill_flexor".
include_signals : list[str], optional
Basic signals to plot, by default ["L"].
include_activations : list[str], optional
Intrafusal activations to plot, by default ["Bag1", "Bag2", "Chain"].
include_tensions : list[str], optional
Intrafusal tensions to plot, by default ["Bag1", "Bag2", "Chain"].
include_afferents : list[str], optional
Afferent firing rates to plot, by default ["Ia", "II"].
time_range : tuple[float, float], optional
Time range to plot, by default None.
title : str, optional
Plot title, by default "Spindle Model Dynamics".
xlabel : str, optional
X-axis label, by default "Time [ms]".
apply_default_formatting : bool, optional
Whether to apply default formatting, by default True.
**kwargs : Any
Additional keyword arguments passed to matplotlib plot functions.
Returns
-------
IterableType[Axes]
The axes that were plotted on.
"""
ax_list = list(axs)
# Find spindle segment
spindle_segment = None
for seg in results.segments:
if seg.name == "spin":
spindle_segment = seg
break
if spindle_segment is None:
warnings.warn("No spindle dynamics data found in results")
return axs
# Extract spindle data
spindle_data = {}
# Get muscle length from specified muscle segment for "L" signal
if "L" in include_signals:
for seg in results.segments:
if seg.name == muscle_name:
for signal in seg.analogsignals:
if signal.name == "muscle_length":
spindle_data["L"] = (signal.magnitude.flatten(), "Length [L0]")
break
break
# Extract spindle-specific signals
for signal in spindle_segment.analogsignals:
signal_name = signal.name
signal_array = signal.magnitude
if signal_name == "bag1_activation":
if "act" not in spindle_data:
spindle_data["act"] = {}
spindle_data["act"]["Bag1"] = signal_array.flatten()
elif signal_name == "bag2_activation":
if "act" not in spindle_data:
spindle_data["act"] = {}
spindle_data["act"]["Bag2"] = signal_array.flatten()
elif signal_name == "chain_activation":
if "act" not in spindle_data:
spindle_data["act"] = {}
spindle_data["act"]["Chain"] = signal_array.flatten()
elif signal_name == "intrafusal_tensions":
# Handle 2D tensions array (3 x time_points) or (time_points x 3)
if signal_array.ndim == 2:
if signal_array.shape[0] == 3: # (3, time_points)
spindle_data["tension"] = {
"Bag1": signal_array[0, :],
"Bag2": signal_array[1, :],
"Chain": signal_array[2, :],
}
elif signal_array.shape[1] == 3: # (time_points, 3)
spindle_data["tension"] = {
"Bag1": signal_array[:, 0],
"Bag2": signal_array[:, 1],
"Chain": signal_array[:, 2],
}
elif signal_name == "primary_afferent_firing__Hz":
if "aff" not in spindle_data:
spindle_data["aff"] = {}
spindle_data["aff"]["Ia"] = signal_array.flatten()
elif signal_name == "secondary_afferent_firing__Hz":
if "aff" not in spindle_data:
spindle_data["aff"] = {}
spindle_data["aff"]["II"] = signal_array.flatten()
# Get time vector from specified muscle segment
time_vector = None
for seg in results.segments:
if seg.name == muscle_name:
for signal in seg.analogsignals:
time_vector = signal.times.rescale("ms").magnitude
break
break
if time_vector is None:
warnings.warn("Could not find time vector for spindle plotting")
return axs
# Apply time range filter
plot_time = time_vector
if time_range is not None:
time_mask = (time_vector >= time_range[0]) & (time_vector <= time_range[1])
plot_time = time_vector[time_mask]
# Apply same mask to all signal data
for key, value in spindle_data.items():
if key in ["L"] and isinstance(value, tuple):
spindle_data[key] = (value[0][time_mask], value[1])
elif key in ["act", "tension", "aff"]:
for subkey, subvalue in value.items():
spindle_data[key][subkey] = subvalue[time_mask]
# Plot signals
plot_idx = 0
# Basic signals (like length)
for signal_name in include_signals:
if signal_name in spindle_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
data, ylabel = spindle_data[signal_name]
ax.plot(plot_time, data, **kwargs)
if apply_default_formatting:
ax.set_ylabel(ylabel)
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
# Activations
if include_activations and "act" in spindle_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
for act_type in include_activations:
if act_type in spindle_data["act"]:
ax.plot(plot_time, spindle_data["act"][act_type], label=act_type, **kwargs)
if apply_default_formatting:
ax.set_ylabel("Activation [a.u.]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
# Tensions
if include_tensions and "tension" in spindle_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
for tension_type in include_tensions:
if tension_type in spindle_data["tension"]:
ax.plot(
plot_time,
spindle_data["tension"][tension_type],
label=tension_type,
**kwargs,
)
if apply_default_formatting:
ax.set_ylabel("Tension [F0]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
# Afferents
if include_afferents and "aff" in spindle_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
for aff_type in include_afferents:
if aff_type in spindle_data["aff"]:
ax.plot(plot_time, spindle_data["aff"][aff_type], label=aff_type, **kwargs)
if apply_default_formatting:
ax.set_ylabel("Firing Rate [Hz]")
ax.legend()
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
return axs
[docs]
@beartowertype
def plot_gto_dynamics(
results: Block,
axs: IterableType[Axes],
muscle_name: str = "hill_flexor",
include_signals: list[str] = ["force", "Ib"],
time_range: tuple[float, float] | None = None,
title: str = "GTO Model Dynamics",
xlabel: str = "Time [ms]",
apply_default_formatting: bool = True,
**kwargs: Any,
) -> IterableType[Axes]:
"""
Plot Golgi tendon organ (GTO) dynamics.
Parameters
----------
results : Block
NEO Block containing GTO model segment.
axs : IterableType[Axes]
Matplotlib axes to plot on (one per signal).
muscle_name : str, optional
Name of muscle segment to use for force and time data, by default "hill_flexor".
include_signals : list[str], optional
Signals to plot, by default ["force", "Ib"].
time_range : tuple[float, float], optional
Time range to plot, by default None.
title : str, optional
Plot title, by default "GTO Model Dynamics".
xlabel : str, optional
X-axis label, by default "Time [ms]".
apply_default_formatting : bool, optional
Whether to apply default formatting, by default True.
**kwargs : Any
Additional keyword arguments passed to matplotlib plot functions.
Returns
-------
IterableType[Axes]
The axes that were plotted on.
"""
ax_list = list(axs)
# Find GTO segment
gto_segment = None
for seg in results.segments:
if seg.name == "gto":
gto_segment = seg
break
if gto_segment is None:
warnings.warn("No GTO dynamics data found in results")
return axs
# Extract GTO data
gto_data = {}
# Get muscle force from specified muscle segment for "force" signal
if "force" in include_signals:
for seg in results.segments:
if seg.name == muscle_name:
for signal in seg.analogsignals:
if signal.name == "muscle_force":
gto_data["force"] = (signal.magnitude.flatten(), "Force [F0]")
break
break
# Extract GTO-specific signals
for signal in gto_segment.analogsignals:
if signal.name == "ib_afferent_firing__Hz" and "Ib" in include_signals:
gto_data["Ib"] = (signal.magnitude.flatten(), "Firing Rate [Hz]")
# Get time vector from specified muscle segment
time_vector = None
for seg in results.segments:
if seg.name == muscle_name:
for signal in seg.analogsignals:
time_vector = signal.times.rescale("ms").magnitude
break
break
if time_vector is None:
warnings.warn("Could not find time vector for GTO plotting")
return axs
# Apply time range filter
plot_time = time_vector
if time_range is not None:
time_mask = (time_vector >= time_range[0]) & (time_vector <= time_range[1])
plot_time = time_vector[time_mask]
# Apply same mask to all signal data
for key, value in gto_data.items():
if isinstance(value, tuple):
gto_data[key] = (value[0][time_mask], value[1])
# Plot signals
plot_idx = 0
for signal_name in include_signals:
if signal_name in gto_data and plot_idx < len(ax_list):
ax = ax_list[plot_idx]
data, ylabel = gto_data[signal_name]
ax.plot(plot_time, data, **kwargs)
if apply_default_formatting:
ax.set_ylabel(ylabel)
if plot_idx == 0:
ax.set_title(title)
if plot_idx == len(ax_list) - 1:
ax.set_xlabel(xlabel)
if time_range is not None:
ax.set_xlim(time_range)
plot_idx += 1
return axs