Source code for myogestic.utils.config

from typing import Any, Callable, Literal, TypedDict

from myoverse.transforms import Transform

from myogestic.gui.widgets.templates.output_system import OutputSystemTemplate
from myogestic.gui.widgets.templates.visual_interface import (
    RecordingInterfaceTemplate,
    SetupInterfaceTemplate,
)


def _custom_message_handler(mode, context, message):
    """
    Custom message handler for the "warnings" module.

    This function is used to suppress a QLayout warning that is not relevant to the user.
    This warning is printed to the console when a new widget is added to a layout that already has a layout.

    Parameters
    ----------
    mode : str
        The mode of the message.
    context : dict
        The context of the message.
    message : str
        The message to display.
    """
    # Suppress the specific warning
    if "QLayout: Attempting to add QLayout" in message:
        return
    # Print other messages to the console
    print(f"{mode}: {message}")


class IntParameter(TypedDict):
    start_value: int
    end_value: int
    step: int
    default_value: int


class FloatParameter(TypedDict):
    start_value: float
    end_value: float
    step: float
    default_value: float


class StringParameter(TypedDict):
    default_value: str


class BoolParameter(TypedDict):
    default_value: bool


class CategoricalParameter(TypedDict):
    values: list[str]
    default_value: str


ChangeableParameter = (
    IntParameter | FloatParameter | StringParameter | BoolParameter | CategoricalParameter
)
UnchangeableParameter = int | float | str | bool | list[str] | None


