from pathlib import Path
from typing import Any, Dict, Optional, Sequence
import zarr
import numpy as np
import lightning as L
from torch.utils.data import DataLoader, Dataset
from myoverse.datasets.filters.generic import (
IdentityFilter,
FilterBaseClass,
IndexDataFilter,
)
from myoverse.datatypes import EMGData, VirtualHandKinematics
[docs]
class EMGDatasetLoader(L.LightningDataModule):
"""Dataset loader for the EMG dataset.
Parameters
----------
data_path : Path
The path to the zarr file
seed : Optional[int], optional
The seed for the random number generator, by default None
dataloader_parameters : Dict[str, Any], optional
The parameters for the DataLoader, by default None
shuffle_training_data : bool, optional
Whether to shuffle the training data, by default True
input_type : numpy.dtype, optional
The type of the input data, by default np.float32
ground_truth_type : numpy.dtype, optional
The type of the ground_truth data, by default np.float32
ground_truth_name : str, optional
The name of the ground truth data, by default "ground_truth"
input_augmentation_pipeline : list[list[FilterBaseClass]], optional
The augmentation pipeline for the input data, by default [[IdentityFilter(is_output=True)]]
input_augmentation_probabilities : Sequence[float], optional
The probabilities for the augmentation pipeline, by default (1,)
The sum of the probabilities must be equal to 1 and the number of probabilities must be equal to the number of augmentation sequences.
ground_truth_augmentation_pipeline : list[list[FilterBaseClass]], optional
The augmentation pipeline for the ground truth data, by default [[IdentityFilter(is_output=True)]]
ground_truth_augmentation_probabilities : Sequence[float], optional
The probabilities for the augmentation pipeline, by default (1,)
The sum of the probabilities must be equal to 1 and the number of probabilities must be equal to the number of augmentation sequences.
"""
class _EMGDatasetLoader(Dataset):
def __init__(
self,
zarr_file: Path,
subset_name: str,
ground_truth_name: str,
input_type=np.float32,
ground_truth_type=np.float32,
input_augmentation_pipeline: list[list[FilterBaseClass]] = [ # noqa
[IdentityFilter(is_output=True)]
],
input_augmentation_probabilities: Sequence[float] = (1,),
ground_truth_augmentation_pipeline: list[list[FilterBaseClass]] = [ # noqa
[IndexDataFilter(indices=(0,), is_output=True)]
],
ground_truth_augmentation_probabilities: Sequence[float] = (1,),
):
self.zarr_file = zarr_file
self.subset_name = subset_name
self.ground_truth_name = ground_truth_name
self._emg_data = zarr.open(str(self.zarr_file))[self.subset_name]["emg"]
self._emg_data = {
key: self._emg_data[key] for key in self._emg_data.array_keys()
}
self._ground_truth_data = zarr.open(str(self.zarr_file))[self.subset_name][
self.ground_truth_name
]
try:
self._ground_truth_data = {
key: self._ground_truth_data[key]
for key in self._ground_truth_data.array_keys()
}
except AttributeError:
self._ground_truth_data = {"temp": self._ground_truth_data}
try:
self.length = list(self._emg_data.values())[0].shape[0]
except IndexError:
self.length = 0
self.input_type = input_type
self.ground_truth_type = ground_truth_type
self.input_augmentation_pipeline = input_augmentation_pipeline
self.input_augmentation_probabilities = input_augmentation_probabilities
self.ground_truth_augmentation_pipeline = ground_truth_augmentation_pipeline
self.ground_truth_augmentation_probabilities = (
ground_truth_augmentation_probabilities
)
def __len__(self):
return self.length
def __getitem__(self, idx):
input_data = []
ground_truth_data = []
input_augmentation_chosen = self.input_augmentation_pipeline[
np.random.choice(
len(self.input_augmentation_pipeline),
p=self.input_augmentation_probabilities,
)
]
ground_truth_augmentation_chosen = self.ground_truth_augmentation_pipeline[
np.random.choice(
len(self.ground_truth_augmentation_pipeline),
p=self.ground_truth_augmentation_probabilities,
)
]
for v in self._emg_data.values():
temp = EMGData(v[idx], sampling_frequency=2048)
temp.apply_filter_sequence(
input_augmentation_chosen, representation_to_filter="Input"
)
input_data.append(list(temp.output_representations.values())[0])
for v in self._ground_truth_data.values():
temp = EMGData(np.atleast_2d(v[idx]), sampling_frequency=2048)
temp.apply_filter_sequence(
ground_truth_augmentation_chosen, representation_to_filter="Input"
)
ground_truth_data.append(list(temp.output_representations.values())[0])
return np.array(input_data).astype(self.input_type), np.array(
ground_truth_data
).astype(self.ground_truth_type)
def __init__(
self,
data_path: Path,
seed: Optional[int] = None,
dataloader_parameters: Dict[str, Any] = None,
shuffle_training_data: bool = True,
input_type=np.float32,
ground_truth_type=np.float32,
ground_truth_name: str = "ground_truth",
input_augmentation_pipeline: list[list[FilterBaseClass]] = [ # noqa
[IdentityFilter(is_output=True)]
],
input_augmentation_probabilities: Sequence[float] = (1,),
ground_truth_augmentation_pipeline: list[list[FilterBaseClass]] = [ # noqa
[IndexDataFilter(indices=(0,), is_output=True)]
],
ground_truth_augmentation_probabilities: Sequence[float] = (1,),
):
"""Initializes the dataset.
Attributes
----------
data_path : Path
The path to the HDF5 file
seed : Optional[int], optional
The seed for the random number generator, by default None
dataloader_parameters : Dict[str, Any], optional
The parameters for the DataLoader, by default None
shuffle_training_data : bool, optional
Whether to shuffle the training data, by default True
input_type : np.dtype, optional
The type of the input data, by default np.float32
ground_truth_type : np.dtype, optional
The type of the label data, by default np.float32
ground_truth_name : bool, optional
The name of the ground truth data, by default "ground_truth"
input_augmentation_pipeline : list[list[FilterBaseClass]], optional
The augmentation pipeline for the input data, by default [[IdentityFilter(is_output=True)]]
input_augmentation_probabilities : Sequence[float], optional
The probabilities for the augmentation pipeline, by default (1,)
The sum of the probabilities must be equal to 1 and the number of probabilities must be equal to the number of augmentation sequences.
ground_truth_augmentation_pipeline : list[list[FilterBaseClass]], optional
The augmentation pipeline for the ground truth data, by default [[IdentityFilter(is_output=True)]]
ground_truth_augmentation_probabilities : Sequence[float], optional
The probabilities for the augmentation pipeline, by default (1,)
The sum of the probabilities must be equal to 1 and the number of probabilities must be equal to the number of augmentation sequences.
"""
super().__init__()
self.data_path = data_path
self.seed = seed
if dataloader_parameters is None:
raise ValueError("DataLoader parameters must be set!")
self.dataloader_parameters = dataloader_parameters
self.shuffle_training_data = shuffle_training_data
self.input_type = input_type
self.ground_truth_type = ground_truth_type
self.ground_truth_name = ground_truth_name
self.input_augmentation_pipeline = input_augmentation_pipeline
self.input_augmentation_probabilities = input_augmentation_probabilities
self.ground_truth_augmentation_pipeline = ground_truth_augmentation_pipeline
self.ground_truth_augmentation_probabilities = (
ground_truth_augmentation_probabilities
)
# check if augmentation probabilities are equal to the number of augmentations filter sequences and that the sum is 1
if len(self.input_augmentation_pipeline) != len(
self.input_augmentation_probabilities
):
raise ValueError(
"The number of probabilities must be equal to the number of augmentation sequences"
)
if sum(self.input_augmentation_probabilities) != 1:
raise ValueError("The sum of the probabilities must be equal to 1")
if len(self.ground_truth_augmentation_pipeline) != len(
self.ground_truth_augmentation_probabilities
):
raise ValueError(
"The number of probabilities must be equal to the number of augmentation sequences"
)
if sum(self.ground_truth_augmentation_probabilities) != 1:
raise ValueError("The sum of the probabilities must be equal to 1")
[docs]
def train_dataloader(self) -> DataLoader:
"""Returns the training set as a DataLoader.
Returns
-------
DataLoader
The training set
"""
return DataLoader(
self._EMGDatasetLoader(
self.data_path,
subset_name="training",
ground_truth_name=self.ground_truth_name,
input_type=self.input_type,
ground_truth_type=self.ground_truth_type,
input_augmentation_pipeline=self.input_augmentation_pipeline,
input_augmentation_probabilities=self.input_augmentation_probabilities,
ground_truth_augmentation_pipeline=self.ground_truth_augmentation_pipeline,
ground_truth_augmentation_probabilities=self.ground_truth_augmentation_probabilities,
),
shuffle=self.shuffle_training_data,
**self.dataloader_parameters,
)
[docs]
def test_dataloader(self) -> DataLoader:
"""Returns the testing set as a DataLoader.
Returns
-------
DataLoader
The testing set
"""
return DataLoader(
self._EMGDatasetLoader(
self.data_path,
subset_name="testing",
ground_truth_name=self.ground_truth_name,
input_type=self.input_type,
ground_truth_type=self.ground_truth_type,
),
shuffle=False,
**self.dataloader_parameters,
)
[docs]
def val_dataloader(self) -> DataLoader:
"""Returns the testing set as a DataLoader.
Returns
-------
DataLoader
The testing set
"""
dataloader_prams = self.dataloader_parameters.copy()
if "drop_last" in dataloader_prams:
dataloader_prams["drop_last"] = False
return DataLoader(
self._EMGDatasetLoader(
self.data_path,
subset_name="validation",
ground_truth_name=self.ground_truth_name,
input_type=self.input_type,
ground_truth_type=self.ground_truth_type,
),
shuffle=False,
**dataloader_prams,
)