Source code for myogen.utils.currents

import numpy as np
import quantities as pq
from neo.core import AnalogSignal

from myogen.utils.decorators import beartowertype
from myogen.utils.types import (
    CURRENT__AnalogSignal,
    Quantity__Hz,
    Quantity__ms,
    Quantity__nA,
    Quantity__rad,
)


def _broadcast_and_validate(
    param_name: str,
    value: pq.Quantity | list[pq.Quantity],
    n_pools: int,
) -> list[pq.Quantity]:
    """Convert scalar or list of quantities to validated list.

    Handles three cases:
    1. Scalar value (non-array/list): broadcast to list of n_pools
    2. Scalar Quantity (0-dim array): broadcast to list of n_pools
    3. List/array with length n_pools: use as-is
    4. List/array with wrong length: raise ValueError

    Note: pq.Quantity is a subclass of np.ndarray, so scalar Quantities
    (e.g., 5.0 * pq.ms) are 0-dimensional arrays that need special handling.

    Parameters
    ----------
    param_name : str
        Name of the parameter (for error messages)
    value : pq.Quantity | list[pq.Quantity]
        The parameter value to broadcast/validate. Can be:
        - Scalar Quantity (e.g., 5.0 * pq.ms)
        - List of Quantities (e.g., [1*pq.ms, 2*pq.ms, 3*pq.ms])
        - Array of Quantities
    n_pools : int
        Expected length of the output list

    Returns
    -------
    list[pq.Quantity]
        List of length n_pools with Quantity values

    Raises
    ------
    ValueError
        If value is a list/array and its length doesn't match n_pools
    """
    # Check if value is a scalar (including 0-dimensional Quantity arrays)
    # This handles both non-array scalars and scalar Quantities
    is_scalar = not isinstance(value, (np.ndarray, list)) or (
        isinstance(value, np.ndarray) and value.ndim == 0
    )

    if is_scalar:
        # Broadcast scalar to list of n_pools elements
        value_list = [value] * n_pools
    else:
        # Value is already a list or multi-element array
        value_list = value

        # Validate that length matches expected n_pools
        if len(value_list) != n_pools:
            raise ValueError(
                f"Length of {param_name} ({len(value_list)}) must match n_pools ({n_pools})"
            )

    return value_list  # type: ignore


def _broadcast_and_validate_float(
    param_name: str,
    value: float | list[float],
    n_pools: int,
) -> list[float]:
    """Convert scalar or list of floats to validated list.

    Parameters
    ----------
    param_name : str
        Name of the parameter (for error messages)
    value : float | list[float]
        The parameter value to broadcast/validate
    n_pools : int
        Expected length of the output list

    Returns
    -------
    list[float]
        List of length n_pools with float values

    Raises
    ------
    ValueError
        If value is a list and its length doesn't match n_pools
    """
    if isinstance(value, (list, np.ndarray)):
        if len(value) != n_pools:
            raise ValueError(
                f"Length of {param_name} ({len(value)}) must match n_pools ({n_pools})"
            )
        return list(value)
    else:
        return [value] * n_pools


