"""Modality configuration for dataset creation.
This module provides the Modality dataclass for configuring data sources
when creating datasets with DatasetCreator.
Example:
-------
>>> from myoverse.datasets import Modality
>>> from myoverse.transforms import Compose, Flatten, Index
>>>
>>> emg = Modality(
... path="emg.pkl",
... dims=("channel", "time"),
... )
>>>
>>> # With preprocessing transform
>>> kinematics = Modality(
... path="kinematics.pkl",
... dims=("dof", "time"),
... transform=Compose([
... Flatten(0, 1), # (21, 3, time) -> (63, time)
... Index(slice(3, None), dim="channel"), # Remove wrist
... ]),
... )
"""
from __future__ import annotations
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
if TYPE_CHECKING:
import torch
[docs]
@dataclass
class Modality:
"""Configuration for a data modality.
A modality represents a single data stream (EMG, kinematics, EEG, etc.)
with its dimensionality and optional preprocessing.
Parameters
----------
data : np.ndarray | dict[str, np.ndarray] | None
Data array or dict of arrays per task.
path : Path | str | None
Path to pickle file containing data.
dims : tuple[str, ...]
Dimension names (last must be 'time').
attrs : dict
Optional attributes to store with the data.
transform : Callable | None
Transform to apply after loading (receives tensor, returns tensor).
Examples
--------
>>> emg = Modality(
... path="emg.pkl",
... dims=("channel", "time"),
... )
>>> # With preprocessing transform
>>> from myoverse.transforms import Compose, Flatten, Index
>>> kinematics = Modality(
... path="kinematics.pkl",
... dims=("dof", "time"),
... transform=Compose([
... Flatten(0, 1), # (21, 3, time) -> (63, time)
... Index(slice(3, None), dim="channel"), # Remove wrist -> (60, time)
... ]),
... )
"""
data: np.ndarray | dict[str, np.ndarray] | None = None
path: Path | str | None = None
dims: tuple[str, ...] = ("channel", "time")
attrs: dict = field(default_factory=dict)
transform: Any = None
[docs]
def __post_init__(self):
if self.path is not None:
self.path = Path(self.path)
if self.dims[-1] != "time":
raise ValueError(f"Last dimension must be 'time', got {self.dims}")
if self.data is None and (self.path is None or not self.path.exists()):
raise ValueError("Must provide data or valid path")
[docs]
def load(self) -> dict[str, np.ndarray]:
"""Load data from path or return data dict, applying transform if set.
Returns
-------
dict[str, np.ndarray]
Dict mapping task names to data arrays.
"""
if self.data is not None:
if isinstance(self.data, np.ndarray):
data = {"default": self.data}
else:
data = self.data
else:
with open(self.path, "rb") as f:
data = pickle.load(f)
# Apply transform (converts to tensor, applies transform, back to numpy)
if self.transform is not None:
# Lazy import torch only when transform is used
import torch
transformed = {}
for task, arr in data.items():
tensor = torch.from_numpy(arr.astype(np.float32))
result = self.transform(tensor)
# Strip named tensor names before converting to numpy
if isinstance(result, torch.Tensor) and result.names[0] is not None:
result = result.rename(None)
transformed[task] = result.numpy()
data = transformed
return data