[docs] class Registry: """ The registry class is used to store different components of a MyoGestic application pipeline. Attributes ---------- models_map : dict[str, tuple[Any, bool]], optional A dictionary that maps model names to tuples of model classes and whether the model is a classifier, by default {}. The tuple is in the form (model_class, is_classifier). models_functions_map : dict[str, dict[Literal["save", "load", "train", "predict"], callable]], optional A dictionary that maps model names to dictionaries of model functions, by default {}. The functions are `save`, `load`, `train`, and `predict`. models_parameters_map : dict[str, dict[Literal["changeable", "unchangeable"], Union[ChangeableParameter, UnchangeableParameter]]], optional A dictionary that maps model names to dictionaries of model parameters, by default {}. The parameters are `changeable` and `unchangeable`. The `changeable` parameters are dictionaries of changeable parameters, while the `unchangeable` parameters are dictionaries of unchangeable parameters. See the `ChangeableParameter` and `UnchangeableParameter` types for more information. features_map : dict[str, type[Transform]], optional A dictionary that maps feature names to feature classes, by default {}. The feature class must be subclasses of `Transform` (TensorTransform). real_time_filters_map : dict[str, callable], optional A dictionary that maps filter names to filter functions, by default {}. A filter function is a callable that takes a single argument, which is the data to filter. The data will be a list of floats that represent the regression output of a model. visual_interfaces_map : dict[str, tuple[type[SetupInterfaceTemplate], type[RecordingInterfaceTemplate]]], optional A dictionary that maps visual interface names to tuples of setup and recording interface classes, by default {}. The setup interface class must be a subclass of `SetupInterfaceTemplate`, while the recording interface class must be a subclass of `RecordingInterfaceTemplate`. output_systems_map : dict[str, type[OutputSystemTemplate]], optional A dictionary that maps output system names to output system classes, by default {}. The output system class must be a subclass of `OutputSystemTemplate`. """
[docs] def __init__(self): self.models_map: dict[str, tuple[Any, bool]] = {} self.models_functions_map: dict[ str, dict[Literal["save", "load", "train", "predict"], Callable] ] = {} self.models_parameters_map: dict[ str, dict[ Literal["changeable", "unchangeable"], ChangeableParameter | UnchangeableParameter, ], ] = {} self.models_metadata_map: dict[str, dict] = {} self.features_map: dict[str, type[Transform]] = {} self.features_metadata_map: dict[str, dict] = {} self.real_time_filters_map: dict[str, Callable] = {} self.visual_interfaces_map: dict[ str, tuple[type[SetupInterfaceTemplate], type[RecordingInterfaceTemplate]] ] = {} self.output_systems_map: dict[str, type[OutputSystemTemplate]] = {}
[docs] def register_model( self, name: str, model_class: type, is_classifier: bool, save_function: Callable, load_function: Callable, train_function: Callable, predict_function: Callable, changeable_parameters: dict[str, ChangeableParameter] | None = None, unchangeable_parameters: dict[str, UnchangeableParameter] | None = None, requires_temporal_preservation: bool = False, feature_window_size: int | None = None, ) -> None: """ Register a model in the registry. The model name must be unique. Parameters ---------- name : str The name of the model. model_class : type The class of the model. is_classifier : bool Whether the model is a classifier. save_function : callable The function to save the model. load_function : callable The function to load the model. train_function : callable The function to train the model. predict_function : callable The function to make predictions with the model. changeable_parameters : dict of str to ChangeableParameter, optional A dictionary of changeable parameters for the model. Default is None. unchangeable_parameters : dict of str to UnchangeableParameter, optional A dictionary of unchangeable parameters for the model. Default is None. requires_temporal_preservation : bool, optional Whether the model requires temporal preservation in features. Default is False. Models like RaulNet with CNN layers need multiple temporal samples, so features should use smaller window sizes to preserve time dimension. feature_window_size : int, optional The window size to use for feature extraction. Default is None, which uses the full buffer size. For models requiring temporal preservation, this should be smaller than the buffer size (e.g., 120 for RaulNet with buffer of 360). Raises ------ ValueError If the model is already registered. """ if name in self.models_map: raise ValueError( f'Model "{name}" is already registered. Please choose a different name.' ) self.models_map[name] = (model_class, is_classifier) self.models_functions_map[name] = { "save": save_function, "load": load_function, "train": train_function, "predict": predict_function, } self.models_parameters_map[name] = { "changeable": changeable_parameters or {}, "unchangeable": unchangeable_parameters or {}, } self.models_metadata_map[name] = { "requires_temporal_preservation": requires_temporal_preservation, "feature_window_size": feature_window_size, }
[docs] def register_feature( self, name: str, feature: type[Transform], requires_temporal_preservation: bool = False, ) -> None: """ Register a feature in the registry. .. note:: The feature name must be unique. Parameters ---------- name : str The name of the feature. feature : type[Transform] The feature transform class to register. requires_temporal_preservation : bool, optional Whether this feature requires temporal preservation (keeps time dimension). Features like RMS Small Window that preserve the time dimension should set this to True. Default is False. Raises ------ ValueError If the feature is already registered """ if name in self.features_map: raise ValueError( f'Feature "{name}" is already registered. Please choose a different name.' ) self.features_map[name] = feature self.features_metadata_map[name] = { "requires_temporal_preservation": requires_temporal_preservation, }
[docs] def register_real_time_filter(self, name: str, function: Callable) -> None: """ Register a real-time filter in the registry. .. note:: The filter name must be unique. Parameters ---------- name : str The name of the filter. function : callable The filter function. Raises ------ ValueError If the filter is already registered. """ if name in self.real_time_filters_map: raise ValueError( f'Filter "{name}" is already registered. Please choose a different name.' ) self.real_time_filters_map[name] = function
[docs] def register_visual_interface( self, name: str, setup_interface_ui: type[SetupInterfaceTemplate], recording_interface_ui: type[RecordingInterfaceTemplate], ) -> None: """ Register a visual interface in the registry. .. note:: The output modality name must be unique. Parameters ---------- name : str The name of the visual interface. setup_interface_ui : type[SetupInterfaceTemplate] The setup interface class. recording_interface_ui : type[RecordingInterfaceTemplate] The recording interface class. Raises ------ ValueError If the visual interface is already registered. """ if name in self.visual_interfaces_map: raise ValueError( f'Visual interface "{name}" is already registered. Please choose a different name.' ) self.visual_interfaces_map[name] = (setup_interface_ui, recording_interface_ui)
[docs] def register_output_system( self, name: str, output_system: type[OutputSystemTemplate] ) -> None: """ Register an output system in the registry. .. note:: The output system name must be unique. Parameters ---------- name : str The name of the output system. output_system : callable The output system class. Raises ------ ValueError If the output system is already registered. """ if name in self.output_systems_map: raise ValueError( f'Output system "{name}" is already registered. Please choose a different name.' ) self.output_systems_map[name] = output_system
# ------------------------------------------------------------------------------ if "CONFIG_REGISTRY" not in globals(): CONFIG_REGISTRY = Registry() import myogestic.default_config # noqa import myogestic.user_config # noqa