Note
Go to the end to download the full example code.
Training a deep learning model#
This example shows how to train a deep learning model using the dataset created in the previous example.
Loading the dataset#
To load the dataset we need to use the EMGDatasetLoader class.
Two parameters are required:
data_path: Path to the dataset file.
dataloader_parameters: Parameters for the DataLoader.
from pathlib import Path
from myoverse.datasets.loader import EMGDatasetLoader
from myoverse.datatypes import _Data
# Create a class to handle the target data that doesn't enforce a specific shape
class CustomDataClass(_Data):
def __init__(self, raw_data, sampling_frequency=None):
# Initialize parent class with raw data
super().__init__(raw_data.reshape(1, 60), sampling_frequency, nr_of_dimensions_when_unchunked=2)
# Let's use the built-in IdentityFilter which just passes data through
loader = EMGDatasetLoader(
Path(r"../data/dataset.zarr").resolve(),
dataloader_params={"batch_size": 16, "drop_last": True},
target_data_class=CustomDataClass,
)
Training the model#
from myoverse.models.definitions.raul_net.online.v16 import RaulNetV16
import lightning as L
# Create the model
model = RaulNetV16(
learning_rate=1e-4,
nr_of_input_channels=2,
input_length__samples=192,
nr_of_outputs=60,
nr_of_electrode_grids=5,
nr_of_electrodes_per_grid=64,
# Multiply following by 4, 8, 16 to have a useful network
cnn_encoder_channels=(4, 1, 1),
mlp_encoder_channels=(8, 8),
event_search_kernel_length=31,
event_search_kernel_stride=8,
)
trainer = L.Trainer(
accelerator="auto",
devices=1,
precision="16-mixed",
max_epochs=1,
log_every_n_steps=50,
logger=None,
enable_checkpointing=False,
deterministic=False,
)
trainer.fit(model, datamodule=loader)
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
Sanity Checking: | | 0/? [00:00<?, ?it/s]/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
Sanity Checking: 0%| | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.63it/s]
Training: | | 0/? [00:00<?, ?it/s]
Training: 0%| | 0/55 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/55 [00:00<?, ?it/s]
Epoch 0: 2%|▏ | 1/55 [00:01<01:26, 0.62it/s]
Epoch 0: 2%|▏ | 1/55 [00:01<01:26, 0.62it/s, v_num=0, loss_step=54.80]
Epoch 0: 4%|▎ | 2/55 [00:03<01:23, 0.63it/s, v_num=0, loss_step=54.80]
Epoch 0: 4%|▎ | 2/55 [00:03<01:23, 0.63it/s, v_num=0, loss_step=55.00]
Epoch 0: 5%|▌ | 3/55 [00:04<01:21, 0.64it/s, v_num=0, loss_step=55.00]
Epoch 0: 5%|▌ | 3/55 [00:04<01:21, 0.64it/s, v_num=0, loss_step=54.50]
Epoch 0: 7%|▋ | 4/55 [00:06<01:19, 0.64it/s, v_num=0, loss_step=54.50]
Epoch 0: 7%|▋ | 4/55 [00:06<01:19, 0.64it/s, v_num=0, loss_step=54.80]
Epoch 0: 9%|▉ | 5/55 [00:07<01:18, 0.64it/s, v_num=0, loss_step=54.80]
Epoch 0: 9%|▉ | 5/55 [00:07<01:18, 0.64it/s, v_num=0, loss_step=55.20]
Epoch 0: 11%|█ | 6/55 [00:09<01:16, 0.64it/s, v_num=0, loss_step=55.20]
Epoch 0: 11%|█ | 6/55 [00:09<01:16, 0.64it/s, v_num=0, loss_step=54.30]
Epoch 0: 13%|█▎ | 7/55 [00:10<01:15, 0.64it/s, v_num=0, loss_step=54.30]
Epoch 0: 13%|█▎ | 7/55 [00:10<01:15, 0.64it/s, v_num=0, loss_step=53.80]
Epoch 0: 15%|█▍ | 8/55 [00:12<01:13, 0.64it/s, v_num=0, loss_step=53.80]
Epoch 0: 15%|█▍ | 8/55 [00:12<01:13, 0.64it/s, v_num=0, loss_step=55.00]
Epoch 0: 16%|█▋ | 9/55 [00:14<01:11, 0.64it/s, v_num=0, loss_step=55.00]
Epoch 0: 16%|█▋ | 9/55 [00:14<01:11, 0.64it/s, v_num=0, loss_step=54.40]
Epoch 0: 18%|█▊ | 10/55 [00:15<01:10, 0.64it/s, v_num=0, loss_step=54.40]
Epoch 0: 18%|█▊ | 10/55 [00:15<01:10, 0.64it/s, v_num=0, loss_step=54.20]
Epoch 0: 20%|██ | 11/55 [00:17<01:08, 0.64it/s, v_num=0, loss_step=54.20]
Epoch 0: 20%|██ | 11/55 [00:17<01:08, 0.64it/s, v_num=0, loss_step=54.30]
Epoch 0: 22%|██▏ | 12/55 [00:18<01:07, 0.64it/s, v_num=0, loss_step=54.30]
Epoch 0: 22%|██▏ | 12/55 [00:18<01:07, 0.64it/s, v_num=0, loss_step=54.10]
Epoch 0: 24%|██▎ | 13/55 [00:20<01:05, 0.64it/s, v_num=0, loss_step=54.10]
Epoch 0: 24%|██▎ | 13/55 [00:20<01:05, 0.64it/s, v_num=0, loss_step=54.10]
Epoch 0: 25%|██▌ | 14/55 [00:21<01:04, 0.64it/s, v_num=0, loss_step=54.10]
Epoch 0: 25%|██▌ | 14/55 [00:21<01:04, 0.64it/s, v_num=0, loss_step=53.70]
Epoch 0: 27%|██▋ | 15/55 [00:23<01:02, 0.64it/s, v_num=0, loss_step=53.70]
Epoch 0: 27%|██▋ | 15/55 [00:23<01:02, 0.64it/s, v_num=0, loss_step=53.30]
Epoch 0: 29%|██▉ | 16/55 [00:25<01:00, 0.64it/s, v_num=0, loss_step=53.30]
Epoch 0: 29%|██▉ | 16/55 [00:25<01:00, 0.64it/s, v_num=0, loss_step=53.70]
Epoch 0: 31%|███ | 17/55 [00:26<00:59, 0.64it/s, v_num=0, loss_step=53.70]
Epoch 0: 31%|███ | 17/55 [00:26<00:59, 0.64it/s, v_num=0, loss_step=53.90]
Epoch 0: 33%|███▎ | 18/55 [00:28<00:57, 0.64it/s, v_num=0, loss_step=53.90]
Epoch 0: 33%|███▎ | 18/55 [00:28<00:57, 0.64it/s, v_num=0, loss_step=55.80]
Epoch 0: 35%|███▍ | 19/55 [00:29<00:56, 0.64it/s, v_num=0, loss_step=55.80]
Epoch 0: 35%|███▍ | 19/55 [00:29<00:56, 0.64it/s, v_num=0, loss_step=53.90]
Epoch 0: 36%|███▋ | 20/55 [00:31<00:54, 0.64it/s, v_num=0, loss_step=53.90]
Epoch 0: 36%|███▋ | 20/55 [00:31<00:54, 0.64it/s, v_num=0, loss_step=53.60]
Epoch 0: 38%|███▊ | 21/55 [00:32<00:53, 0.64it/s, v_num=0, loss_step=53.60]
Epoch 0: 38%|███▊ | 21/55 [00:32<00:53, 0.64it/s, v_num=0, loss_step=53.50]
Epoch 0: 40%|████ | 22/55 [00:34<00:51, 0.64it/s, v_num=0, loss_step=53.50]
Epoch 0: 40%|████ | 22/55 [00:34<00:51, 0.64it/s, v_num=0, loss_step=54.00]
Epoch 0: 42%|████▏ | 23/55 [00:35<00:49, 0.64it/s, v_num=0, loss_step=54.00]
Epoch 0: 42%|████▏ | 23/55 [00:35<00:49, 0.64it/s, v_num=0, loss_step=53.80]
Epoch 0: 44%|████▎ | 24/55 [00:37<00:48, 0.64it/s, v_num=0, loss_step=53.80]
Epoch 0: 44%|████▎ | 24/55 [00:37<00:48, 0.64it/s, v_num=0, loss_step=53.40]
Epoch 0: 45%|████▌ | 25/55 [00:38<00:46, 0.64it/s, v_num=0, loss_step=53.40]
Epoch 0: 45%|████▌ | 25/55 [00:38<00:46, 0.64it/s, v_num=0, loss_step=53.50]
Epoch 0: 47%|████▋ | 26/55 [00:40<00:45, 0.64it/s, v_num=0, loss_step=53.50]
Epoch 0: 47%|████▋ | 26/55 [00:40<00:45, 0.64it/s, v_num=0, loss_step=53.30]
Epoch 0: 49%|████▉ | 27/55 [00:42<00:43, 0.64it/s, v_num=0, loss_step=53.30]
Epoch 0: 49%|████▉ | 27/55 [00:42<00:43, 0.64it/s, v_num=0, loss_step=52.10]
Epoch 0: 51%|█████ | 28/55 [00:43<00:42, 0.64it/s, v_num=0, loss_step=52.10]
Epoch 0: 51%|█████ | 28/55 [00:43<00:42, 0.64it/s, v_num=0, loss_step=52.70]
Epoch 0: 53%|█████▎ | 29/55 [00:45<00:40, 0.64it/s, v_num=0, loss_step=52.70]
Epoch 0: 53%|█████▎ | 29/55 [00:45<00:40, 0.64it/s, v_num=0, loss_step=52.20]
Epoch 0: 55%|█████▍ | 30/55 [00:46<00:38, 0.64it/s, v_num=0, loss_step=52.20]
Epoch 0: 55%|█████▍ | 30/55 [00:46<00:38, 0.64it/s, v_num=0, loss_step=51.20]
Epoch 0: 56%|█████▋ | 31/55 [00:48<00:37, 0.64it/s, v_num=0, loss_step=51.20]
Epoch 0: 56%|█████▋ | 31/55 [00:48<00:37, 0.64it/s, v_num=0, loss_step=51.80]
Epoch 0: 58%|█████▊ | 32/55 [00:49<00:35, 0.64it/s, v_num=0, loss_step=51.80]
Epoch 0: 58%|█████▊ | 32/55 [00:49<00:35, 0.64it/s, v_num=0, loss_step=52.10]
Epoch 0: 60%|██████ | 33/55 [00:51<00:34, 0.64it/s, v_num=0, loss_step=52.10]
Epoch 0: 60%|██████ | 33/55 [00:51<00:34, 0.64it/s, v_num=0, loss_step=52.20]
Epoch 0: 62%|██████▏ | 34/55 [00:52<00:32, 0.64it/s, v_num=0, loss_step=52.20]
Epoch 0: 62%|██████▏ | 34/55 [00:52<00:32, 0.64it/s, v_num=0, loss_step=50.80]
Epoch 0: 64%|██████▎ | 35/55 [00:54<00:31, 0.64it/s, v_num=0, loss_step=50.80]
Epoch 0: 64%|██████▎ | 35/55 [00:54<00:31, 0.64it/s, v_num=0, loss_step=51.20]
Epoch 0: 65%|██████▌ | 36/55 [00:55<00:29, 0.64it/s, v_num=0, loss_step=51.20]
Epoch 0: 65%|██████▌ | 36/55 [00:55<00:29, 0.64it/s, v_num=0, loss_step=51.30]
Epoch 0: 67%|██████▋ | 37/55 [00:57<00:27, 0.64it/s, v_num=0, loss_step=51.30]
Epoch 0: 67%|██████▋ | 37/55 [00:57<00:27, 0.64it/s, v_num=0, loss_step=51.10]
Epoch 0: 69%|██████▉ | 38/55 [00:59<00:26, 0.64it/s, v_num=0, loss_step=51.10]
Epoch 0: 69%|██████▉ | 38/55 [00:59<00:26, 0.64it/s, v_num=0, loss_step=50.40]
Epoch 0: 71%|███████ | 39/55 [01:00<00:24, 0.64it/s, v_num=0, loss_step=50.40]
Epoch 0: 71%|███████ | 39/55 [01:00<00:24, 0.64it/s, v_num=0, loss_step=50.50]
Epoch 0: 73%|███████▎ | 40/55 [01:02<00:23, 0.64it/s, v_num=0, loss_step=50.50]
Epoch 0: 73%|███████▎ | 40/55 [01:02<00:23, 0.64it/s, v_num=0, loss_step=51.00]
Epoch 0: 75%|███████▍ | 41/55 [01:03<00:21, 0.64it/s, v_num=0, loss_step=51.00]
Epoch 0: 75%|███████▍ | 41/55 [01:03<00:21, 0.64it/s, v_num=0, loss_step=50.60]
Epoch 0: 76%|███████▋ | 42/55 [01:05<00:20, 0.64it/s, v_num=0, loss_step=50.60]
Epoch 0: 76%|███████▋ | 42/55 [01:05<00:20, 0.64it/s, v_num=0, loss_step=50.00]
Epoch 0: 78%|███████▊ | 43/55 [01:06<00:18, 0.64it/s, v_num=0, loss_step=50.00]
Epoch 0: 78%|███████▊ | 43/55 [01:06<00:18, 0.64it/s, v_num=0, loss_step=50.20]
Epoch 0: 80%|████████ | 44/55 [01:08<00:17, 0.64it/s, v_num=0, loss_step=50.20]
Epoch 0: 80%|████████ | 44/55 [01:08<00:17, 0.64it/s, v_num=0, loss_step=50.50]
Epoch 0: 82%|████████▏ | 45/55 [01:09<00:15, 0.64it/s, v_num=0, loss_step=50.50]
Epoch 0: 82%|████████▏ | 45/55 [01:09<00:15, 0.64it/s, v_num=0, loss_step=50.20]
Epoch 0: 84%|████████▎ | 46/55 [01:11<00:13, 0.64it/s, v_num=0, loss_step=50.20]
Epoch 0: 84%|████████▎ | 46/55 [01:11<00:13, 0.64it/s, v_num=0, loss_step=49.80]
Epoch 0: 85%|████████▌ | 47/55 [01:13<00:12, 0.64it/s, v_num=0, loss_step=49.80]
Epoch 0: 85%|████████▌ | 47/55 [01:13<00:12, 0.64it/s, v_num=0, loss_step=50.20]
Epoch 0: 87%|████████▋ | 48/55 [01:14<00:10, 0.64it/s, v_num=0, loss_step=50.20]
Epoch 0: 87%|████████▋ | 48/55 [01:14<00:10, 0.64it/s, v_num=0, loss_step=50.70]
Epoch 0: 89%|████████▉ | 49/55 [01:16<00:09, 0.64it/s, v_num=0, loss_step=50.70]
Epoch 0: 89%|████████▉ | 49/55 [01:16<00:09, 0.64it/s, v_num=0, loss_step=48.90]
Epoch 0: 91%|█████████ | 50/55 [01:17<00:07, 0.64it/s, v_num=0, loss_step=48.90]
Epoch 0: 91%|█████████ | 50/55 [01:17<00:07, 0.64it/s, v_num=0, loss_step=49.30]
Epoch 0: 93%|█████████▎| 51/55 [01:19<00:06, 0.64it/s, v_num=0, loss_step=49.30]
Epoch 0: 93%|█████████▎| 51/55 [01:19<00:06, 0.64it/s, v_num=0, loss_step=49.90]
Epoch 0: 95%|█████████▍| 52/55 [01:20<00:04, 0.64it/s, v_num=0, loss_step=49.90]
Epoch 0: 95%|█████████▍| 52/55 [01:20<00:04, 0.64it/s, v_num=0, loss_step=49.70]
Epoch 0: 96%|█████████▋| 53/55 [01:22<00:03, 0.64it/s, v_num=0, loss_step=49.70]
Epoch 0: 96%|█████████▋| 53/55 [01:22<00:03, 0.64it/s, v_num=0, loss_step=49.60]
Epoch 0: 98%|█████████▊| 54/55 [01:23<00:01, 0.64it/s, v_num=0, loss_step=49.60]
Epoch 0: 98%|█████████▊| 54/55 [01:23<00:01, 0.64it/s, v_num=0, loss_step=50.30]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00, 0.64it/s, v_num=0, loss_step=50.30]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00, 0.64it/s, v_num=0, loss_step=50.60]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.78it/s]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00, 0.64it/s, v_num=0, loss_step=50.60, val_loss=49.60]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00, 0.64it/s, v_num=0, loss_step=50.60, val_loss=49.60, loss_epoch=52.30]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00, 0.64it/s, v_num=0, loss_step=50.60, val_loss=49.60, loss_epoch=52.30]
Total running time of the script: (1 minutes 27.434 seconds)
Estimated memory usage: 589 MB