Note
Go to the end to download the full example code.
Training a deep learning model#
This example shows how to train a deep learning model using the dataset created in the previous example.
Loading the dataset#
To load the dataset we need to use the EMGDatasetLoader class.
Two parameters are required:
data_path: Path to the dataset file.
dataloader_parameters: Parameters for the DataLoader.
from pathlib import Path
from myoverse.datasets.loader import EMGDatasetLoader
from myoverse.datatypes import _Data
# Create a class to handle the target data that doesn't enforce a specific shape
class CustomDataClass(_Data):
def __init__(self, raw_data, sampling_frequency=None):
# Initialize parent class with raw data
super().__init__(raw_data.reshape(1, 60), sampling_frequency, nr_of_dimensions_when_unchunked=2)
# Let's use the built-in IdentityFilter which just passes data through
loader = EMGDatasetLoader(
Path(r"../data/dataset.zarr").resolve(),
dataloader_params={"batch_size": 16, "drop_last": True},
target_data_class=CustomDataClass,
)
Training the model#
from myoverse.models.definitions.raul_net.online.v16 import RaulNetV16
import lightning as L
# Create the model
model = RaulNetV16(
learning_rate=1e-4,
nr_of_input_channels=2,
input_length__samples=192,
nr_of_outputs=60,
nr_of_electrode_grids=5,
nr_of_electrodes_per_grid=64,
# Multiply following by 4, 8, 16 to have a useful network
cnn_encoder_channels=(4, 1, 1),
mlp_encoder_channels=(8, 8),
event_search_kernel_length=31,
event_search_kernel_stride=8,
)
trainer = L.Trainer(
accelerator="auto",
devices=1,
precision="16-mixed",
max_epochs=1,
log_every_n_steps=50,
logger=None,
enable_checkpointing=False,
deterministic=False,
)
trainer.fit(model, datamodule=loader)
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
Sanity Checking: | | 0/? [00:00<?, ?it/s]/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
Sanity Checking: 0%| | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 6.55it/s]
Training: | | 0/? [00:00<?, ?it/s]
Training: 0%| | 0/55 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/55 [00:00<?, ?it/s]
Epoch 0: 2%|▏ | 1/55 [00:00<00:22, 2.44it/s]
Epoch 0: 2%|▏ | 1/55 [00:00<00:22, 2.43it/s, v_num=0, loss_step=55.00]
Epoch 0: 4%|▎ | 2/55 [00:00<00:21, 2.48it/s, v_num=0, loss_step=55.00]
Epoch 0: 4%|▎ | 2/55 [00:00<00:21, 2.48it/s, v_num=0, loss_step=54.50]
Epoch 0: 5%|▌ | 3/55 [00:01<00:20, 2.49it/s, v_num=0, loss_step=54.50]
Epoch 0: 5%|▌ | 3/55 [00:01<00:20, 2.49it/s, v_num=0, loss_step=54.70]
Epoch 0: 7%|▋ | 4/55 [00:01<00:20, 2.50it/s, v_num=0, loss_step=54.70]
Epoch 0: 7%|▋ | 4/55 [00:01<00:20, 2.50it/s, v_num=0, loss_step=54.50]
Epoch 0: 9%|▉ | 5/55 [00:02<00:20, 2.49it/s, v_num=0, loss_step=54.50]
Epoch 0: 9%|▉ | 5/55 [00:02<00:20, 2.49it/s, v_num=0, loss_step=53.80]
Epoch 0: 11%|█ | 6/55 [00:02<00:19, 2.50it/s, v_num=0, loss_step=53.80]
Epoch 0: 11%|█ | 6/55 [00:02<00:19, 2.50it/s, v_num=0, loss_step=54.30]
Epoch 0: 13%|█▎ | 7/55 [00:02<00:19, 2.51it/s, v_num=0, loss_step=54.30]
Epoch 0: 13%|█▎ | 7/55 [00:02<00:19, 2.51it/s, v_num=0, loss_step=54.30]
Epoch 0: 15%|█▍ | 8/55 [00:03<00:18, 2.52it/s, v_num=0, loss_step=54.30]
Epoch 0: 15%|█▍ | 8/55 [00:03<00:18, 2.52it/s, v_num=0, loss_step=53.80]
Epoch 0: 16%|█▋ | 9/55 [00:03<00:18, 2.53it/s, v_num=0, loss_step=53.80]
Epoch 0: 16%|█▋ | 9/55 [00:03<00:18, 2.53it/s, v_num=0, loss_step=54.60]
Epoch 0: 18%|█▊ | 10/55 [00:03<00:17, 2.54it/s, v_num=0, loss_step=54.60]
Epoch 0: 18%|█▊ | 10/55 [00:03<00:17, 2.54it/s, v_num=0, loss_step=53.80]
Epoch 0: 20%|██ | 11/55 [00:04<00:17, 2.54it/s, v_num=0, loss_step=53.80]
Epoch 0: 20%|██ | 11/55 [00:04<00:17, 2.54it/s, v_num=0, loss_step=54.40]
Epoch 0: 22%|██▏ | 12/55 [00:04<00:16, 2.55it/s, v_num=0, loss_step=54.40]
Epoch 0: 22%|██▏ | 12/55 [00:04<00:16, 2.55it/s, v_num=0, loss_step=53.90]
Epoch 0: 24%|██▎ | 13/55 [00:05<00:16, 2.55it/s, v_num=0, loss_step=53.90]
Epoch 0: 24%|██▎ | 13/55 [00:05<00:16, 2.55it/s, v_num=0, loss_step=53.70]
Epoch 0: 25%|██▌ | 14/55 [00:05<00:16, 2.55it/s, v_num=0, loss_step=53.70]
Epoch 0: 25%|██▌ | 14/55 [00:05<00:16, 2.55it/s, v_num=0, loss_step=55.10]
Epoch 0: 27%|██▋ | 15/55 [00:05<00:15, 2.55it/s, v_num=0, loss_step=55.10]
Epoch 0: 27%|██▋ | 15/55 [00:05<00:15, 2.55it/s, v_num=0, loss_step=54.90]
Epoch 0: 29%|██▉ | 16/55 [00:06<00:15, 2.55it/s, v_num=0, loss_step=54.90]
Epoch 0: 29%|██▉ | 16/55 [00:06<00:15, 2.55it/s, v_num=0, loss_step=55.90]
Epoch 0: 31%|███ | 17/55 [00:06<00:14, 2.55it/s, v_num=0, loss_step=55.90]
Epoch 0: 31%|███ | 17/55 [00:06<00:14, 2.55it/s, v_num=0, loss_step=54.20]
Epoch 0: 33%|███▎ | 18/55 [00:07<00:14, 2.55it/s, v_num=0, loss_step=54.20]
Epoch 0: 33%|███▎ | 18/55 [00:07<00:14, 2.55it/s, v_num=0, loss_step=55.10]
Epoch 0: 35%|███▍ | 19/55 [00:07<00:14, 2.54it/s, v_num=0, loss_step=55.10]
Epoch 0: 35%|███▍ | 19/55 [00:07<00:14, 2.54it/s, v_num=0, loss_step=54.60]
Epoch 0: 36%|███▋ | 20/55 [00:07<00:13, 2.55it/s, v_num=0, loss_step=54.60]
Epoch 0: 36%|███▋ | 20/55 [00:07<00:13, 2.55it/s, v_num=0, loss_step=53.40]
Epoch 0: 38%|███▊ | 21/55 [00:08<00:13, 2.55it/s, v_num=0, loss_step=53.40]
Epoch 0: 38%|███▊ | 21/55 [00:08<00:13, 2.55it/s, v_num=0, loss_step=54.00]
Epoch 0: 40%|████ | 22/55 [00:08<00:12, 2.55it/s, v_num=0, loss_step=54.00]
Epoch 0: 40%|████ | 22/55 [00:08<00:12, 2.55it/s, v_num=0, loss_step=53.40]
Epoch 0: 42%|████▏ | 23/55 [00:09<00:12, 2.55it/s, v_num=0, loss_step=53.40]
Epoch 0: 42%|████▏ | 23/55 [00:09<00:12, 2.55it/s, v_num=0, loss_step=54.00]
Epoch 0: 44%|████▎ | 24/55 [00:09<00:12, 2.55it/s, v_num=0, loss_step=54.00]
Epoch 0: 44%|████▎ | 24/55 [00:09<00:12, 2.55it/s, v_num=0, loss_step=53.30]
Epoch 0: 45%|████▌ | 25/55 [00:09<00:11, 2.55it/s, v_num=0, loss_step=53.30]
Epoch 0: 45%|████▌ | 25/55 [00:09<00:11, 2.55it/s, v_num=0, loss_step=52.40]
Epoch 0: 47%|████▋ | 26/55 [00:10<00:11, 2.55it/s, v_num=0, loss_step=52.40]
Epoch 0: 47%|████▋ | 26/55 [00:10<00:11, 2.55it/s, v_num=0, loss_step=53.00]
Epoch 0: 49%|████▉ | 27/55 [00:10<00:10, 2.55it/s, v_num=0, loss_step=53.00]
Epoch 0: 49%|████▉ | 27/55 [00:10<00:10, 2.55it/s, v_num=0, loss_step=52.70]
Epoch 0: 51%|█████ | 28/55 [00:10<00:10, 2.55it/s, v_num=0, loss_step=52.70]
Epoch 0: 51%|█████ | 28/55 [00:10<00:10, 2.55it/s, v_num=0, loss_step=52.00]
Epoch 0: 53%|█████▎ | 29/55 [00:11<00:10, 2.55it/s, v_num=0, loss_step=52.00]
Epoch 0: 53%|█████▎ | 29/55 [00:11<00:10, 2.55it/s, v_num=0, loss_step=52.10]
Epoch 0: 55%|█████▍ | 30/55 [00:11<00:09, 2.55it/s, v_num=0, loss_step=52.10]
Epoch 0: 55%|█████▍ | 30/55 [00:11<00:09, 2.55it/s, v_num=0, loss_step=51.30]
Epoch 0: 56%|█████▋ | 31/55 [00:12<00:09, 2.55it/s, v_num=0, loss_step=51.30]
Epoch 0: 56%|█████▋ | 31/55 [00:12<00:09, 2.55it/s, v_num=0, loss_step=52.40]
Epoch 0: 58%|█████▊ | 32/55 [00:12<00:09, 2.55it/s, v_num=0, loss_step=52.40]
Epoch 0: 58%|█████▊ | 32/55 [00:12<00:09, 2.55it/s, v_num=0, loss_step=51.30]
Epoch 0: 60%|██████ | 33/55 [00:12<00:08, 2.55it/s, v_num=0, loss_step=51.30]
Epoch 0: 60%|██████ | 33/55 [00:12<00:08, 2.55it/s, v_num=0, loss_step=51.40]
Epoch 0: 62%|██████▏ | 34/55 [00:13<00:08, 2.55it/s, v_num=0, loss_step=51.40]
Epoch 0: 62%|██████▏ | 34/55 [00:13<00:08, 2.55it/s, v_num=0, loss_step=50.60]
Epoch 0: 64%|██████▎ | 35/55 [00:13<00:07, 2.55it/s, v_num=0, loss_step=50.60]
Epoch 0: 64%|██████▎ | 35/55 [00:13<00:07, 2.55it/s, v_num=0, loss_step=49.80]
Epoch 0: 65%|██████▌ | 36/55 [00:14<00:07, 2.55it/s, v_num=0, loss_step=49.80]
Epoch 0: 65%|██████▌ | 36/55 [00:14<00:07, 2.55it/s, v_num=0, loss_step=50.20]
Epoch 0: 67%|██████▋ | 37/55 [00:14<00:07, 2.55it/s, v_num=0, loss_step=50.20]
Epoch 0: 67%|██████▋ | 37/55 [00:14<00:07, 2.55it/s, v_num=0, loss_step=49.20]
Epoch 0: 69%|██████▉ | 38/55 [00:14<00:06, 2.55it/s, v_num=0, loss_step=49.20]
Epoch 0: 69%|██████▉ | 38/55 [00:14<00:06, 2.55it/s, v_num=0, loss_step=50.10]
Epoch 0: 71%|███████ | 39/55 [00:15<00:06, 2.55it/s, v_num=0, loss_step=50.10]
Epoch 0: 71%|███████ | 39/55 [00:15<00:06, 2.55it/s, v_num=0, loss_step=49.90]
Epoch 0: 73%|███████▎ | 40/55 [00:15<00:05, 2.55it/s, v_num=0, loss_step=49.90]
Epoch 0: 73%|███████▎ | 40/55 [00:15<00:05, 2.55it/s, v_num=0, loss_step=49.80]
Epoch 0: 75%|███████▍ | 41/55 [00:16<00:05, 2.55it/s, v_num=0, loss_step=49.80]
Epoch 0: 75%|███████▍ | 41/55 [00:16<00:05, 2.55it/s, v_num=0, loss_step=48.50]
Epoch 0: 76%|███████▋ | 42/55 [00:16<00:05, 2.55it/s, v_num=0, loss_step=48.50]
Epoch 0: 76%|███████▋ | 42/55 [00:16<00:05, 2.55it/s, v_num=0, loss_step=48.70]
Epoch 0: 78%|███████▊ | 43/55 [00:16<00:04, 2.55it/s, v_num=0, loss_step=48.70]
Epoch 0: 78%|███████▊ | 43/55 [00:16<00:04, 2.55it/s, v_num=0, loss_step=48.20]
Epoch 0: 80%|████████ | 44/55 [00:17<00:04, 2.55it/s, v_num=0, loss_step=48.20]
Epoch 0: 80%|████████ | 44/55 [00:17<00:04, 2.55it/s, v_num=0, loss_step=48.70]
Epoch 0: 82%|████████▏ | 45/55 [00:17<00:03, 2.55it/s, v_num=0, loss_step=48.70]
Epoch 0: 82%|████████▏ | 45/55 [00:17<00:03, 2.55it/s, v_num=0, loss_step=48.40]
Epoch 0: 84%|████████▎ | 46/55 [00:18<00:03, 2.55it/s, v_num=0, loss_step=48.40]
Epoch 0: 84%|████████▎ | 46/55 [00:18<00:03, 2.55it/s, v_num=0, loss_step=48.10]
Epoch 0: 85%|████████▌ | 47/55 [00:18<00:03, 2.56it/s, v_num=0, loss_step=48.10]
Epoch 0: 85%|████████▌ | 47/55 [00:18<00:03, 2.56it/s, v_num=0, loss_step=47.50]
Epoch 0: 87%|████████▋ | 48/55 [00:18<00:02, 2.56it/s, v_num=0, loss_step=47.50]
Epoch 0: 87%|████████▋ | 48/55 [00:18<00:02, 2.56it/s, v_num=0, loss_step=48.10]
Epoch 0: 89%|████████▉ | 49/55 [00:19<00:02, 2.56it/s, v_num=0, loss_step=48.10]
Epoch 0: 89%|████████▉ | 49/55 [00:19<00:02, 2.56it/s, v_num=0, loss_step=47.80]
Epoch 0: 91%|█████████ | 50/55 [00:19<00:01, 2.56it/s, v_num=0, loss_step=47.80]
Epoch 0: 91%|█████████ | 50/55 [00:19<00:01, 2.56it/s, v_num=0, loss_step=48.40]
Epoch 0: 93%|█████████▎| 51/55 [00:19<00:01, 2.56it/s, v_num=0, loss_step=48.40]
Epoch 0: 93%|█████████▎| 51/55 [00:19<00:01, 2.56it/s, v_num=0, loss_step=47.50]
Epoch 0: 95%|█████████▍| 52/55 [00:20<00:01, 2.56it/s, v_num=0, loss_step=47.50]
Epoch 0: 95%|█████████▍| 52/55 [00:20<00:01, 2.56it/s, v_num=0, loss_step=47.40]
Epoch 0: 96%|█████████▋| 53/55 [00:20<00:00, 2.56it/s, v_num=0, loss_step=47.40]
Epoch 0: 96%|█████████▋| 53/55 [00:20<00:00, 2.56it/s, v_num=0, loss_step=48.00]
Epoch 0: 98%|█████████▊| 54/55 [00:21<00:00, 2.56it/s, v_num=0, loss_step=48.00]
Epoch 0: 98%|█████████▊| 54/55 [00:21<00:00, 2.56it/s, v_num=0, loss_step=48.30]
Epoch 0: 100%|██████████| 55/55 [00:21<00:00, 2.56it/s, v_num=0, loss_step=48.30]
Epoch 0: 100%|██████████| 55/55 [00:21<00:00, 2.56it/s, v_num=0, loss_step=48.10]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 9.93it/s]
Epoch 0: 100%|██████████| 55/55 [00:21<00:00, 2.55it/s, v_num=0, loss_step=48.10, val_loss=48.30]
Epoch 0: 100%|██████████| 55/55 [00:21<00:00, 2.55it/s, v_num=0, loss_step=48.10, val_loss=48.30, loss_epoch=51.70]
Epoch 0: 100%|██████████| 55/55 [00:21<00:00, 2.55it/s, v_num=0, loss_step=48.10, val_loss=48.30, loss_epoch=51.70]
Total running time of the script: (0 minutes 22.564 seconds)
Estimated memory usage: 729 MB