from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Union, List, Tuple, Callable, Type
import warnings
import zarr
import numpy as np
import lightning as L
from torch.utils.data import DataLoader, Dataset
from myoverse.datasets.filters.generic import (
IdentityFilter,
FilterBaseClass,
IndexDataFilter,
)
from myoverse.datatypes import EMGData, KinematicsData, _Data
[docs]
class EMGZarrDataset(Dataset):
"""Dataset class for loading EMG data from Zarr files.
Parameters
----------
zarr_file : Path
The path to the zarr file
subset_name : str
The name of the subset to load (e.g., "training", "validation", "testing")
target_name : str
The name of the target data
emg_dtype : np.dtype, optional
The data type of the EMG data, by default np.float32
target_dtype : np.dtype, optional
The data type of the target data, by default np.float32
sampling_frequency : float, optional
The sampling frequency of the EMG data in Hz, by default 2048.0
input_data_class : Type[_Data], optional
The class to use for input data, by default EMGData
target_data_class : Type[_Data], optional
The class to use for target data, by default KinematicsData
input_augmentation_pipeline : list[list[FilterBaseClass]], optional
The augmentation pipeline for the input data
input_augmentation_probabilities : Sequence[float], optional
The probabilities for each input augmentation pipeline
target_augmentation_pipeline : list[list[FilterBaseClass]], optional
The augmentation pipeline for the target data
target_augmentation_probabilities : Sequence[float], optional
The probabilities for each target augmentation pipeline
cache_size : int, optional
The maximum number of items to cache, by default 100
"""
[docs]
def __init__(
self,
zarr_file: Path,
subset_name: str,
target_name: str,
emg_dtype=np.float32,
target_dtype=np.float32,
sampling_frequency: float = 2048.0,
input_data_class: Type[_Data] = EMGData,
target_data_class: Type[_Data] = KinematicsData,
input_augmentation_pipeline: list[list[FilterBaseClass]] = None,
input_augmentation_probabilities: Sequence[float] = None,
target_augmentation_pipeline: list[list[FilterBaseClass]] = None,
target_augmentation_probabilities: Sequence[float] = None,
cache_size: int = 100,
):
self.zarr_file = zarr_file
self.subset_name = subset_name
self.target_name = target_name
self.sampling_frequency = sampling_frequency
self.input_data_class = input_data_class
self.target_data_class = target_data_class
# Validate zarr file exists
if not Path(zarr_file).exists():
raise FileNotFoundError(f"Zarr file not found at {zarr_file}")
# Set default augmentation pipelines if None
if input_augmentation_pipeline is None:
self.input_augmentation_pipeline = [
[IdentityFilter(is_output=True, input_is_chunked=True)]
]
else:
self.input_augmentation_pipeline = input_augmentation_pipeline
if input_augmentation_probabilities is None:
self.input_augmentation_probabilities = (1.0,)
else:
self.input_augmentation_probabilities = input_augmentation_probabilities
if target_augmentation_pipeline is None:
self.target_augmentation_pipeline = [
[IndexDataFilter(indices=(0,), is_output=True, input_is_chunked=True)]
]
else:
self.target_augmentation_pipeline = target_augmentation_pipeline
if target_augmentation_probabilities is None:
self.target_augmentation_probabilities = (1.0,)
else:
self.target_augmentation_probabilities = target_augmentation_probabilities
# Validate augmentation probabilities
self._validate_augmentation_probabilities()
# Initialize cache
self._cache = {}
self.cache_size = cache_size
# Load data from zarr file
try:
zarr_root = zarr.open(store=str(self.zarr_file), mode="r", zarr_version=2)
if self.subset_name not in zarr_root:
raise ValueError(f"Subset '{self.subset_name}' not found in Zarr file")
subset = zarr_root[self.subset_name]
if "emg" not in subset:
raise ValueError(f"EMG data not found in subset '{self.subset_name}'")
self._emg_data = subset["emg"]
self._emg_data = {
key: self._emg_data[key] for key in self._emg_data.array_keys()
}
if self.target_name not in subset:
raise ValueError(
f"Target data '{self.target_name}' not found in subset '{self.subset_name}'"
)
self._target_data = subset[self.target_name]
try:
self._target_data = {
key: self._target_data[key]
for key in self._target_data.array_keys()
}
except AttributeError:
self._target_data = {"temp": self._target_data}
try:
self.length = list(self._emg_data.values())[0].shape[0]
if self.length == 0:
warnings.warn(f"Dataset '{self.subset_name}' is empty")
except IndexError:
self.length = 0
warnings.warn(
f"Failed to determine dataset length for '{self.subset_name}'"
)
except Exception as e:
raise ValueError(f"Failed to load data from Zarr file: {e}")
self.emg_dtype = emg_dtype
self.target_dtype = target_dtype
[docs]
def _validate_augmentation_probabilities(self):
"""Validate that augmentation probabilities are valid."""
# Check input augmentation probabilities
if len(self.input_augmentation_pipeline) != len(
self.input_augmentation_probabilities
):
raise ValueError(
"Number of input augmentation probabilities must match number of pipelines"
)
if abs(sum(self.input_augmentation_probabilities) - 1.0) > 1e-6:
raise ValueError("Sum of input augmentation probabilities must be 1.0")
# Check target augmentation probabilities
if len(self.target_augmentation_pipeline) != len(
self.target_augmentation_probabilities
):
raise ValueError(
"Number of target augmentation probabilities must match number of pipelines"
)
if abs(sum(self.target_augmentation_probabilities) - 1.0) > 1e-6:
raise ValueError("Sum of target augmentation probabilities must be 1.0")
[docs]
def __len__(self):
return self.length
[docs]
def __getitem__(self, idx):
# Check if item is in cache
if idx in self._cache:
return self._cache[idx]
input_data = []
target_data = []
# Choose augmentation pipeline based on probabilities
input_augmentation_chosen = self.input_augmentation_pipeline[
np.random.choice(
len(self.input_augmentation_pipeline),
p=self.input_augmentation_probabilities,
)
]
target_augmentation_chosen = self.target_augmentation_pipeline[
np.random.choice(
len(self.target_augmentation_pipeline),
p=self.target_augmentation_probabilities,
)
]
# Process EMG data
for v in self._emg_data.values():
try:
temp = self.input_data_class(
v[idx], sampling_frequency=self.sampling_frequency
)
temp.apply_filter_sequence(
input_augmentation_chosen, representations_to_filter=["Input"]
)
input_data.append(list(temp.output_representations.values())[0])
except Exception as e:
raise RuntimeError(f"Error processing EMG data at index {idx}: {e}")
# Process target data
for v in self._target_data.values():
try:
# Use the specified target_data_class without forced np.atleast_2d
temp = self.target_data_class(
v[idx], sampling_frequency=self.sampling_frequency
)
temp.apply_filter_sequence(
target_augmentation_chosen,
representations_to_filter=["Input"],
)
target_data.append(list(temp.output_representations.values())[0])
except Exception as e:
raise RuntimeError(f"Error processing target data at index {idx}: {e}")
# Convert to appropriate data types
result = (
np.array(input_data).astype(self.emg_dtype),
np.array(target_data).astype(self.target_dtype),
)
# Update cache (with simple LRU-like behavior)
if len(self._cache) >= self.cache_size:
# Remove oldest item
self._cache.pop(next(iter(self._cache)))
self._cache[idx] = result
return result
class EMGDatasetLoader(L.LightningDataModule):
"""Dataset loader for the EMG dataset.
Parameters
----------
data_path : Path
The path to the zarr file
seed : Optional[int], optional
The seed for the random number generator, by default None
dataloader_params : Optional[Dict[str, Any]], optional
The parameters for the DataLoader, by default None
shuffle_train : bool, optional
Whether to shuffle the training data, by default True
emg_dtype : np.dtype, optional
The data type of the EMG data, by default np.float32
target_dtype : np.dtype, optional
The data type of the target data, by default np.float32
target_name : str, optional
The name of the target data, by default "ground_truth"
sampling_frequency : float, optional
The sampling frequency of the EMG data in Hz, by default 2048.0
input_data_class : Type[_Data], optional
The class to use for input data, by default EMGData
target_data_class : Type[_Data], optional
The class to use for target data, by default KinematicsData
input_augmentation_pipeline : Optional[List[List[FilterBaseClass]]], optional
The augmentation pipeline for the input data, by default None (identity filter)
input_augmentation_probabilities : Optional[Sequence[float]], optional
The probabilities for the input augmentation pipeline, by default None (1.0)
target_augmentation_pipeline : Optional[List[List[FilterBaseClass]]], optional
The augmentation pipeline for the target data, by default None (index filter)
target_augmentation_probabilities : Optional[Sequence[float]], optional
The probabilities for the target augmentation pipeline, by default None (1.0)
cache_size : int, optional
The maximum number of items to cache, by default 100
preprocessing_hooks : Optional[Dict[str, Callable]], optional
Custom preprocessing functions, by default None
"""
def __init__(
self,
data_path: Path,
seed: Optional[int] = None,
dataloader_params: Optional[Dict[str, Any]] = None,
shuffle_train: bool = True,
emg_dtype=np.float32,
target_dtype=np.float32,
target_name: str = "ground_truth",
sampling_frequency: float = 2048.0,
input_data_class: Type[_Data] = EMGData,
target_data_class: Type[_Data] = KinematicsData,
input_augmentation_pipeline: Optional[List[List[FilterBaseClass]]] = None,
input_augmentation_probabilities: Optional[Sequence[float]] = None,
target_augmentation_pipeline: Optional[List[List[FilterBaseClass]]] = None,
target_augmentation_probabilities: Optional[Sequence[float]] = None,
cache_size: int = 100,
preprocessing_hooks: Optional[Dict[str, Callable]] = None,
):
super().__init__()
# Validate and set data path
if not isinstance(data_path, Path):
data_path = Path(data_path)
self.data_path = data_path
# Set random seed
self.seed = seed
if seed is not None:
np.random.seed(seed)
# Set default dataloader parameters
if dataloader_params is None:
self.dataloader_params = {
"batch_size": 32,
"num_workers": 4,
"pin_memory": True,
}
else:
self.dataloader_params = dataloader_params
# Set basic parameters
self.shuffle_train = shuffle_train
self.emg_dtype = emg_dtype
self.target_dtype = target_dtype
self.target_name = target_name
self.sampling_frequency = sampling_frequency
self.input_data_class = input_data_class
self.target_data_class = target_data_class
self.cache_size = cache_size
self.preprocessing_hooks = preprocessing_hooks or {}
# Set augmentation pipelines and probabilities
self.input_augmentation_pipeline = input_augmentation_pipeline
self.input_augmentation_probabilities = input_augmentation_probabilities
self.target_augmentation_pipeline = target_augmentation_pipeline
self.target_augmentation_probabilities = target_augmentation_probabilities
# Validate zarr file
self._validate_zarr_file()
def _validate_zarr_file(self):
"""Validate that the zarr file exists and has the expected structure."""
if not self.data_path.exists():
raise FileNotFoundError(f"Zarr file not found at {self.data_path}")
try:
zarr_root = zarr.open(str(self.data_path))
required_subsets = ["training", "validation", "testing"]
for subset in required_subsets:
if subset not in zarr_root:
warnings.warn(f"Subset '{subset}' not found in Zarr file")
except Exception as e:
raise ValueError(f"Failed to validate Zarr file: {e}")
def _create_dataset(self, subset_name: str, use_augmentation: bool = False):
"""Create a dataset for a specific subset.
Parameters
----------
subset_name : str
The name of the subset to load
use_augmentation : bool, optional
Whether to use augmentation, by default False
Returns
-------
EMGZarrDataset
The dataset for the specified subset
"""
if use_augmentation:
# For training, use the full augmentation pipeline
return EMGZarrDataset(
zarr_file=self.data_path,
subset_name=subset_name,
target_name=self.target_name,
emg_dtype=self.emg_dtype,
target_dtype=self.target_dtype,
sampling_frequency=self.sampling_frequency,
input_data_class=self.input_data_class,
target_data_class=self.target_data_class,
input_augmentation_pipeline=self.input_augmentation_pipeline,
input_augmentation_probabilities=self.input_augmentation_probabilities,
target_augmentation_pipeline=self.target_augmentation_pipeline,
target_augmentation_probabilities=self.target_augmentation_probabilities,
cache_size=self.cache_size,
)
else:
# For validation/testing, use IdentityFilter to ensure consistent shape without augmentation
from myoverse.datasets.filters.generic import IdentityFilter
validation_target_pipeline = [
[IdentityFilter(is_output=True, input_is_chunked=True)]
]
return EMGZarrDataset(
zarr_file=self.data_path,
subset_name=subset_name,
target_name=self.target_name,
emg_dtype=self.emg_dtype,
target_dtype=self.target_dtype,
sampling_frequency=self.sampling_frequency,
input_data_class=self.input_data_class,
target_data_class=self.target_data_class,
# No input augmentation for validation
input_augmentation_pipeline=None,
input_augmentation_probabilities=None,
# Use identity filter for validation to ensure consistent shape without augmentation
target_augmentation_pipeline=validation_target_pipeline,
target_augmentation_probabilities=[1.0],
cache_size=self.cache_size,
)
def train_dataloader(self) -> DataLoader:
"""Returns the training set as a DataLoader.
Returns
-------
DataLoader
The training set
"""
return DataLoader(
self._create_dataset("training", use_augmentation=True),
shuffle=self.shuffle_train,
**self.dataloader_params.copy(),
)
def test_dataloader(self) -> DataLoader:
"""Returns the testing set as a DataLoader.
Returns
-------
DataLoader
The testing set
"""
return DataLoader(
self._create_dataset("testing", use_augmentation=False),
shuffle=False,
**self.dataloader_params.copy(),
)
def val_dataloader(self) -> DataLoader:
"""Returns the validation set as a DataLoader.
Returns
-------
DataLoader
The validation set
"""
# Create a copy of dataloader parameters to avoid side effects
dataloader_params = self.dataloader_params.copy()
# Always set drop_last to False for validation to ensure all samples are evaluated
if "drop_last" in dataloader_params:
dataloader_params["drop_last"] = False
return DataLoader(
self._create_dataset("validation", use_augmentation=False),
shuffle=False,
**dataloader_params,
)
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "EMGDatasetLoader":
"""Create an EMGDatasetLoader from a configuration dictionary.
Parameters
----------
config : Dict[str, Any]
Configuration dictionary
Returns
-------
EMGDatasetLoader
The dataset loader created from the configuration
"""
return cls(**config)