[docs] @beartowertype def create_sinusoidal_current( n_pools: int, t_points: int, timestep__ms: Quantity__ms, amplitudes__nA: Quantity__nA | list[Quantity__nA], frequencies__Hz: Quantity__Hz | list[Quantity__Hz], offsets__nA: Quantity__nA | list[Quantity__nA], phases__rad: Quantity__rad | list[Quantity__rad] = 0.0 * pq.rad, ) -> CURRENT__AnalogSignal: """Create a matrix of sinusoidal currents for multiple pools. Parameters ---------- n_pools : int Number of current pools to generate t_points : int Number of time points timestep__ms : Quantity__ms Time step in milliseconds as a Quantity amplitudes__nA : Quantity__nA | list[Quantity__nA] Amplitude(s) of the sinusoidal current(s) in nanoamperes. frequencies__Hz : Quantity__Hz | list[Quantity__Hz] Frequency(s) of the sinusoidal current(s) in Hertz. offsets__nA : Quantity__nA | list[Quantity__nA] DC offset(s) to add to the sinusoidal current(s) in nanoamperes. phases__rad : Quantity__rad | list[Quantity__rad] Phase(s) of the sinusoidal current(s) in radians. Raises ------ ValueError If the amplitudes, frequencies, offsets, or phases are lists and the length of the parameters does not match n_pools Notes ----- If a parameter is provided as a single Quantity, it is broadcasted to all pools. If provided as a list, its length must match n_pools. Returns ------- INPUT_CURRENT__AnalogSignal Analog signal of shape (t_points, n_pools) * pq.nA containing sinusoidal currents """ # Convert timestep to milliseconds for time array timestep_ms = timestep__ms.magnitude t = np.arange(0, t_points * timestep_ms, timestep_ms) # Convert quantities to lists of floats in expected units amplitudes_list = _broadcast_and_validate("amplitudes__nA", amplitudes__nA, n_pools) frequencies_list = _broadcast_and_validate("frequencies__Hz", frequencies__Hz, n_pools) offsets_list = _broadcast_and_validate("offsets__nA", offsets__nA, n_pools) phases_list = _broadcast_and_validate("phases__rad", phases__rad, n_pools) return AnalogSignal( signal=np.stack( [ ( amplitudes_list[i].magnitude * np.sin( 2 * np.pi * frequencies_list[i].magnitude * t / 1000 + phases_list[i].magnitude ) + offsets_list[i].magnitude ) for i in range(n_pools) ], axis=-1, ) * pq.nA, t_start=0 * pq.s, sampling_period=timestep__ms.rescale(pq.s), )
[docs] @beartowertype def create_sawtooth_current( n_pools: int, t_points: int, timestep__ms: Quantity__ms, amplitudes__nA: Quantity__nA | list[Quantity__nA], frequencies__Hz: Quantity__Hz | list[Quantity__Hz], offsets__nA: Quantity__nA | list[Quantity__nA] = 0.0 * pq.nA, widths__ratio: float | list[float] = 0.5, phases__rad: Quantity__rad | list[Quantity__rad] = 0.0 * pq.rad, ) -> CURRENT__AnalogSignal: """Create a matrix of sawtooth currents for multiple pools. Parameters ---------- n_pools : int Number of current pools to generate t_points : int Number of time points timestep__ms : Quantity__ms Time step in milliseconds as a Quantity amplitudes__nA : Quantity__nA | list[Quantity__nA] Amplitude(s) of the sawtooth current(s) in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools frequencies__Hz : Quantity__Hz | list[Quantity__Hz] Frequency(s) of the sawtooth current(s) in Hertz. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools offsets__nA : Quantity__nA | list[Quantity__nA] DC offset(s) to add to the sawtooth current(s) in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools widths__ratio : float | list[float] Width(s) of the rising edge as proportion of period (0 to 1). Must be: - Single float: used for all pools - List of floats: must match n_pools phases__rad : Quantity__rad | list[Quantity__rad] Phase(s) of the sawtooth current(s) in radians. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools Raises ------ ValueError If the parameters are lists and the length of the parameters does not match n_pools Returns ------- INPUT_CURRENT__AnalogSignal Analog signal of shape (t_points, n_pools) * pq.nA containing sawtooth currents """ t = np.arange(0, t_points * timestep__ms.magnitude, timestep__ms.magnitude) # Convert parameters to lists and validate amplitudes_list = _broadcast_and_validate("amplitudes__nA", amplitudes__nA, n_pools) frequencies_list = _broadcast_and_validate("frequencies__Hz", frequencies__Hz, n_pools) offsets_list = _broadcast_and_validate("offsets__nA", offsets__nA, n_pools) widths_list = _broadcast_and_validate_float("widths__ratio", widths__ratio, n_pools) phases_list = _broadcast_and_validate("phases__rad", phases__rad, n_pools) return AnalogSignal( signal=np.stack( [ ( amplitudes_list[i].magnitude * np.where( ( ( 2 * np.pi * frequencies_list[i].magnitude * t / 1000 + phases_list[i].magnitude ) / (2 * np.pi) ) % 1 < widths_list[i], ( ( 2 * np.pi * frequencies_list[i].magnitude * t / 1000 + phases_list[i].magnitude ) / (2 * np.pi) ) % 1 / widths_list[i], ( 1 - ( ( 2 * np.pi * frequencies_list[i].magnitude * t / 1000 + phases_list[i].magnitude ) / (2 * np.pi) ) % 1 ) / (1 - widths_list[i]), ) + offsets_list[i].magnitude ) for i in range(n_pools) ], axis=-1, ) * pq.nA, t_start=0 * pq.s, sampling_period=timestep__ms.rescale(pq.s), )
[docs] @beartowertype def create_step_current( n_pools: int, t_points: int, timestep__ms: Quantity__ms, step_heights__nA: Quantity__nA | list[Quantity__nA], step_durations__ms: Quantity__ms | list[Quantity__ms], offsets__nA: Quantity__nA | list[Quantity__nA] = 0.0 * pq.nA, ) -> CURRENT__AnalogSignal: """Create a matrix of step currents for multiple pools. Parameters ---------- n_pools : int Number of current pools to generate t_points : int Number of time points timestep__ms : Quantity__ms Time step in milliseconds as a Quantity step_heights__nA : Quantity__nA | list[Quantity__nA] Step height(s) for the current(s) in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools step_durations__ms : Quantity__ms | list[Quantity__ms] Step duration(s) in milliseconds as Quantities. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools offsets__nA : Quantity__nA | list[Quantity__nA] DC offset(s) to add to the step current(s) in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools Raises ------ ValueError If the parameters are lists and the length of the parameters does not match n_pools Returns ------- INPUT_CURRENT__AnalogSignal Analog signal of shape (t_points, n_pools) * pq.nA containing step currents """ # Convert parameters to lists and validate step_heights_list = _broadcast_and_validate("step_heights__nA", step_heights__nA, n_pools) step_durations_list = _broadcast_and_validate("step_durations__ms", step_durations__ms, n_pools) offsets_list = _broadcast_and_validate("offsets__nA", offsets__nA, n_pools) def create_step_for_pool(i: int) -> np.ndarray: current = np.zeros(t_points) duration_points = int(step_durations_list[i].magnitude / timestep__ms.magnitude) if duration_points > 0: end_idx = min(duration_points, t_points) current[:end_idx] = step_heights_list[i].magnitude return current + offsets_list[i].magnitude return AnalogSignal( signal=np.stack([create_step_for_pool(i) for i in range(n_pools)], axis=-1) * pq.nA, t_start=0 * pq.s, sampling_period=timestep__ms.rescale(pq.s), )
[docs] @beartowertype def create_ramp_current( n_pools: int, t_points: int, timestep__ms: Quantity__ms, start_currents__nA: Quantity__nA | list[Quantity__nA], end_currents__nA: Quantity__nA | list[Quantity__nA], offsets__nA: Quantity__nA | list[Quantity__nA] = 0.0 * pq.nA, ) -> CURRENT__AnalogSignal: """Create a matrix of ramp currents for multiple pools. Parameters ---------- n_pools : int Number of current pools to generate t_points : int Number of time points timestep__ms : Quantity__ms Time step in milliseconds as a Quantity start_currents__nA : Quantity__nA | list[Quantity__nA] Starting current(s) for the ramp in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools end_currents__nA : Quantity__nA | list[Quantity__nA] Ending current(s) for the ramp in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools offsets__nA : Quantity__nA | list[Quantity__nA] DC offset(s) to add to the ramp current(s) in nanoamperes. Must be: - Single Quantity: used for all pools - List of Quantities: must match n_pools Raises ------ ValueError If the parameters are lists and the length of the parameters does not match n_pools Returns ------- INPUT_CURRENT__AnalogSignal Analog signal of shape (t_points, n_pools) * pq.nA containing ramp currents """ # Convert parameters to lists and validate start_currents_list = _broadcast_and_validate("start_currents__nA", start_currents__nA, n_pools) end_currents_list = _broadcast_and_validate("end_currents__nA", end_currents__nA, n_pools) offsets_list = _broadcast_and_validate("offsets__nA", offsets__nA, n_pools) return AnalogSignal( signal=np.stack( [ np.linspace( start_currents_list[i].magnitude, end_currents_list[i].magnitude, t_points ) + offsets_list[i].magnitude for i in range(n_pools) ], axis=-1, ) * pq.nA, t_start=0 * pq.s, sampling_period=timestep__ms.rescale(pq.s), )
[docs] @beartowertype def create_trapezoid_current( n_pools: int, t_points: int, timestep__ms: Quantity__ms, amplitudes__nA: Quantity__nA | list[Quantity__nA], rise_times__ms: Quantity__ms | list[Quantity__ms] = 100.0 * pq.ms, plateau_times__ms: Quantity__ms | list[Quantity__ms] = 200.0 * pq.ms, fall_times__ms: Quantity__ms | list[Quantity__ms] = 100.0 * pq.ms, offsets__nA: Quantity__nA | list[Quantity__nA] = 0.0 * pq.nA, delays__ms: Quantity__ms | list[Quantity__ms] = 0.0 * pq.ms, ) -> CURRENT__AnalogSignal: """Create a matrix of trapezoidal currents for multiple pools. Parameters ---------- n_pools : int Number of current pools to generate t_points : int Number of time points timestep__ms : float Time step in milliseconds amplitudes__nA : float | list[float] Amplitude(s) of the trapezoidal current(s) in nano Amperes. Must be: - Single float: used for all pools - List of floats: must match n_pools rise_times__ms : float | list[float] Duration(s) of the rising phase in milliseconds. Must be: - Single float: used for all pools - List of floats: must match n_pools plateau_times__ms : float | list[float] Duration(s) of the plateau phase in milliseconds. Must be: - Single float: used for all pools - List of floats: must match n_pools fall_times__ms : float | list[float] Duration(s) of the falling phase in milliseconds. Must be: - Single float: used for all pools - List of floats: must match n_pools offsets__nA : float | list[float] DC offset(s) to add to the trapezoidal current(s) in nano Amperes. Must be: - Single float: used for all pools - List of floats: must match n_pools delays__ms : float | list[float] Delay(s) before starting the trapezoid in milliseconds. Must be: - Single float: used for all pools - List of floats: must match n_pools Raises ------ ValueError If the parameters are lists and the length of the parameters does not match n_pools Returns ------- INPUT_CURRENT__AnalogSignal Analog signal of shape (t_points, n_pools) * pq.nA containing trapezoidal currents """ # Convert parameters to lists and validate amplitudes_list = _broadcast_and_validate("amplitudes__nA", amplitudes__nA, n_pools) rise_times_list = _broadcast_and_validate("rise_times__ms", rise_times__ms, n_pools) plateau_times_list = _broadcast_and_validate("plateau_times__ms", plateau_times__ms, n_pools) fall_times_list = _broadcast_and_validate("fall_times__ms", fall_times__ms, n_pools) offsets_list = _broadcast_and_validate("offsets__nA", offsets__nA, n_pools) delays_list = _broadcast_and_validate("delays__ms", delays__ms, n_pools) def create_trapezoid_for_pool(i: int): # Calculate indices for each phase delay_points = int(delays_list[i].magnitude / timestep__ms.magnitude) rise_points = int(rise_times_list[i].magnitude / timestep__ms.magnitude) plateau_points = int(plateau_times_list[i].magnitude / timestep__ms.magnitude) fall_points = int(fall_times_list[i].magnitude / timestep__ms.magnitude) # Create the base trapezoid shape trapezoid = np.zeros(t_points) # Calculate start indices for each phase rise_start = delay_points plateau_start = rise_start + rise_points fall_start = plateau_start + plateau_points end_idx = fall_start + fall_points # Ensure we don't exceed array bounds if rise_start < t_points: # Rising phase (linear ramp up) rise_end = min(plateau_start, t_points) if rise_end > rise_start: points_to_fill = rise_end - rise_start trapezoid[rise_start:rise_end] = np.linspace(0, 1, points_to_fill) # Plateau phase (constant) if plateau_start < t_points: plateau_end = min(fall_start, t_points) if plateau_end > plateau_start: trapezoid[plateau_start:plateau_end] = 1 # Falling phase (linear ramp down) if fall_start < t_points: fall_end = min(end_idx, t_points) if fall_end > fall_start: points_to_fill = fall_end - fall_start trapezoid[fall_start:fall_end] = np.linspace(1, 0, points_to_fill) return amplitudes_list[i].magnitude * trapezoid + offsets_list[i].magnitude return AnalogSignal( signal=np.stack([create_trapezoid_for_pool(i) for i in range(n_pools)], axis=-1) * pq.nA, sampling_period=timestep__ms.rescale(pq.s), )