DataModule#

class myoverse.datasets.datamodule.DataModule(data_path, inputs=('emg',), targets=('kinematics',), batch_size=32, window_size=200, window_stride=None, n_windows_per_epoch=None, num_workers=4, train_transform=None, val_transform=None, test_transform=None, target_transform=None, pin_memory=True, persistent_workers=True, device=None, dtype=torch.float32, cache_in_ram=True)[source]#

Lightning DataModule for supervised learning.

Wraps SupervisedDataset instances for train/val/test splits and provides DataLoaders.

Parameters:
  • data_path (Path | str) – Path to the Zarr dataset.

  • inputs (Sequence[str]) – Modality names to use as model inputs.

  • targets (Sequence[str]) – Modality names to use as model targets.

  • batch_size (int) – Batch size for all dataloaders.

  • window_size (int) – Window size in samples.

  • window_stride (int | None) – Window stride for validation/test.

  • n_windows_per_epoch (int | None) – Number of random windows per training epoch.

  • num_workers (int) – Number of dataloader workers.

  • train_transform (Callable | None) – Transform for training inputs.

  • val_transform (Callable | None) – Transform for validation inputs.

  • test_transform (Callable | None) – Transform for test inputs.

  • target_transform (Callable | None) – Transform for targets.

  • pin_memory (bool) – Pin memory for faster GPU transfer.

  • persistent_workers (bool) – Keep workers alive between epochs.

  • device (torch.device | str | None) – Output device (‘cpu’, ‘cuda’, or None for numpy).

  • dtype (torch.dtype) – Data type for tensors.

  • cache_in_ram (bool) – Cache entire split in RAM.

Examples

>>> dm = DataModule(
...     "data.zip",
...     inputs=["emg"],
...     targets=["kinematics"],
...     window_size=200,
...     n_windows_per_epoch=10000,
...     device="cuda",
... )
>>> dm.setup("fit")
>>> for inputs, targets in dm.train_dataloader():
...     # inputs: Tensor of shape (batch, channels, time)
...     # targets: Tensor of shape (batch, joints)
...     pass
prepare_data_per_node#

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices#

If True, dataloader with zero length within local rank is allowed. Default value is False.

Methods

__init__(data_path[, inputs, targets, ...])

setup([stage])

Setup datasets for each stage.

test_dataloader()

An iterable or collection of iterables specifying test samples.

train_dataloader()

An iterable or collection of iterables specifying training samples.

val_dataloader()

An iterable or collection of iterables specifying validation samples.

setup(stage=None)[source]#

Setup datasets for each stage.

Parameters:

stage (str | None)

Return type:

None

test_dataloader()[source]#

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Return type:

DataLoader

train_dataloader()[source]#

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return type:

DataLoader

val_dataloader()[source]#

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.

Return type:

DataLoader