from pathlib import Path
from typing import Sequence, Optional
import numpy as np
from scipy.signal import butter
from myoverse.datasets.filters.emg_augmentations import (
GaussianNoise,
MagnitudeWarping,
WaveletDecomposition,
)
from myoverse.datasets.filters.generic import (
ApplyFunctionFilter,
IndexDataFilter,
IdentityFilter,
)
from myoverse.datasets.filters.temporal import RMSFilter, SOSFrequencyFilter
from myoverse.datasets.supervised import EMGDataset
[docs]
class EMBCDataset:
"""Official dataset maker for the EMBC paper [1].
Parameters
----------
emg_data_path : Path
The path to the pickle file containing the EMG data.
This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the EMG data.
The EMG data should be of shape (320, samples).
ground_truth_data_path : Path
The path to the pickle file containing the ground truth data.
This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the ground truth data.
The ground truth data should be of shape (21, 3, samples).
save_path : Path
The path to save the dataset to. This should be a zarr file.
emg_data : dict[str, np.ndarray], optional
Optional dictionary containing EMG data if not loading from a file.
ground_truth_data : dict[str, np.ndarray], optional
Optional dictionary containing ground truth data if not loading from a file.
tasks_to_use : Sequence[str], optional
The tasks to use.
debug_level : int, optional
Debug level (0-2). Default is 0 (no debugging).
silence_zarr_warnings : bool, optional
Whether to silence all Zarr-related warnings. Default is False.
Methods
-------
create_dataset()
Creates the dataset.
References
----------
[1] Sîmpetru, R.C., Osswald, M., Braun, D.I., Souza de Oliveira, D., Cakici, A.L., Del Vecchio, A., 2022.
Accurate Continuous Prediction of 14 Degrees of Freedom of the Hand from Myoelectrical Signals through
Convolutive Deep Learning, in: Proceedings of the 2022 44th Annual International Conference of
the IEEE Engineering in Medicine & Biology Society (EMBC) pp. 702–706. https://doi.org/10/gq2f47
"""
[docs]
def __init__(
self,
emg_data_path: Path,
ground_truth_data_path: Path,
save_path: Path,
emg_data: dict[str, np.ndarray] = {},
ground_truth_data: dict[str, np.ndarray] = {},
tasks_to_use: Sequence[str] = ("Change Me",),
debug_level: int = 0,
silence_zarr_warnings: bool = False,
):
self.emg_data_path = emg_data_path
self.emg_data = emg_data
self.ground_truth_data_path = ground_truth_data_path
self.ground_truth_data = ground_truth_data
self.tasks_to_use = tasks_to_use
self.save_path = save_path
self.debug_level = debug_level
self.silence_zarr_warnings = silence_zarr_warnings
[docs]
def create_dataset(self):
# EMBC default settings
dataset = EMGDataset(
emg_data_path=self.emg_data_path,
emg_data=self.emg_data,
ground_truth_data_path=self.ground_truth_data_path,
ground_truth_data=self.ground_truth_data,
ground_truth_data_type="kinematics",
sampling_frequency=2048.0,
tasks_to_use=self.tasks_to_use,
save_path=self.save_path,
chunk_size=192,
chunk_shift=64,
testing_split_ratio=0.2,
validation_split_ratio=0.2,
debug_level=self.debug_level,
silence_zarr_warnings=self.silence_zarr_warnings,
# EMBC-specific filter pipelines and augmentations
emg_filter_pipeline_after_chunking=[
[
IdentityFilter(is_output=True, name="raw", input_is_chunked=True),
SOSFrequencyFilter(
sos_filter_coefficients=butter(
4, 20, "lowpass", output="sos", fs=2048
),
is_output=True,
input_is_chunked=True,
),
]
],
emg_representations_to_filter_after_chunking=[["Last"]],
ground_truth_filter_pipeline_before_chunking=[
[
ApplyFunctionFilter(
function=np.reshape,
name="Reshape",
newshape=(63, -1),
input_is_chunked=False,
),
IndexDataFilter(indices=(slice(3, 63),), input_is_chunked=False),
]
],
ground_truth_representations_to_filter_before_chunking=[["Input"]],
ground_truth_filter_pipeline_after_chunking=[
[
ApplyFunctionFilter(
function=np.mean,
name="Mean",
axis=-1,
is_output=True,
input_is_chunked=True,
)
]
],
ground_truth_representations_to_filter_after_chunking=[["Last"]],
augmentation_pipelines=[
[GaussianNoise(is_output=True, input_is_chunked=False)],
[
MagnitudeWarping(
is_output=True, nr_of_grids=5, input_is_chunked=False
)
],
[
WaveletDecomposition(
level=3, is_output=True, nr_of_grids=5, input_is_chunked=False
)
],
],
amount_of_chunks_to_augment_at_once=500,
).create_dataset()
class CastelliniDataset:
"""Dataset maker made after the Castellini paper [1].
This is not the official dataset maker used but our own version made after the paper.
Parameters
----------
emg_data_path : Path
The path to the pickle file containing the EMG data.
This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the EMG data.
The EMG data should be of shape (320, samples).
ground_truth_data_path : Path
The path to the pickle file containing the ground truth data.
This should be a dictionary with the keys as the tasks in tasks_to_use and the values as the ground truth data.
The ground truth data should be of shape (21, 3, samples).
save_path : Path
The path to save the dataset to. This should be a zarr file.
emg_data : dict[str, np.ndarray], optional
Optional dictionary containing EMG data if not loading from a file.
ground_truth_data : dict[str, np.ndarray], optional
Optional dictionary containing ground truth data if not loading from a file.
tasks_to_use : Sequence[str], optional
The tasks to use.
debug_level : int, optional
Debug level (0-2). Default is 0 (no debugging).
silence_zarr_warnings : bool, optional
Whether to silence all Zarr-related warnings. Default is False.
Methods
-------
create_dataset()
Creates the dataset.
References
----------
[1] Nowak, M., Vujaklija, I., Sturma, A., Castellini, C., Farina, D., 2023.
Simultaneous and Proportional Real-Time Myocontrol of Up to Three Degrees of Freedom of the Wrist and Hand.
IEEE Transactions on Biomedical Engineering 70, 459–469. https://doi.org/10/grc7qf
"""
def __init__(
self,
emg_data_path: Path,
ground_truth_data_path: Path,
save_path: Path,
emg_data: dict[str, np.ndarray] = {},
ground_truth_data: dict[str, np.ndarray] = {},
tasks_to_use: Sequence[str] = ("Change Me",),
debug_level: int = 0,
silence_zarr_warnings: bool = False,
):
self.emg_data_path = emg_data_path
self.emg_data = emg_data
self.ground_truth_data_path = ground_truth_data_path
self.ground_truth_data = ground_truth_data
self.save_path = save_path
self.tasks_to_use = tasks_to_use
self.debug_level = debug_level
self.silence_zarr_warnings = silence_zarr_warnings
def create_dataset(self):
dataset = EMGDataset(
emg_data_path=self.emg_data_path,
emg_data=self.emg_data,
ground_truth_data_path=self.ground_truth_data_path,
ground_truth_data=self.ground_truth_data,
ground_truth_data_type="kinematics",
sampling_frequency=2048,
tasks_to_use=self.tasks_to_use,
save_path=self.save_path,
debug_level=self.debug_level,
silence_zarr_warnings=self.silence_zarr_warnings,
# Castellini-specific filter pipelines
emg_filter_pipeline_before_chunking=[
[
SOSFrequencyFilter(
sos_filter_coefficients=butter(
5,
(20, 500),
"bandpass",
output="sos",
fs=2048,
),
name="Bandpass 20-500 Hz",
input_is_chunked=False,
),
SOSFrequencyFilter(
sos_filter_coefficients=butter(
5, (45, 55), "bandstop", output="sos", fs=2048
),
name="Bandstop 45-55 Hz",
input_is_chunked=False,
),
RMSFilter(
window_size=204,
shift=20,
name=f"RMS {204 / 2048 * 1000} ms",
input_is_chunked=False,
),
]
],
emg_representations_to_filter_before_chunking=[["Input"]],
ground_truth_filter_pipeline_before_chunking=[
[
ApplyFunctionFilter(
function=np.reshape,
newshape=(63, -1),
name="Reshape",
input_is_chunked=False,
),
IndexDataFilter(
indices=(slice(3, 63),),
name="Indexing (Remove Wrist)",
input_is_chunked=False,
),
]
],
ground_truth_representations_to_filter_before_chunking=[["Input"]],
ground_truth_filter_pipeline_after_chunking=[
[
ApplyFunctionFilter(
function=np.mean,
axis=-1,
is_output=True,
name="Mean",
input_is_chunked=True,
)
]
],
ground_truth_representations_to_filter_after_chunking=[["Last"]],
augmentation_pipelines=[
[GaussianNoise(is_output=True, input_is_chunked=False)],
[
MagnitudeWarping(
is_output=True, input_is_chunked=False, nr_of_grids=5
)
],
[
WaveletDecomposition(
level=3, is_output=True, input_is_chunked=False, nr_of_grids=5
)
],
],
amount_of_chunks_to_augment_at_once=500,
)
dataset.create_dataset()