DataModule#
- class myoverse.datasets.datamodule.DataModule(data_path, inputs=('emg',), targets=('kinematics',), batch_size=32, window_size=200, window_stride=None, n_windows_per_epoch=None, num_workers=4, train_transform=None, val_transform=None, test_transform=None, target_transform=None, pin_memory=True, persistent_workers=True, device=None, dtype=torch.float32, cache_in_ram=True)[source]#
Lightning DataModule for supervised learning.
Wraps SupervisedDataset instances for train/val/test splits and provides DataLoaders.
- Parameters:
data_path (Path | str) – Path to the Zarr dataset.
inputs (Sequence[str]) – Modality names to use as model inputs.
targets (Sequence[str]) – Modality names to use as model targets.
batch_size (int) – Batch size for all dataloaders.
window_size (int) – Window size in samples.
window_stride (int | None) – Window stride for validation/test.
n_windows_per_epoch (int | None) – Number of random windows per training epoch.
num_workers (int) – Number of dataloader workers.
train_transform (Callable | None) – Transform for training inputs.
val_transform (Callable | None) – Transform for validation inputs.
test_transform (Callable | None) – Transform for test inputs.
target_transform (Callable | None) – Transform for targets.
pin_memory (bool) – Pin memory for faster GPU transfer.
persistent_workers (bool) – Keep workers alive between epochs.
device (torch.device | str | None) – Output device (‘cpu’, ‘cuda’, or None for numpy).
dtype (torch.dtype) – Data type for tensors.
cache_in_ram (bool) – Cache entire split in RAM.
Examples
>>> dm = DataModule( ... "data.zip", ... inputs=["emg"], ... targets=["kinematics"], ... window_size=200, ... n_windows_per_epoch=10000, ... device="cuda", ... ) >>> dm.setup("fit") >>> for inputs, targets in dm.train_dataloader(): ... # inputs: Tensor of shape (batch, channels, time) ... # targets: Tensor of shape (batch, joints) ... pass
- prepare_data_per_node#
If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.
- allow_zero_length_dataloader_with_multiple_devices#
If True, dataloader with zero length within local rank is allowed. Default value is False.
Methods
__init__(data_path[, inputs, targets, ...])setup([stage])Setup datasets for each stage.
An iterable or collection of iterables specifying test samples.
An iterable or collection of iterables specifying training samples.
An iterable or collection of iterables specifying validation samples.
- setup(stage=None)[source]#
Setup datasets for each stage.
- Parameters:
stage (str | None)
- Return type:
None
- test_dataloader()[source]#
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.- Return type:
- train_dataloader()[source]#
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return type:
- val_dataloader()[source]#
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.- Return type: