"""PyTorch Lightning DataModule for MyoVerse datasets.
This module provides the DataModule class that integrates datasets
with PyTorch Lightning's training loop.
Example:
-------
>>> from myoverse.datasets import DataModule
>>> from myoverse.transforms import Compose, ZScore, RMS
>>>
>>> dm = DataModule(
... "data.zip",
... inputs=["emg"],
... targets=["kinematics"],
... window_size=200,
... n_windows_per_epoch=10000,
... device="cuda",
... train_transform=Compose([ZScore(), RMS(50)]),
... )
>>> dm.setup("fit")
>>> train_loader = dm.train_dataloader()
"""
from __future__ import annotations
from collections.abc import Callable, Sequence
from pathlib import Path
import lightning as L
import numpy as np
import torch
from torch.utils.data import DataLoader
from myoverse.datasets.paradigms import SupervisedDataset
def _stack_modalities(
samples: list[dict[str, torch.Tensor | np.ndarray]],
) -> dict[str, torch.Tensor | np.ndarray]:
"""Stack samples for each modality, stripping named tensor names."""
first = next(iter(samples[0].values()))
is_numpy = isinstance(first, np.ndarray)
result = {}
for key in samples[0]:
items = [s[key] for s in samples]
if is_numpy:
result[key] = np.stack(items)
else:
# Strip named tensor names - models don't support them
items_unnamed = [
t.rename(None) if t.names[0] is not None else t for t in items
]
result[key] = torch.stack(items_unnamed)
return result
def collate_supervised(
batch: list[tuple[dict, dict]],
) -> tuple[
dict[str, torch.Tensor] | torch.Tensor, dict[str, torch.Tensor] | torch.Tensor
]:
"""Collate function for supervised datasets.
Handles both numpy arrays and tensors.
Strips named tensor names since most models don't support them.
Parameters
----------
batch : list[tuple[dict, dict]]
List of (inputs, targets) tuples from dataset.
Returns
-------
tuple
Batched (inputs, targets). If single modality, returns tensor directly.
"""
inputs_list = [b[0] for b in batch]
targets_list = [b[1] for b in batch]
inputs = _stack_modalities(inputs_list)
targets = _stack_modalities(targets_list)
# Return directly if single input/target
if len(inputs) == 1 and len(targets) == 1:
return next(iter(inputs.values())), next(iter(targets.values()))
return inputs, targets
[docs]
class DataModule(L.LightningDataModule):
"""Lightning DataModule for supervised learning.
Wraps SupervisedDataset instances for train/val/test splits
and provides DataLoaders.
Parameters
----------
data_path : Path | str
Path to the Zarr dataset.
inputs : Sequence[str]
Modality names to use as model inputs.
targets : Sequence[str]
Modality names to use as model targets.
batch_size : int
Batch size for all dataloaders.
window_size : int
Window size in samples.
window_stride : int | None
Window stride for validation/test.
n_windows_per_epoch : int | None
Number of random windows per training epoch.
num_workers : int
Number of dataloader workers.
train_transform : Callable | None
Transform for training inputs.
val_transform : Callable | None
Transform for validation inputs.
test_transform : Callable | None
Transform for test inputs.
target_transform : Callable | None
Transform for targets.
pin_memory : bool
Pin memory for faster GPU transfer.
persistent_workers : bool
Keep workers alive between epochs.
device : torch.device | str | None
Output device ('cpu', 'cuda', or None for numpy).
dtype : torch.dtype
Data type for tensors.
cache_in_ram : bool
Cache entire split in RAM.
Examples
--------
>>> dm = DataModule(
... "data.zip",
... inputs=["emg"],
... targets=["kinematics"],
... window_size=200,
... n_windows_per_epoch=10000,
... device="cuda",
... )
>>> dm.setup("fit")
>>> for inputs, targets in dm.train_dataloader():
... # inputs: Tensor of shape (batch, channels, time)
... # targets: Tensor of shape (batch, joints)
... pass
"""
[docs]
def __init__(
self,
data_path: Path | str,
inputs: Sequence[str] = ("emg",),
targets: Sequence[str] = ("kinematics",),
batch_size: int = 32,
window_size: int = 200,
window_stride: int | None = None,
n_windows_per_epoch: int | None = None,
num_workers: int = 4,
train_transform: Callable | None = None,
val_transform: Callable | None = None,
test_transform: Callable | None = None,
target_transform: Callable | None = None,
pin_memory: bool = True,
persistent_workers: bool = True,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
cache_in_ram: bool = True,
):
super().__init__()
self.data_path = Path(data_path)
self.inputs = list(inputs)
self.targets = list(targets)
self.batch_size = batch_size
self.window_size = window_size
self.window_stride = window_stride
self.n_windows_per_epoch = n_windows_per_epoch
self.num_workers = num_workers
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform or val_transform
self.target_transform = target_transform
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers and num_workers > 0
self.device = device
self.dtype = dtype
self.cache_in_ram = cache_in_ram
if not self.data_path.exists():
raise FileNotFoundError(f"Dataset not found: {self.data_path}")
if n_windows_per_epoch is None and window_stride is None:
raise ValueError("Need n_windows_per_epoch or window_stride")
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
[docs]
def setup(self, stage: str | None = None) -> None:
"""Setup datasets for each stage."""
if stage == "fit" or stage is None:
self.train_dataset = SupervisedDataset(
self.data_path,
split="training",
inputs=self.inputs,
targets=self.targets,
window_size=self.window_size,
n_windows=self.n_windows_per_epoch,
transform=self.train_transform,
target_transform=self.target_transform,
device=self.device,
dtype=self.dtype,
cache_in_ram=self.cache_in_ram,
)
self.val_dataset = SupervisedDataset(
self.data_path,
split="validation",
inputs=self.inputs,
targets=self.targets,
window_size=self.window_size,
window_stride=self.window_stride or self.window_size,
transform=self.val_transform,
target_transform=self.target_transform,
device=self.device,
dtype=self.dtype,
cache_in_ram=self.cache_in_ram,
)
# Pre-load cache in main process before workers are spawned
# (avoids zarr ZipStore concurrency issues in multiprocessing)
if self.cache_in_ram and self.num_workers > 0:
_ = self.train_dataset[0]
_ = self.val_dataset[0]
if stage == "test" or stage is None:
self.test_dataset = SupervisedDataset(
self.data_path,
split="testing",
inputs=self.inputs,
targets=self.targets,
window_size=self.window_size,
window_stride=self.window_stride or self.window_size,
transform=self.test_transform,
target_transform=self.target_transform,
device=self.device,
dtype=self.dtype,
cache_in_ram=self.cache_in_ram,
)
# Pre-load cache in main process before workers are spawned
if self.cache_in_ram and self.num_workers > 0:
_ = self.test_dataset[0]
[docs]
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory and self.device is None,
persistent_workers=self.persistent_workers,
collate_fn=collate_supervised,
)
[docs]
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory and self.device is None,
persistent_workers=self.persistent_workers,
collate_fn=collate_supervised,
)
[docs]
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=self.pin_memory and self.device is None,
persistent_workers=self.persistent_workers,
collate_fn=collate_supervised,
)