Note
Go to the end to download the full example code.
Training a deep learning model#
This example shows how to train a deep learning model using the dataset created in the previous example.
Loading the dataset#
To load the dataset we need to use the EMGDatasetLoader class.
Two parameters are required:
data_path: Path to the dataset file.
dataloader_parameters: Parameters for the DataLoader.
from pathlib import Path
from myoverse.datasets.loader import EMGDatasetLoader
from myoverse.datatypes import _Data
# Create a class to handle the target data that doesn't enforce a specific shape
class CustomDataClass(_Data):
def __init__(self, raw_data, sampling_frequency=None):
# Initialize parent class with raw data
super().__init__(raw_data.reshape(1, 60), sampling_frequency, nr_of_dimensions_when_unchunked=2)
# Let's use the built-in IdentityFilter which just passes data through
loader = EMGDatasetLoader(
Path(r"../data/dataset.zarr").resolve(),
dataloader_params={"batch_size": 16, "drop_last": True},
target_data_class=CustomDataClass,
)
Training the model#
from myoverse.models.definitions.raul_net.online.v16 import RaulNetV16
import lightning as L
# Create the model
model = RaulNetV16(
learning_rate=1e-4,
nr_of_input_channels=2,
input_length__samples=192,
nr_of_outputs=60,
nr_of_electrode_grids=5,
nr_of_electrodes_per_grid=64,
# Multiply following by 4, 8, 16 to have a useful network
cnn_encoder_channels=(4, 1, 1),
mlp_encoder_channels=(8, 8),
event_search_kernel_length=31,
event_search_kernel_stride=8,
)
trainer = L.Trainer(
accelerator="auto",
devices=1,
precision="16-mixed",
max_epochs=1,
log_every_n_steps=50,
logger=None,
enable_checkpointing=False,
deterministic=False,
)
trainer.fit(model, datamodule=loader)
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
Sanity Checking: | | 0/? [00:00<?, ?it/s]/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
Sanity Checking: 0%| | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.68it/s]
Training: | | 0/? [00:00<?, ?it/s]
Training: 0%| | 0/55 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/55 [00:00<?, ?it/s]
Epoch 0: 2%|▏ | 1/55 [00:01<01:13, 0.73it/s]
Epoch 0: 2%|▏ | 1/55 [00:01<01:13, 0.73it/s, v_num=0, loss_step=54.10]
Epoch 0: 4%|▎ | 2/55 [00:02<01:10, 0.75it/s, v_num=0, loss_step=54.10]
Epoch 0: 4%|▎ | 2/55 [00:02<01:10, 0.75it/s, v_num=0, loss_step=54.60]
Epoch 0: 5%|▌ | 3/55 [00:04<01:09, 0.75it/s, v_num=0, loss_step=54.60]
Epoch 0: 5%|▌ | 3/55 [00:04<01:09, 0.75it/s, v_num=0, loss_step=54.40]
Epoch 0: 7%|▋ | 4/55 [00:05<01:09, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 7%|▋ | 4/55 [00:05<01:09, 0.74it/s, v_num=0, loss_step=54.90]
Epoch 0: 9%|▉ | 5/55 [00:06<01:07, 0.74it/s, v_num=0, loss_step=54.90]
Epoch 0: 9%|▉ | 5/55 [00:06<01:07, 0.74it/s, v_num=0, loss_step=53.90]
Epoch 0: 11%|█ | 6/55 [00:08<01:06, 0.74it/s, v_num=0, loss_step=53.90]
Epoch 0: 11%|█ | 6/55 [00:08<01:06, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 13%|█▎ | 7/55 [00:09<01:05, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 13%|█▎ | 7/55 [00:09<01:05, 0.74it/s, v_num=0, loss_step=53.70]
Epoch 0: 15%|█▍ | 8/55 [00:10<01:03, 0.74it/s, v_num=0, loss_step=53.70]
Epoch 0: 15%|█▍ | 8/55 [00:10<01:03, 0.74it/s, v_num=0, loss_step=55.10]
Epoch 0: 16%|█▋ | 9/55 [00:12<01:02, 0.74it/s, v_num=0, loss_step=55.10]
Epoch 0: 16%|█▋ | 9/55 [00:12<01:02, 0.74it/s, v_num=0, loss_step=54.20]
Epoch 0: 18%|█▊ | 10/55 [00:13<01:00, 0.74it/s, v_num=0, loss_step=54.20]
Epoch 0: 18%|█▊ | 10/55 [00:13<01:00, 0.74it/s, v_num=0, loss_step=54.10]
Epoch 0: 20%|██ | 11/55 [00:14<00:59, 0.74it/s, v_num=0, loss_step=54.10]
Epoch 0: 20%|██ | 11/55 [00:14<00:59, 0.74it/s, v_num=0, loss_step=54.20]
Epoch 0: 22%|██▏ | 12/55 [00:16<00:58, 0.74it/s, v_num=0, loss_step=54.20]
Epoch 0: 22%|██▏ | 12/55 [00:16<00:58, 0.74it/s, v_num=0, loss_step=54.60]
Epoch 0: 24%|██▎ | 13/55 [00:17<00:56, 0.74it/s, v_num=0, loss_step=54.60]
Epoch 0: 24%|██▎ | 13/55 [00:17<00:56, 0.74it/s, v_num=0, loss_step=53.90]
Epoch 0: 25%|██▌ | 14/55 [00:18<00:55, 0.74it/s, v_num=0, loss_step=53.90]
Epoch 0: 25%|██▌ | 14/55 [00:18<00:55, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 27%|██▋ | 15/55 [00:20<00:53, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 27%|██▋ | 15/55 [00:20<00:53, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 29%|██▉ | 16/55 [00:21<00:52, 0.74it/s, v_num=0, loss_step=54.40]
Epoch 0: 29%|██▉ | 16/55 [00:21<00:52, 0.74it/s, v_num=0, loss_step=55.10]
Epoch 0: 31%|███ | 17/55 [00:22<00:51, 0.74it/s, v_num=0, loss_step=55.10]
Epoch 0: 31%|███ | 17/55 [00:22<00:51, 0.74it/s, v_num=0, loss_step=53.30]
Epoch 0: 33%|███▎ | 18/55 [00:24<00:49, 0.74it/s, v_num=0, loss_step=53.30]
Epoch 0: 33%|███▎ | 18/55 [00:24<00:49, 0.74it/s, v_num=0, loss_step=54.00]
Epoch 0: 35%|███▍ | 19/55 [00:25<00:48, 0.74it/s, v_num=0, loss_step=54.00]
Epoch 0: 35%|███▍ | 19/55 [00:25<00:48, 0.74it/s, v_num=0, loss_step=53.20]
Epoch 0: 36%|███▋ | 20/55 [00:26<00:47, 0.74it/s, v_num=0, loss_step=53.20]
Epoch 0: 36%|███▋ | 20/55 [00:26<00:47, 0.74it/s, v_num=0, loss_step=53.70]
Epoch 0: 38%|███▊ | 21/55 [00:28<00:45, 0.74it/s, v_num=0, loss_step=53.70]
Epoch 0: 38%|███▊ | 21/55 [00:28<00:45, 0.74it/s, v_num=0, loss_step=52.50]
Epoch 0: 40%|████ | 22/55 [00:29<00:44, 0.74it/s, v_num=0, loss_step=52.50]
Epoch 0: 40%|████ | 22/55 [00:29<00:44, 0.74it/s, v_num=0, loss_step=53.00]
Epoch 0: 42%|████▏ | 23/55 [00:30<00:43, 0.74it/s, v_num=0, loss_step=53.00]
Epoch 0: 42%|████▏ | 23/55 [00:30<00:43, 0.74it/s, v_num=0, loss_step=52.10]
Epoch 0: 44%|████▎ | 24/55 [00:32<00:41, 0.74it/s, v_num=0, loss_step=52.10]
Epoch 0: 44%|████▎ | 24/55 [00:32<00:41, 0.74it/s, v_num=0, loss_step=52.60]
Epoch 0: 45%|████▌ | 25/55 [00:33<00:40, 0.74it/s, v_num=0, loss_step=52.60]
Epoch 0: 45%|████▌ | 25/55 [00:33<00:40, 0.74it/s, v_num=0, loss_step=52.30]
Epoch 0: 47%|████▋ | 26/55 [00:34<00:38, 0.74it/s, v_num=0, loss_step=52.30]
Epoch 0: 47%|████▋ | 26/55 [00:34<00:38, 0.74it/s, v_num=0, loss_step=52.00]
Epoch 0: 49%|████▉ | 27/55 [00:36<00:37, 0.74it/s, v_num=0, loss_step=52.00]
Epoch 0: 49%|████▉ | 27/55 [00:36<00:37, 0.74it/s, v_num=0, loss_step=52.10]
Epoch 0: 51%|█████ | 28/55 [00:37<00:36, 0.74it/s, v_num=0, loss_step=52.10]
Epoch 0: 51%|█████ | 28/55 [00:37<00:36, 0.74it/s, v_num=0, loss_step=51.50]
Epoch 0: 53%|█████▎ | 29/55 [00:38<00:34, 0.74it/s, v_num=0, loss_step=51.50]
Epoch 0: 53%|█████▎ | 29/55 [00:38<00:34, 0.74it/s, v_num=0, loss_step=51.30]
Epoch 0: 55%|█████▍ | 30/55 [00:40<00:33, 0.74it/s, v_num=0, loss_step=51.30]
Epoch 0: 55%|█████▍ | 30/55 [00:40<00:33, 0.74it/s, v_num=0, loss_step=51.00]
Epoch 0: 56%|█████▋ | 31/55 [00:41<00:32, 0.74it/s, v_num=0, loss_step=51.00]
Epoch 0: 56%|█████▋ | 31/55 [00:41<00:32, 0.74it/s, v_num=0, loss_step=50.10]
Epoch 0: 58%|█████▊ | 32/55 [00:42<00:30, 0.74it/s, v_num=0, loss_step=50.10]
Epoch 0: 58%|█████▊ | 32/55 [00:42<00:30, 0.74it/s, v_num=0, loss_step=49.70]
Epoch 0: 60%|██████ | 33/55 [00:44<00:29, 0.75it/s, v_num=0, loss_step=49.70]
Epoch 0: 60%|██████ | 33/55 [00:44<00:29, 0.75it/s, v_num=0, loss_step=49.50]
Epoch 0: 62%|██████▏ | 34/55 [00:45<00:28, 0.75it/s, v_num=0, loss_step=49.50]
Epoch 0: 62%|██████▏ | 34/55 [00:45<00:28, 0.75it/s, v_num=0, loss_step=50.10]
Epoch 0: 64%|██████▎ | 35/55 [00:46<00:26, 0.75it/s, v_num=0, loss_step=50.10]
Epoch 0: 64%|██████▎ | 35/55 [00:46<00:26, 0.75it/s, v_num=0, loss_step=48.10]
Epoch 0: 65%|██████▌ | 36/55 [00:48<00:25, 0.75it/s, v_num=0, loss_step=48.10]
Epoch 0: 65%|██████▌ | 36/55 [00:48<00:25, 0.75it/s, v_num=0, loss_step=48.60]
Epoch 0: 67%|██████▋ | 37/55 [00:49<00:24, 0.75it/s, v_num=0, loss_step=48.60]
Epoch 0: 67%|██████▋ | 37/55 [00:49<00:24, 0.75it/s, v_num=0, loss_step=47.70]
Epoch 0: 69%|██████▉ | 38/55 [00:50<00:22, 0.75it/s, v_num=0, loss_step=47.70]
Epoch 0: 69%|██████▉ | 38/55 [00:50<00:22, 0.75it/s, v_num=0, loss_step=47.50]
Epoch 0: 71%|███████ | 39/55 [00:52<00:21, 0.75it/s, v_num=0, loss_step=47.50]
Epoch 0: 71%|███████ | 39/55 [00:52<00:21, 0.75it/s, v_num=0, loss_step=47.90]
Epoch 0: 73%|███████▎ | 40/55 [00:53<00:20, 0.75it/s, v_num=0, loss_step=47.90]
Epoch 0: 73%|███████▎ | 40/55 [00:53<00:20, 0.75it/s, v_num=0, loss_step=46.50]
Epoch 0: 75%|███████▍ | 41/55 [00:54<00:18, 0.75it/s, v_num=0, loss_step=46.50]
Epoch 0: 75%|███████▍ | 41/55 [00:54<00:18, 0.75it/s, v_num=0, loss_step=46.90]
Epoch 0: 76%|███████▋ | 42/55 [00:56<00:17, 0.75it/s, v_num=0, loss_step=46.90]
Epoch 0: 76%|███████▋ | 42/55 [00:56<00:17, 0.75it/s, v_num=0, loss_step=46.30]
Epoch 0: 78%|███████▊ | 43/55 [00:57<00:16, 0.75it/s, v_num=0, loss_step=46.30]
Epoch 0: 78%|███████▊ | 43/55 [00:57<00:16, 0.75it/s, v_num=0, loss_step=46.60]
Epoch 0: 80%|████████ | 44/55 [00:58<00:14, 0.75it/s, v_num=0, loss_step=46.60]
Epoch 0: 80%|████████ | 44/55 [00:58<00:14, 0.75it/s, v_num=0, loss_step=46.40]
Epoch 0: 82%|████████▏ | 45/55 [01:00<00:13, 0.75it/s, v_num=0, loss_step=46.40]
Epoch 0: 82%|████████▏ | 45/55 [01:00<00:13, 0.75it/s, v_num=0, loss_step=46.10]
Epoch 0: 84%|████████▎ | 46/55 [01:01<00:12, 0.75it/s, v_num=0, loss_step=46.10]
Epoch 0: 84%|████████▎ | 46/55 [01:01<00:12, 0.75it/s, v_num=0, loss_step=45.80]
Epoch 0: 85%|████████▌ | 47/55 [01:02<00:10, 0.75it/s, v_num=0, loss_step=45.80]
Epoch 0: 85%|████████▌ | 47/55 [01:02<00:10, 0.75it/s, v_num=0, loss_step=44.40]
Epoch 0: 87%|████████▋ | 48/55 [01:04<00:09, 0.75it/s, v_num=0, loss_step=44.40]
Epoch 0: 87%|████████▋ | 48/55 [01:04<00:09, 0.75it/s, v_num=0, loss_step=46.80]
Epoch 0: 89%|████████▉ | 49/55 [01:05<00:08, 0.75it/s, v_num=0, loss_step=46.80]
Epoch 0: 89%|████████▉ | 49/55 [01:05<00:08, 0.75it/s, v_num=0, loss_step=46.10]
Epoch 0: 91%|█████████ | 50/55 [01:07<00:06, 0.75it/s, v_num=0, loss_step=46.10]
Epoch 0: 91%|█████████ | 50/55 [01:07<00:06, 0.75it/s, v_num=0, loss_step=46.40]
Epoch 0: 93%|█████████▎| 51/55 [01:08<00:05, 0.75it/s, v_num=0, loss_step=46.40]
Epoch 0: 93%|█████████▎| 51/55 [01:08<00:05, 0.75it/s, v_num=0, loss_step=45.30]
Epoch 0: 95%|█████████▍| 52/55 [01:09<00:04, 0.75it/s, v_num=0, loss_step=45.30]
Epoch 0: 95%|█████████▍| 52/55 [01:09<00:04, 0.75it/s, v_num=0, loss_step=45.80]
Epoch 0: 96%|█████████▋| 53/55 [01:11<00:02, 0.75it/s, v_num=0, loss_step=45.80]
Epoch 0: 96%|█████████▋| 53/55 [01:11<00:02, 0.75it/s, v_num=0, loss_step=45.30]
Epoch 0: 98%|█████████▊| 54/55 [01:12<00:01, 0.75it/s, v_num=0, loss_step=45.30]
Epoch 0: 98%|█████████▊| 54/55 [01:12<00:01, 0.75it/s, v_num=0, loss_step=45.30]
Epoch 0: 100%|██████████| 55/55 [01:13<00:00, 0.75it/s, v_num=0, loss_step=45.30]
Epoch 0: 100%|██████████| 55/55 [01:13<00:00, 0.75it/s, v_num=0, loss_step=46.60]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s]
Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.89it/s]
Epoch 0: 100%|██████████| 55/55 [01:14<00:00, 0.74it/s, v_num=0, loss_step=46.60, val_loss=45.80]
Epoch 0: 100%|██████████| 55/55 [01:14<00:00, 0.74it/s, v_num=0, loss_step=46.60, val_loss=45.80, loss_epoch=50.60]
Epoch 0: 100%|██████████| 55/55 [01:14<00:00, 0.74it/s, v_num=0, loss_step=46.60, val_loss=45.80, loss_epoch=50.60]
Total running time of the script: (1 minutes 15.625 seconds)
Estimated memory usage: 725 MB