Source code for myoverse.datasets.defaults

from pathlib import Path
from typing import Sequence, Optional

import numpy as np
from scipy.signal import butter

from myoverse.datasets.filters.emg_augmentations import (
    GaussianNoise,
    MagnitudeWarping,
    WaveletDecomposition,
)
from myoverse.datasets.filters.generic import (
    ApplyFunctionFilter,
    IndexDataFilter,
    IdentityFilter,
)
from myoverse.datasets.filters.temporal import RMSFilter, SOSFrequencyFilter
from myoverse.datasets.supervised import EMGDataset


[docs] class EMBCDataset: """Official dataset maker for the EMBC paper [1]. Parameters ---------- emg_data_path : Path The path to the pickle file containing the EMG data. This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the EMG data. The EMG data should be of shape (320, samples). ground_truth_data_path : Path The path to the pickle file containing the ground truth data. This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the ground truth data. The ground truth data should be of shape (21, 3, samples). save_path : Path The path to save the dataset to. This should be a zarr file. emg_data : dict[str, np.ndarray], optional Optional dictionary containing EMG data if not loading from a file. ground_truth_data : dict[str, np.ndarray], optional Optional dictionary containing ground truth data if not loading from a file. tasks_to_use : Sequence[str], optional The tasks to use. debug_level : int, optional Debug level (0-2). Default is 0 (no debugging). silence_zarr_warnings : bool, optional Whether to silence all Zarr-related warnings. Default is False. Methods ------- create_dataset() Creates the dataset. References ---------- [1] Sîmpetru, R.C., Osswald, M., Braun, D.I., Souza de Oliveira, D., Cakici, A.L., Del Vecchio, A., 2022. Accurate Continuous Prediction of 14 Degrees of Freedom of the Hand from Myoelectrical Signals through Convolutive Deep Learning, in: Proceedings of the 2022 44th Annual International Conference of the IEEE Engineering in Medicine & Biology Society (EMBC) pp. 702–706. https://doi.org/10/gq2f47 """
[docs] def __init__( self, emg_data_path: Path, ground_truth_data_path: Path, save_path: Path, emg_data: dict[str, np.ndarray] = {}, ground_truth_data: dict[str, np.ndarray] = {}, tasks_to_use: Sequence[str] = ("Change Me",), debug_level: int = 0, silence_zarr_warnings: bool = False, ): self.emg_data_path = emg_data_path self.emg_data = emg_data self.ground_truth_data_path = ground_truth_data_path self.ground_truth_data = ground_truth_data self.tasks_to_use = tasks_to_use self.save_path = save_path self.debug_level = debug_level self.silence_zarr_warnings = silence_zarr_warnings
[docs] def create_dataset(self): # EMBC default settings dataset = EMGDataset( emg_data_path=self.emg_data_path, emg_data=self.emg_data, ground_truth_data_path=self.ground_truth_data_path, ground_truth_data=self.ground_truth_data, ground_truth_data_type="kinematics", sampling_frequency=2048.0, tasks_to_use=self.tasks_to_use, save_path=self.save_path, chunk_size=192, chunk_shift=64, testing_split_ratio=0.2, validation_split_ratio=0.2, debug_level=self.debug_level, silence_zarr_warnings=self.silence_zarr_warnings, # EMBC-specific filter pipelines and augmentations emg_filter_pipeline_after_chunking=[ [ IdentityFilter(is_output=True, name="raw", input_is_chunked=True), SOSFrequencyFilter( sos_filter_coefficients=butter( 4, 20, "lowpass", output="sos", fs=2048 ), is_output=True, input_is_chunked=True, ), ] ], emg_representations_to_filter_after_chunking=[["Last"]], ground_truth_filter_pipeline_before_chunking=[ [ ApplyFunctionFilter( function=np.reshape, name="Reshape", newshape=(63, -1), input_is_chunked=False, ), IndexDataFilter(indices=(slice(3, 63),), input_is_chunked=False), ] ], ground_truth_representations_to_filter_before_chunking=[["Input"]], ground_truth_filter_pipeline_after_chunking=[ [ ApplyFunctionFilter( function=np.mean, name="Mean", axis=-1, is_output=True, input_is_chunked=True, ) ] ], ground_truth_representations_to_filter_after_chunking=[["Last"]], augmentation_pipelines=[ [GaussianNoise(is_output=True, input_is_chunked=False)], [ MagnitudeWarping( is_output=True, nr_of_grids=5, input_is_chunked=False ) ], [ WaveletDecomposition( level=3, is_output=True, nr_of_grids=5, input_is_chunked=False ) ], ], amount_of_chunks_to_augment_at_once=500, ).create_dataset()
class CastelliniDataset: """Dataset maker made after the Castellini paper [1]. This is not the official dataset maker used but our own version made after the paper. Parameters ---------- emg_data_path : Path The path to the pickle file containing the EMG data. This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the EMG data. The EMG data should be of shape (320, samples). ground_truth_data_path : Path The path to the pickle file containing the ground truth data. This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the ground truth data. The ground truth data should be of shape (21, 3, samples). save_path : Path The path to save the dataset to. This should be a zarr file. emg_data : dict[str, np.ndarray], optional Optional dictionary containing EMG data if not loading from a file. ground_truth_data : dict[str, np.ndarray], optional Optional dictionary containing ground truth data if not loading from a file. tasks_to_use : Sequence[str], optional The tasks to use. debug_level : int, optional Debug level (0-2). Default is 0 (no debugging). silence_zarr_warnings : bool, optional Whether to silence all Zarr-related warnings. Default is False. Methods ------- create_dataset() Creates the dataset. References ---------- [1] Nowak, M., Vujaklija, I., Sturma, A., Castellini, C., Farina, D., 2023. Simultaneous and Proportional Real-Time Myocontrol of Up to Three Degrees of Freedom of the Wrist and Hand. IEEE Transactions on Biomedical Engineering 70, 459–469. https://doi.org/10/grc7qf """ def __init__( self, emg_data_path: Path, ground_truth_data_path: Path, save_path: Path, emg_data: dict[str, np.ndarray] = {}, ground_truth_data: dict[str, np.ndarray] = {}, tasks_to_use: Sequence[str] = ("Change Me",), debug_level: int = 0, silence_zarr_warnings: bool = False, ): self.emg_data_path = emg_data_path self.emg_data = emg_data self.ground_truth_data_path = ground_truth_data_path self.ground_truth_data = ground_truth_data self.save_path = save_path self.tasks_to_use = tasks_to_use self.debug_level = debug_level self.silence_zarr_warnings = silence_zarr_warnings def create_dataset(self): dataset = EMGDataset( emg_data_path=self.emg_data_path, emg_data=self.emg_data, ground_truth_data_path=self.ground_truth_data_path, ground_truth_data=self.ground_truth_data, ground_truth_data_type="kinematics", sampling_frequency=2048, tasks_to_use=self.tasks_to_use, save_path=self.save_path, debug_level=self.debug_level, silence_zarr_warnings=self.silence_zarr_warnings, # Castellini-specific filter pipelines emg_filter_pipeline_before_chunking=[ [ SOSFrequencyFilter( sos_filter_coefficients=butter( 5, (20, 500), "bandpass", output="sos", fs=2048, ), name="Bandpass 20-500 Hz", input_is_chunked=False, ), SOSFrequencyFilter( sos_filter_coefficients=butter( 5, (45, 55), "bandstop", output="sos", fs=2048 ), name="Bandstop 45-55 Hz", input_is_chunked=False, ), RMSFilter( window_size=204, shift=20, name=f"RMS {204 / 2048 * 1000} ms", input_is_chunked=False, ), ] ], emg_representations_to_filter_before_chunking=[["Input"]], ground_truth_filter_pipeline_before_chunking=[ [ ApplyFunctionFilter( function=np.reshape, newshape=(63, -1), name="Reshape", input_is_chunked=False, ), IndexDataFilter( indices=(slice(3, 63),), name="Indexing (Remove Wrist)", input_is_chunked=False, ), ] ], ground_truth_representations_to_filter_before_chunking=[["Input"]], ground_truth_filter_pipeline_after_chunking=[ [ ApplyFunctionFilter( function=np.mean, axis=-1, is_output=True, name="Mean", input_is_chunked=True, ) ] ], ground_truth_representations_to_filter_after_chunking=[["Last"]], augmentation_pipelines=[ [GaussianNoise(is_output=True, input_is_chunked=False)], [ MagnitudeWarping( is_output=True, input_is_chunked=False, nr_of_grids=5 ) ], [ WaveletDecomposition( level=3, is_output=True, input_is_chunked=False, nr_of_grids=5 ) ], ], amount_of_chunks_to_augment_at_once=500, ) dataset.create_dataset()