Source code for myogen.simulator.neuron.simulation_runner

from itertools import count
from typing import Any, Callable, Optional, Union

import numpy as np
import quantities as pq
from neo import AnalogSignal, Block, Segment, SpikeTrain
from tqdm import tqdm

from myogen.simulator.neuron.network import Network

from neuron import h

from myogen.utils.decorators import beartowertype
from myogen.utils.types import Quantity__m_per_s, Quantity__ms


[docs] @beartowertype class SimulationRunner: """ Manages NEURON simulation execution with automated setup, initialization, and result collection for neuromuscular simulations. Provides a clean interface for running complex neuromuscular simulations while maintaining full user control over populations, connections, and step-by-step simulation logic. Automatically handles NEURON environment setup, voltage initialization, and structured result collection. Separates simulation control from plotting and analysis concerns. """ # Smart defaults for common MyoGen model output attributes _DEFAULT_MODEL_OUTPUTS = { "HillModel": [ "muscle_length", "muscle_force", "muscle_torque", "type1_activation", "type2_activation", ], "SpindleModel": [ "primary_afferent_firing__Hz", "secondary_afferent_firing__Hz", "bag1_activation", "bag2_activation", "chain_activation", "intrafusal_tensions", ], "GolgiTendonOrganModel": ["ib_afferent_firing__Hz"], }
[docs] def __init__( self, network: Network, models: dict[str, Any], step_callback: Callable[[Any], Any], model_outputs: Optional[dict[str, Union[list[str], None]]] = None, temperature__celsius: float = 36.0, ): """ Initialize SimulationRunner with network, models, and step callback. Parameters ---------- network : Network Configured Network instance with populations and connections. models : Dict[str, Any] Physiological models (e.g., {"hill": hill_model, "spin": spindle_model}). step_callback : Callable User-defined function called at each simulation timestep. model_outputs : Optional[Dict[str, Union[List[str], None]]], optional Explicit model output attributes to collect. None uses smart defaults. Format: {"model_name": ["attr1", "attr2"]} or {"model_name": None} for defaults, by default None. temperature__celsius : float, optional NEURON simulation temperature, by default 36.0. """ # Store immutable parameters following project pattern self.network = network self.populations = network.populations # Expose populations from network self.models = models self.step_callback = step_callback self.model_outputs = model_outputs self.temperature__celsius = temperature__celsius # Private working copies self._network = network self._populations = network.populations # Get populations from network self._models = models self._step_callback = step_callback self._model_outputs = self._resolve_model_outputs() self._temperature__celsius = temperature__celsius # Runtime state self._trace_vectors: dict[str, dict[int, Any]] = {} self._step_counter = None self._progress_bar = None self._total_steps = None # Setup internal spike recording vectors self._spike_recording = self._setup_spike_recording()
def _resolve_model_outputs(self) -> dict[str, list[str]]: """ Resolve model output attributes using smart defaults and user overrides. Returns ------- Dict[str, List[str]] Final mapping of model names to output attribute lists. """ resolved = {} for model_name, model_instance in self._models.items(): model_class_name = model_instance.__class__.__name__ # Check for user override if self.model_outputs and model_name in self.model_outputs: user_specified = self.model_outputs[model_name] if user_specified is None: # Use defaults for this model resolved[model_name] = self._DEFAULT_MODEL_OUTPUTS.get(model_class_name, []) else: # Use explicit user specification resolved[model_name] = user_specified else: # Use smart defaults based on model class resolved[model_name] = self._DEFAULT_MODEL_OUTPUTS.get(model_class_name, []) return resolved def _setup_spike_recording(self) -> dict[str, Any]: """ Create NEURON spike recording vectors for all populations. Returns ------- dict[str, Any] Dictionary containing 'idvec' and 'spkvec' with NEURON Vectors for each population. """ from neuron import h idvec = {} spkvec = {} for pop_name in self._populations.keys(): idvec[pop_name] = h.Vector() spkvec[pop_name] = h.Vector() return {"idvec": idvec, "spkvec": spkvec} def _setup_network_spike_recording(self) -> None: """ Configure the network with spike recording vectors and activate recording. """ # Set spike recording on network self._network.spike_recording = self._spike_recording # Setup spike recording (calls NEURON setup) self._network.setup_spike_recording()
[docs] def run( self, duration__ms: Quantity__ms, timestep__ms: Quantity__ms, membrane_recording: Optional[dict[str, list[int]]] = None, ) -> Block: """ Execute NEURON simulation with automated setup and result collection. Parameters ---------- duration__ms : Quantity__ms Total simulation duration in milliseconds. timestep__ms : Quantity__ms Integration timestep in milliseconds. membrane_recording : Optional[Dict[str, List[int]]], optional Populations and cell indices for membrane potential recording. Format: {"population_name": [cell_id1, cell_id2, ...]}, by default None. Returns ------- Block Structured simulation results containing: - spikes: Spike timing and ID data for all populations - membrane: Membrane potential traces (if requested) - models: Output data from all physiological models - simulation: Time vector and simulation metadata Raises ------ ValueError If model output attributes don't exist on model instances. RuntimeError If NEURON simulation fails to complete. """ try: # Setup NEURON environment self._setup_neuron_environment(duration__ms, timestep__ms) # Setup optional membrane recording if membrane_recording: self._setup_membrane_recording(membrane_recording) # Initialize population voltages self._initialize_voltages() # Register step callback for closed-loop dynamics self._register_step_callback() # Validate model outputs before simulation self._validate_model_outputs() # Setup spike recording on network self._setup_network_spike_recording() h.run() # Close progress bar (with error handling) if self._progress_bar is not None: try: self._progress_bar.close() except (TypeError, AttributeError): # Ignore progress bar closing errors pass print("Simulation completed") # Collect and structure results results = self._collect_results(duration__ms, timestep__ms) return results except Exception as e: # Close progress bar in case of error (with error handling) if self._progress_bar is not None: try: self._progress_bar.close() except (TypeError, AttributeError): # Ignore progress bar closing errors pass raise RuntimeError(f"Simulation failed: {str(e)}") from e
def _setup_neuron_environment( self, duration__ms: Quantity__ms, timestep__ms: Quantity__ms ) -> None: """Configure NEURON global simulation parameters.""" h.load_file("stdrun.hoc") h.celsius = self._temperature__celsius h.tstop = duration__ms h.dt = timestep__ms h.secondorder = 2 # Use Crank-Nicolson for better accuracy and speed # Calculate total steps for progress bar self._total_steps = int(duration__ms / timestep__ms) # Store timestep for progress bar updates self._timestep__ms = timestep__ms # Initialize progress bar self._progress_bar = tqdm( total=duration__ms.magnitude, desc="Simulation Progress", unit="ms", ) # Reset step counter for step callback self._step_counter = count(0) def _setup_membrane_recording(self, membrane_recording: dict[str, list[int]]) -> None: """Setup membrane potential recording vectors for specified populations.""" self._trace_vectors = {} for pop_name, cell_indices in membrane_recording.items(): if pop_name not in self._populations: raise ValueError(f"Population '{pop_name}' not found in populations") pop_traces = {} population = self._populations[pop_name] for cell_idx in cell_indices: if cell_idx >= len(population): raise ValueError( f"Cell index {cell_idx} out of range for population " f"'{pop_name}' (size: {len(population)})" ) vector = h.Vector() vector.record(population[cell_idx].soma(0.5)._ref_v) pop_traces[cell_idx] = vector self._trace_vectors[pop_name] = pop_traces def _initialize_voltages(self) -> None: """Automatically collect and set initial voltages for all populations.""" sections = [] voltages = [] for population in self._populations.values(): try: sec_list, v_hold = population.get_initialization_data() sections.extend(sec_list) voltages.extend(v_hold) except (AttributeError, TypeError): # Skip populations without initialization data continue if sections: def set_initial_voltages(): for sec, voltage in zip(sections, voltages): sec.v = voltage h.FInitializeHandler(0, set_initial_voltages) def _register_step_callback(self) -> None: """Register user's step callback with NEURON's integration system.""" duration__ms = h.tstop # Use NEURON's tstop as simulation duration # Track last progress bar update time to avoid multiple updates per timestep last_progress_time = -1 # Create wrapper that provides step counter access to callback and updates progress def step_wrapper(): nonlocal last_progress_time # Check if we've exceeded simulation time - don't process if so if h.t >= h.tstop: return # Update progress bar based on actual simulation time progression (not callback count) if self._progress_bar is not None: # Only update progress bar when simulation time has actually advanced if h.t > last_progress_time: time_advance = h.t - max(0, last_progress_time) try: self._progress_bar.update(time_advance) last_progress_time = h.t except (TypeError, AttributeError) as e: # Disable progress bar when it fails print(f"Progress bar error (disabling): {e}") self._progress_bar = None return # Call user's step callback only if we haven't exceeded time limit if h.t < h.tstop: return self._step_callback(self._step_counter) h.CVode().extra_scatter_gather(0, step_wrapper) def _validate_model_outputs(self) -> None: """Validate that all specified model output attributes exist.""" for model_name, output_attrs in self._model_outputs.items(): if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found in models") model_instance = self._models[model_name] for attr_name in output_attrs: if not hasattr(model_instance, attr_name): raise ValueError( f"Model '{model_name}' ({model_instance.__class__.__name__}) " f"does not have attribute '{attr_name}'" ) def _collect_results(self, duration__ms: Quantity__ms, timestep__ms: Quantity__ms) -> Block: """ Collect simulation results from network, models, and recordings. Returns structured results compatible with existing analysis code. """ block = Block() for pop_name in self._populations.keys(): segment = Segment(name=pop_name) if self._spike_recording and pop_name in self._spike_recording.get("spkvec", {}): spike_times = self._spike_recording["spkvec"][pop_name].as_numpy() spike_ids = self._spike_recording["idvec"][pop_name].as_numpy() for spike_id in sorted(np.unique(spike_ids)): times_for_id = spike_times[spike_ids == spike_id] if len(times_for_id) > 0: segment.spiketrains.append( SpikeTrain( name=str(int(spike_id)), times=(times_for_id * pq.ms).rescale(pq.s), t_start=0.0 * pq.s, t_stop=duration__ms.rescale(pq.s), sampling_rate=(1.0 / timestep__ms.rescale(pq.s)).rescale(pq.Hz), ) ) for cell_idx, vector in self._trace_vectors.get(pop_name, {}).items(): segment.analogsignals.append( AnalogSignal( name=str(cell_idx), sampling_period=timestep__ms.rescale(pq.s), signal=vector * pq.mV, ) ) block.segments.append(segment) for model_name, output_attrs in self._model_outputs.items(): segment = Segment(name=model_name) model_instance = self._models[model_name] for attr_name in output_attrs: attr_value = getattr(model_instance, attr_name) if hasattr(attr_value, "__iter__"): segment.analogsignals.append( AnalogSignal( name=attr_name, sampling_period=timestep__ms.rescale(pq.s), signal=attr_value * pq.dimensionless, ) ) elif isinstance(attr_value, (int, float, str)): segment.annotations[attr_name] = attr_value block.segments.append(segment) block.annotations["time__ms"] = duration__ms block.annotations["timestep__ms"] = timestep__ms block.annotations["temperature__celsius"] = self._temperature__celsius block.annotations["active_MNs"] = np.unique(spike_ids).astype(int) return block
[docs] def get_model_outputs(self, model_name: str) -> list[str]: """ Get the list of output attributes that will be collected for a model. Parameters ---------- model_name : str Name of the model as specified in the models dictionary. Returns ------- List[str] List of attribute names that will be collected from this model. """ return self._model_outputs.get(model_name, [])
[docs] def set_model_outputs(self, model_name: str, output_attrs: list[str]) -> None: """ Override the output attributes for a specific model. Parameters ---------- model_name : str Name of the model as specified in the models dictionary. output_attrs : List[str] List of attribute names to collect from this model. Raises ------ ValueError If model_name is not found in the models dictionary. """ if model_name not in self._models: raise ValueError(f"Model '{model_name}' not found in models") self._model_outputs[model_name] = output_attrs