WindowedDataset#
- class myoverse.datasets.base.WindowedDataset(zarr_path, split='training', modalities=None, window_size=200, window_stride=None, n_windows=None, seed=None, device=None, dtype=torch.float32, cache_in_ram=True)[source]#
Base dataset that loads windows from zarr for any modality.
This is the infrastructure layer - it handles loading, windowing, caching, and device management. It returns ALL requested modalities as a dict.
Subclasses implement paradigm-specific logic (e.g., SupervisedDataset splits into inputs/targets, ContrastiveDataset creates augmented views).
- Parameters:
zarr_path (Path | str) – Path to the Zarr dataset.
split (str) – Dataset split (‘training’, ‘validation’, ‘testing’).
modalities (Sequence[str] | None) – Modality names to load. If None, loads all available modalities.
window_size (int) – Number of samples per window.
window_stride (int | None) – Stride between windows. If None, uses random positions.
n_windows (int | None) – Number of windows per epoch. Required if window_stride is None.
seed (int | None) – Random seed for reproducible window positions.
device (torch.device | str | None) – Output device: - None: return numpy arrays - “cpu”: return tensors on CPU - “cuda”: return tensors on GPU (uses kvikio GDS if available)
dtype (torch.dtype) – Data type for tensors. Default: torch.float32.
cache_in_ram (bool) – Cache entire split in RAM for faster access. Default: True.
Examples
>>> # Return numpy arrays >>> ds = WindowedDataset("data.zip", modalities=["emg"], device=None) >>> data = ds[0] >>> type(data["emg"]) # numpy.ndarray >>> >>> # Return tensors on GPU with named dimensions >>> ds = WindowedDataset("data.zip", modalities=["emg"], device="cuda") >>> data["emg"].device # cuda:0 >>> data["emg"].names # ('channel', 'time')
Methods
__getitem__(idx)Load windows for all modalities.
Prepare state for pickling (used by multiprocessing workers).
__init__(zarr_path[, split, modalities, ...])__len__()__setstate__(state)Restore state after unpickling (in worker processes).
Load data into RAM cache if caching is enabled and not yet loaded.
Get deterministic window position for given index.
_get_dim_names(modality)Get dimension names for a modality from metadata.
_get_task_for_recording(rec_idx)Get the task name for a recording index.
_global_to_local(global_pos)Convert global position to (recording_idx, local_position).
_load_window(var_path, local_pos, modality)Load a window for a variable and convert to tensor.
Sample a random valid window position.
Setup valid sampling ranges for each recording.
_to_tensor(data)Convert data to tensor on target device.
get_sample_shape(modality)Get the shape of a sample for a given modality (without time dimension).
reseed([seed])Reseed the random number generator.