"""Base dataset for windowed multi-modal data loading.
This module provides the paradigm-agnostic infrastructure for loading
windowed data from zarr stores. It handles:
- Zarr I/O with optional GPU Direct Storage (GDS)
- Window sampling (random or deterministic)
- RAM caching for performance
- Device management (CPU/GPU)
- Multiprocessing support
The WindowedDataset class returns all modalities as a dict, without
making assumptions about the learning paradigm (supervised, contrastive, etc.).
Paradigm-specific datasets should subclass WindowedDataset.
Example:
-------
>>> from myoverse.datasets.base import WindowedDataset
>>>
>>> # Load all modalities from zarr
>>> ds = WindowedDataset(
... "data.zip",
... split="training",
... modalities=["emg", "kinematics"],
... window_size=200,
... n_windows=10000,
... device="cuda",
... )
>>> data = ds[0] # dict[str, Tensor] with 'emg' and 'kinematics'
"""
from __future__ import annotations
import warnings
from collections.abc import Sequence
from pathlib import Path
import numpy as np
import torch
import zarr
from torch.utils.data import Dataset
from zarr.storage import ZipStore
# Suppress named tensor experimental warning
warnings.filterwarnings("ignore", category=UserWarning, message=".*Named tensors.*")
[docs]
class WindowedDataset(Dataset):
"""Base dataset that loads windows from zarr for any modality.
This is the infrastructure layer - it handles loading, windowing, caching,
and device management. It returns ALL requested modalities as a dict.
Subclasses implement paradigm-specific logic (e.g., SupervisedDataset
splits into inputs/targets, ContrastiveDataset creates augmented views).
Parameters
----------
zarr_path : Path | str
Path to the Zarr dataset.
split : str
Dataset split ('training', 'validation', 'testing').
modalities : Sequence[str] | None
Modality names to load. If None, loads all available modalities.
window_size : int
Number of samples per window.
window_stride : int | None
Stride between windows. If None, uses random positions.
n_windows : int | None
Number of windows per epoch. Required if window_stride is None.
seed : int | None
Random seed for reproducible window positions.
device : torch.device | str | None
Output device:
- None: return numpy arrays
- "cpu": return tensors on CPU
- "cuda": return tensors on GPU (uses kvikio GDS if available)
dtype : torch.dtype
Data type for tensors. Default: torch.float32.
cache_in_ram : bool
Cache entire split in RAM for faster access. Default: True.
Examples
--------
>>> # Return numpy arrays
>>> ds = WindowedDataset("data.zip", modalities=["emg"], device=None)
>>> data = ds[0]
>>> type(data["emg"]) # numpy.ndarray
>>>
>>> # Return tensors on GPU with named dimensions
>>> ds = WindowedDataset("data.zip", modalities=["emg"], device="cuda")
>>> data["emg"].device # cuda:0
>>> data["emg"].names # ('channel', 'time')
"""
[docs]
def __init__(
self,
zarr_path: Path | str,
split: str = "training",
modalities: Sequence[str] | None = None,
window_size: int = 200,
window_stride: int | None = None,
n_windows: int | None = None,
seed: int | None = None,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
cache_in_ram: bool = True,
):
self.zarr_path = Path(zarr_path)
self.split = split
self.window_size = window_size
self.window_stride = window_stride
self.seed = seed
self._rng = np.random.default_rng(seed)
self.device = torch.device(device) if device else None
self.cache_in_ram = cache_in_ram
self.dtype = dtype
# Validate path
if not self.zarr_path.exists():
raise FileNotFoundError(f"Dataset not found: {self.zarr_path}")
# Reject directory-based .zarr (no longer supported)
if self.zarr_path.is_dir() or self.zarr_path.suffix.lower() == ".zarr":
raise ValueError(
f"Directory-based .zarr is no longer supported: {self.zarr_path}\n"
f"Please recreate the dataset using DatasetCreator with a .zip path."
)
if window_stride is None and n_windows is None:
raise ValueError("Must specify n_windows when window_stride is None")
# Open zarr store from zip file
self._zip_store = ZipStore(self.zarr_path, mode="r")
self._store = zarr.open(self._zip_store, mode="r")
# Get metadata from standard zarr attrs
self._available_modalities = self._store.attrs.get("modalities", [])
self._tasks = self._store.attrs.get("tasks", [])
self._dims_info = self._store.attrs.get("dims", {})
# Get split group
if split not in self._store:
raise FileNotFoundError(f"Split '{split}' not found in {self.zarr_path}")
self._split_group = self._store[split]
# Determine modalities to load
if modalities is None:
self.modalities = list(self._available_modalities)
else:
self.modalities = list(modalities)
# Validate modalities exist
missing = set(self.modalities) - set(self._available_modalities)
if missing:
raise ValueError(
f"Requested modalities {missing} not in dataset. "
f"Available: {self._available_modalities}",
)
# RAM cache is loaded lazily on first data access
self._ram_cache = None
self._cache_loaded = False
# Build task lists for each modality (nested structure: split/modality/task)
self._modality_tasks: dict[str, list[str]] = {mod: [] for mod in self.modalities}
for mod in self.modalities:
if mod in self._split_group:
mod_group = self._split_group[mod]
self._modality_tasks[mod] = sorted(mod_group.keys())
# For compatibility, also store full paths
self._modality_vars: dict[str, list[str]] = {
mod: [f"{mod}/{task}" for task in tasks]
for mod, tasks in self._modality_tasks.items()
}
# Get recording lengths from first modality (nested structure)
first_mod = self.modalities[0]
self._recording_lengths = []
self._recording_tasks = [] # Store task names directly
mod_group = self._split_group[first_mod]
for task in self._modality_tasks[first_mod]:
arr = mod_group[task]
length = arr.shape[-1] # Time is last dimension
self._recording_lengths.append(length)
self._recording_tasks.append(task)
self._total_length = sum(self._recording_lengths)
# Compute number of windows
if window_stride is not None:
self._n_windows = sum(
max(0, (length - window_size) // window_stride + 1)
for length in self._recording_lengths
)
self._random_mode = False
else:
self._n_windows = n_windows
self._random_mode = True
self._setup_recording_ranges()
[docs]
def __getstate__(self):
"""Prepare state for pickling (used by multiprocessing workers)."""
state = self.__dict__.copy()
state["_store"] = None
state["_split_group"] = None
state["_rng"] = None
return state
[docs]
def __setstate__(self, state):
"""Restore state after unpickling (in worker processes)."""
try:
self.__dict__.update(state)
self._rng = np.random.default_rng(self.seed)
try:
zarr.config.reset()
except Exception:
pass
# If cache already loaded, no need to reopen store
if self._ram_cache is not None and self._cache_loaded:
return
# Reopen store for lazy cache loading or direct reads
self._zip_store = ZipStore(self.zarr_path, mode="r")
self._store = zarr.open(self._zip_store, mode="r")
self._split_group = self._store[self.split]
except Exception as e:
import sys
print(f"ERROR in __setstate__: {e}", file=sys.stderr)
raise
[docs]
def get_sample_shape(self, modality: str) -> tuple[int, ...]:
"""Get the shape of a sample for a given modality (without time dimension).
Parameters
----------
modality : str
Modality name.
Returns
-------
tuple[int, ...]
Shape without time dimension.
"""
task_list = self._modality_tasks.get(modality)
if not task_list:
raise ValueError(f"Modality '{modality}' not found")
first_task = task_list[0]
arr = self._split_group[modality][first_task]
return arr.shape[:-1]
[docs]
def _setup_recording_ranges(self) -> None:
"""Setup valid sampling ranges for each recording."""
self._valid_ranges = []
cumsum = 0
for rec_idx, length in enumerate(self._recording_lengths):
if length >= self.window_size:
valid_start = cumsum
valid_end = cumsum + length - self.window_size
self._valid_ranges.append((rec_idx, valid_start, valid_end))
cumsum += length
if not self._valid_ranges:
raise ValueError(
f"No recordings long enough for window_size={self.window_size}",
)
self._total_valid = sum(end - start + 1 for _, start, end in self._valid_ranges)
[docs]
def _global_to_local(self, global_pos: int) -> tuple[int, int]:
"""Convert global position to (recording_idx, local_position)."""
cumsum = 0
for rec_idx, length in enumerate(self._recording_lengths):
if global_pos < cumsum + length:
return rec_idx, global_pos - cumsum
cumsum += length
raise ValueError(f"Position {global_pos} out of range")
[docs]
def _sample_random_position(self) -> tuple[int, int]:
"""Sample a random valid window position."""
pos = self._rng.integers(0, self._total_valid)
cumsum = 0
for rec_idx, start, end in self._valid_ranges:
range_size = end - start + 1
if pos < cumsum + range_size:
global_pos = start + (pos - cumsum)
return self._global_to_local(global_pos)
cumsum += range_size
raise RuntimeError(
f"Failed to map random position {pos} to valid range "
f"(total_valid={self._total_valid}, n_ranges={len(self._valid_ranges)})",
)
[docs]
def _get_deterministic_position(self, idx: int) -> tuple[int, int]:
"""Get deterministic window position for given index."""
cumsum = 0
for rec_idx, length in enumerate(self._recording_lengths):
valid_positions = max(
0,
(length - self.window_size) // self.window_stride + 1,
)
if idx < cumsum + valid_positions:
local_idx = idx - cumsum
local_pos = local_idx * self.window_stride
return rec_idx, local_pos
cumsum += valid_positions
raise ValueError(f"Index {idx} out of range")
[docs]
def _get_task_for_recording(self, rec_idx: int) -> str:
"""Get the task name for a recording index."""
return self._recording_tasks[rec_idx]
# Default dimension names by modality (fallback when not in metadata)
_DEFAULT_DIMS: dict[str, tuple[str, ...]] = {
"emg": ("channel", "time"),
"kinematics": ("joint", "time"),
"eeg": ("electrode", "time"),
}
[docs]
def _get_dim_names(self, modality: str) -> tuple[str, ...]:
"""Get dimension names for a modality from metadata."""
if modality in self._dims_info:
return tuple(self._dims_info[modality])
return self._DEFAULT_DIMS.get(modality, ("channel", "time"))
[docs]
def _ensure_cache_loaded(self) -> None:
"""Load data into RAM cache if caching is enabled and not yet loaded."""
if not self.cache_in_ram or self._cache_loaded:
return
zarr.config.set({"async.concurrency": 32})
self._ram_cache = {}
# Load nested structure: modality -> task -> array
for mod in self.modalities:
if mod not in self._split_group:
continue
mod_group = self._split_group[mod]
for task in mod_group.keys():
cache_key = f"{mod}/{task}"
self._ram_cache[cache_key] = np.asarray(mod_group[task][:])
self._cache_loaded = True
[docs]
def _to_tensor(self, data) -> torch.Tensor:
"""Convert data to tensor on target device."""
tensor = torch.from_numpy(np.ascontiguousarray(data))
if self.device is not None:
return tensor.to(device=self.device, dtype=self.dtype)
return tensor.to(dtype=self.dtype)
[docs]
def _load_window(
self, var_path: str, local_pos: int, modality: str
) -> torch.Tensor | np.ndarray:
"""Load a window for a variable and convert to tensor.
Parameters
----------
var_path : str
Variable path in zarr (e.g., "emg/task1").
local_pos : int
Starting position within the recording.
modality : str
Modality name for dimension info.
Returns
-------
torch.Tensor | np.ndarray
Window data as tensor (if device set) or numpy array.
"""
if self._ram_cache is not None:
arr = self._ram_cache[var_path]
else:
# Navigate nested structure: modality/task
mod, task = var_path.split("/", 1)
arr = self._split_group[mod][task]
# Validate window fits within recording
end_pos = local_pos + self.window_size
if end_pos > arr.shape[-1]:
raise ValueError(
f"Window [{local_pos}:{end_pos}] exceeds recording length {arr.shape[-1]} "
f"for variable {var_name}",
)
data = arr[..., local_pos:end_pos]
if self.device is None:
return np.ascontiguousarray(data)
tensor = self._to_tensor(data)
names = self._get_dim_names(modality)
tensor = tensor.rename(*names)
return tensor
[docs]
def __len__(self) -> int:
return self._n_windows
[docs]
def __getitem__(self, idx: int) -> dict[str, torch.Tensor | np.ndarray]:
"""Load windows for all modalities.
Parameters
----------
idx : int
Sample index.
Returns
-------
dict[str, torch.Tensor | np.ndarray]
Dict mapping modality names to data windows.
"""
# Lazy load cache on first access
self._ensure_cache_loaded()
# Get window position
if self._random_mode:
rec_idx, local_pos = self._sample_random_position()
else:
rec_idx, local_pos = self._get_deterministic_position(idx)
task = self._get_task_for_recording(rec_idx)
# Extract windows for all modalities (nested structure: modality/task)
data = {}
for mod in self.modalities:
cache_key = f"{mod}/{task}"
if self._ram_cache is not None and cache_key in self._ram_cache:
data[mod] = self._load_window(cache_key, local_pos, mod)
elif mod in self._split_group and task in self._split_group[mod]:
data[mod] = self._load_window(cache_key, local_pos, mod)
return data
[docs]
def reseed(self, seed: int | None = None) -> None:
"""Reseed the random number generator."""
self._rng = np.random.default_rng(seed)