WindowedDataset#

class myoverse.datasets.base.WindowedDataset(zarr_path, split='training', modalities=None, window_size=200, window_stride=None, n_windows=None, seed=None, device=None, dtype=torch.float32, cache_in_ram=True)[source]#

Base dataset that loads windows from zarr for any modality.

This is the infrastructure layer - it handles loading, windowing, caching, and device management. It returns ALL requested modalities as a dict.

Subclasses implement paradigm-specific logic (e.g., SupervisedDataset splits into inputs/targets, ContrastiveDataset creates augmented views).

Parameters:
  • zarr_path (Path | str) – Path to the Zarr dataset.

  • split (str) – Dataset split (‘training’, ‘validation’, ‘testing’).

  • modalities (Sequence[str] | None) – Modality names to load. If None, loads all available modalities.

  • window_size (int) – Number of samples per window.

  • window_stride (int | None) – Stride between windows. If None, uses random positions.

  • n_windows (int | None) – Number of windows per epoch. Required if window_stride is None.

  • seed (int | None) – Random seed for reproducible window positions.

  • device (torch.device | str | None) – Output device: - None: return numpy arrays - “cpu”: return tensors on CPU - “cuda”: return tensors on GPU (uses kvikio GDS if available)

  • dtype (torch.dtype) – Data type for tensors. Default: torch.float32.

  • cache_in_ram (bool) – Cache entire split in RAM for faster access. Default: True.

Examples

>>> # Return numpy arrays
>>> ds = WindowedDataset("data.zip", modalities=["emg"], device=None)
>>> data = ds[0]
>>> type(data["emg"])  # numpy.ndarray
>>>
>>> # Return tensors on GPU with named dimensions
>>> ds = WindowedDataset("data.zip", modalities=["emg"], device="cuda")
>>> data["emg"].device  # cuda:0
>>> data["emg"].names   # ('channel', 'time')

Methods

__getitem__(idx)

Load windows for all modalities.

__getstate__()

Prepare state for pickling (used by multiprocessing workers).

__init__(zarr_path[, split, modalities, ...])

__len__()

__setstate__(state)

Restore state after unpickling (in worker processes).

_ensure_cache_loaded()

Load data into RAM cache if caching is enabled and not yet loaded.

_get_deterministic_position(idx)

Get deterministic window position for given index.

_get_dim_names(modality)

Get dimension names for a modality from metadata.

_get_task_for_recording(rec_idx)

Get the task name for a recording index.

_global_to_local(global_pos)

Convert global position to (recording_idx, local_position).

_load_window(var_path, local_pos, modality)

Load a window for a variable and convert to tensor.

_sample_random_position()

Sample a random valid window position.

_setup_recording_ranges()

Setup valid sampling ranges for each recording.

_to_tensor(data)

Convert data to tensor on target device.

get_sample_shape(modality)

Get the shape of a sample for a given modality (without time dimension).

reseed([seed])

Reseed the random number generator.

get_sample_shape(modality)[source]#

Get the shape of a sample for a given modality (without time dimension).

Parameters:

modality (str) – Modality name.

Returns:

Shape without time dimension.

Return type:

tuple[int, …]

reseed(seed=None)[source]#

Reseed the random number generator.

Parameters:

seed (int | None)

Return type:

None