Training a deep learning model#

This example shows how to train a deep learning model using the dataset created in the previous example.

Loading the dataset#

To load the dataset we need to use the EMGDatasetLoader class.

Two parameters are required:

  • data_path: Path to the dataset file.

  • dataloader_parameters: Parameters for the DataLoader.

from pathlib import Path

from myoverse.datasets.loader import EMGDatasetLoader
from myoverse.datatypes import _Data


# Create a class to handle the target data that doesn't enforce a specific shape
class CustomDataClass(_Data):
    def __init__(self, raw_data, sampling_frequency=None):
        # Initialize parent class with raw data
        super().__init__(raw_data.reshape(1, 60), sampling_frequency, nr_of_dimensions_when_unchunked=2)

# Let's use the built-in IdentityFilter which just passes data through
loader = EMGDatasetLoader(
    Path(r"../data/dataset.zarr").resolve(),
    dataloader_params={"batch_size": 16, "drop_last": True},
    target_data_class=CustomDataClass,
)

Training the model#

from myoverse.models.definitions.raul_net.online.v16 import RaulNetV16
import lightning as L

# Create the model
model = RaulNetV16(
    learning_rate=1e-4,
    nr_of_input_channels=2,
    input_length__samples=192,
    nr_of_outputs=60,
    nr_of_electrode_grids=5,
    nr_of_electrodes_per_grid=64,
    # Multiply following by 4, 8, 16 to have a useful network
    cnn_encoder_channels=(4, 1, 1),
    mlp_encoder_channels=(8, 8),
    event_search_kernel_length=31,
    event_search_kernel_stride=8,
)

trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    precision="16-mixed",
    max_epochs=1,
    log_every_n_steps=50,
    logger=None,
    enable_checkpointing=False,
    deterministic=False,
)

trainer.fit(model, datamodule=loader)
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.

Sanity Checking: |          | 0/? [00:00<?, ?it/s]/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.

Sanity Checking:   0%|          | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]


Training: |          | 0/? [00:00<?, ?it/s]
Training:   0%|          | 0/55 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/55 [00:00<?, ?it/s]
Epoch 0:   2%|▏         | 1/55 [00:01<01:26,  0.62it/s]
Epoch 0:   2%|▏         | 1/55 [00:01<01:26,  0.62it/s, v_num=0, loss_step=54.80]
Epoch 0:   4%|▎         | 2/55 [00:03<01:23,  0.63it/s, v_num=0, loss_step=54.80]
Epoch 0:   4%|▎         | 2/55 [00:03<01:23,  0.63it/s, v_num=0, loss_step=55.00]
Epoch 0:   5%|▌         | 3/55 [00:04<01:21,  0.64it/s, v_num=0, loss_step=55.00]
Epoch 0:   5%|▌         | 3/55 [00:04<01:21,  0.64it/s, v_num=0, loss_step=54.50]
Epoch 0:   7%|▋         | 4/55 [00:06<01:19,  0.64it/s, v_num=0, loss_step=54.50]
Epoch 0:   7%|▋         | 4/55 [00:06<01:19,  0.64it/s, v_num=0, loss_step=54.80]
Epoch 0:   9%|▉         | 5/55 [00:07<01:18,  0.64it/s, v_num=0, loss_step=54.80]
Epoch 0:   9%|▉         | 5/55 [00:07<01:18,  0.64it/s, v_num=0, loss_step=55.20]
Epoch 0:  11%|█         | 6/55 [00:09<01:16,  0.64it/s, v_num=0, loss_step=55.20]
Epoch 0:  11%|█         | 6/55 [00:09<01:16,  0.64it/s, v_num=0, loss_step=54.30]
Epoch 0:  13%|█▎        | 7/55 [00:10<01:15,  0.64it/s, v_num=0, loss_step=54.30]
Epoch 0:  13%|█▎        | 7/55 [00:10<01:15,  0.64it/s, v_num=0, loss_step=53.80]
Epoch 0:  15%|█▍        | 8/55 [00:12<01:13,  0.64it/s, v_num=0, loss_step=53.80]
Epoch 0:  15%|█▍        | 8/55 [00:12<01:13,  0.64it/s, v_num=0, loss_step=55.00]
Epoch 0:  16%|█▋        | 9/55 [00:14<01:11,  0.64it/s, v_num=0, loss_step=55.00]
Epoch 0:  16%|█▋        | 9/55 [00:14<01:11,  0.64it/s, v_num=0, loss_step=54.40]
Epoch 0:  18%|█▊        | 10/55 [00:15<01:10,  0.64it/s, v_num=0, loss_step=54.40]
Epoch 0:  18%|█▊        | 10/55 [00:15<01:10,  0.64it/s, v_num=0, loss_step=54.20]
Epoch 0:  20%|██        | 11/55 [00:17<01:08,  0.64it/s, v_num=0, loss_step=54.20]
Epoch 0:  20%|██        | 11/55 [00:17<01:08,  0.64it/s, v_num=0, loss_step=54.30]
Epoch 0:  22%|██▏       | 12/55 [00:18<01:07,  0.64it/s, v_num=0, loss_step=54.30]
Epoch 0:  22%|██▏       | 12/55 [00:18<01:07,  0.64it/s, v_num=0, loss_step=54.10]
Epoch 0:  24%|██▎       | 13/55 [00:20<01:05,  0.64it/s, v_num=0, loss_step=54.10]
Epoch 0:  24%|██▎       | 13/55 [00:20<01:05,  0.64it/s, v_num=0, loss_step=54.10]
Epoch 0:  25%|██▌       | 14/55 [00:21<01:04,  0.64it/s, v_num=0, loss_step=54.10]
Epoch 0:  25%|██▌       | 14/55 [00:21<01:04,  0.64it/s, v_num=0, loss_step=53.70]
Epoch 0:  27%|██▋       | 15/55 [00:23<01:02,  0.64it/s, v_num=0, loss_step=53.70]
Epoch 0:  27%|██▋       | 15/55 [00:23<01:02,  0.64it/s, v_num=0, loss_step=53.30]
Epoch 0:  29%|██▉       | 16/55 [00:25<01:00,  0.64it/s, v_num=0, loss_step=53.30]
Epoch 0:  29%|██▉       | 16/55 [00:25<01:00,  0.64it/s, v_num=0, loss_step=53.70]
Epoch 0:  31%|███       | 17/55 [00:26<00:59,  0.64it/s, v_num=0, loss_step=53.70]
Epoch 0:  31%|███       | 17/55 [00:26<00:59,  0.64it/s, v_num=0, loss_step=53.90]
Epoch 0:  33%|███▎      | 18/55 [00:28<00:57,  0.64it/s, v_num=0, loss_step=53.90]
Epoch 0:  33%|███▎      | 18/55 [00:28<00:57,  0.64it/s, v_num=0, loss_step=55.80]
Epoch 0:  35%|███▍      | 19/55 [00:29<00:56,  0.64it/s, v_num=0, loss_step=55.80]
Epoch 0:  35%|███▍      | 19/55 [00:29<00:56,  0.64it/s, v_num=0, loss_step=53.90]
Epoch 0:  36%|███▋      | 20/55 [00:31<00:54,  0.64it/s, v_num=0, loss_step=53.90]
Epoch 0:  36%|███▋      | 20/55 [00:31<00:54,  0.64it/s, v_num=0, loss_step=53.60]
Epoch 0:  38%|███▊      | 21/55 [00:32<00:53,  0.64it/s, v_num=0, loss_step=53.60]
Epoch 0:  38%|███▊      | 21/55 [00:32<00:53,  0.64it/s, v_num=0, loss_step=53.50]
Epoch 0:  40%|████      | 22/55 [00:34<00:51,  0.64it/s, v_num=0, loss_step=53.50]
Epoch 0:  40%|████      | 22/55 [00:34<00:51,  0.64it/s, v_num=0, loss_step=54.00]
Epoch 0:  42%|████▏     | 23/55 [00:35<00:49,  0.64it/s, v_num=0, loss_step=54.00]
Epoch 0:  42%|████▏     | 23/55 [00:35<00:49,  0.64it/s, v_num=0, loss_step=53.80]
Epoch 0:  44%|████▎     | 24/55 [00:37<00:48,  0.64it/s, v_num=0, loss_step=53.80]
Epoch 0:  44%|████▎     | 24/55 [00:37<00:48,  0.64it/s, v_num=0, loss_step=53.40]
Epoch 0:  45%|████▌     | 25/55 [00:38<00:46,  0.64it/s, v_num=0, loss_step=53.40]
Epoch 0:  45%|████▌     | 25/55 [00:38<00:46,  0.64it/s, v_num=0, loss_step=53.50]
Epoch 0:  47%|████▋     | 26/55 [00:40<00:45,  0.64it/s, v_num=0, loss_step=53.50]
Epoch 0:  47%|████▋     | 26/55 [00:40<00:45,  0.64it/s, v_num=0, loss_step=53.30]
Epoch 0:  49%|████▉     | 27/55 [00:42<00:43,  0.64it/s, v_num=0, loss_step=53.30]
Epoch 0:  49%|████▉     | 27/55 [00:42<00:43,  0.64it/s, v_num=0, loss_step=52.10]
Epoch 0:  51%|█████     | 28/55 [00:43<00:42,  0.64it/s, v_num=0, loss_step=52.10]
Epoch 0:  51%|█████     | 28/55 [00:43<00:42,  0.64it/s, v_num=0, loss_step=52.70]
Epoch 0:  53%|█████▎    | 29/55 [00:45<00:40,  0.64it/s, v_num=0, loss_step=52.70]
Epoch 0:  53%|█████▎    | 29/55 [00:45<00:40,  0.64it/s, v_num=0, loss_step=52.20]
Epoch 0:  55%|█████▍    | 30/55 [00:46<00:38,  0.64it/s, v_num=0, loss_step=52.20]
Epoch 0:  55%|█████▍    | 30/55 [00:46<00:38,  0.64it/s, v_num=0, loss_step=51.20]
Epoch 0:  56%|█████▋    | 31/55 [00:48<00:37,  0.64it/s, v_num=0, loss_step=51.20]
Epoch 0:  56%|█████▋    | 31/55 [00:48<00:37,  0.64it/s, v_num=0, loss_step=51.80]
Epoch 0:  58%|█████▊    | 32/55 [00:49<00:35,  0.64it/s, v_num=0, loss_step=51.80]
Epoch 0:  58%|█████▊    | 32/55 [00:49<00:35,  0.64it/s, v_num=0, loss_step=52.10]
Epoch 0:  60%|██████    | 33/55 [00:51<00:34,  0.64it/s, v_num=0, loss_step=52.10]
Epoch 0:  60%|██████    | 33/55 [00:51<00:34,  0.64it/s, v_num=0, loss_step=52.20]
Epoch 0:  62%|██████▏   | 34/55 [00:52<00:32,  0.64it/s, v_num=0, loss_step=52.20]
Epoch 0:  62%|██████▏   | 34/55 [00:52<00:32,  0.64it/s, v_num=0, loss_step=50.80]
Epoch 0:  64%|██████▎   | 35/55 [00:54<00:31,  0.64it/s, v_num=0, loss_step=50.80]
Epoch 0:  64%|██████▎   | 35/55 [00:54<00:31,  0.64it/s, v_num=0, loss_step=51.20]
Epoch 0:  65%|██████▌   | 36/55 [00:55<00:29,  0.64it/s, v_num=0, loss_step=51.20]
Epoch 0:  65%|██████▌   | 36/55 [00:55<00:29,  0.64it/s, v_num=0, loss_step=51.30]
Epoch 0:  67%|██████▋   | 37/55 [00:57<00:27,  0.64it/s, v_num=0, loss_step=51.30]
Epoch 0:  67%|██████▋   | 37/55 [00:57<00:27,  0.64it/s, v_num=0, loss_step=51.10]
Epoch 0:  69%|██████▉   | 38/55 [00:59<00:26,  0.64it/s, v_num=0, loss_step=51.10]
Epoch 0:  69%|██████▉   | 38/55 [00:59<00:26,  0.64it/s, v_num=0, loss_step=50.40]
Epoch 0:  71%|███████   | 39/55 [01:00<00:24,  0.64it/s, v_num=0, loss_step=50.40]
Epoch 0:  71%|███████   | 39/55 [01:00<00:24,  0.64it/s, v_num=0, loss_step=50.50]
Epoch 0:  73%|███████▎  | 40/55 [01:02<00:23,  0.64it/s, v_num=0, loss_step=50.50]
Epoch 0:  73%|███████▎  | 40/55 [01:02<00:23,  0.64it/s, v_num=0, loss_step=51.00]
Epoch 0:  75%|███████▍  | 41/55 [01:03<00:21,  0.64it/s, v_num=0, loss_step=51.00]
Epoch 0:  75%|███████▍  | 41/55 [01:03<00:21,  0.64it/s, v_num=0, loss_step=50.60]
Epoch 0:  76%|███████▋  | 42/55 [01:05<00:20,  0.64it/s, v_num=0, loss_step=50.60]
Epoch 0:  76%|███████▋  | 42/55 [01:05<00:20,  0.64it/s, v_num=0, loss_step=50.00]
Epoch 0:  78%|███████▊  | 43/55 [01:06<00:18,  0.64it/s, v_num=0, loss_step=50.00]
Epoch 0:  78%|███████▊  | 43/55 [01:06<00:18,  0.64it/s, v_num=0, loss_step=50.20]
Epoch 0:  80%|████████  | 44/55 [01:08<00:17,  0.64it/s, v_num=0, loss_step=50.20]
Epoch 0:  80%|████████  | 44/55 [01:08<00:17,  0.64it/s, v_num=0, loss_step=50.50]
Epoch 0:  82%|████████▏ | 45/55 [01:09<00:15,  0.64it/s, v_num=0, loss_step=50.50]
Epoch 0:  82%|████████▏ | 45/55 [01:09<00:15,  0.64it/s, v_num=0, loss_step=50.20]
Epoch 0:  84%|████████▎ | 46/55 [01:11<00:13,  0.64it/s, v_num=0, loss_step=50.20]
Epoch 0:  84%|████████▎ | 46/55 [01:11<00:13,  0.64it/s, v_num=0, loss_step=49.80]
Epoch 0:  85%|████████▌ | 47/55 [01:13<00:12,  0.64it/s, v_num=0, loss_step=49.80]
Epoch 0:  85%|████████▌ | 47/55 [01:13<00:12,  0.64it/s, v_num=0, loss_step=50.20]
Epoch 0:  87%|████████▋ | 48/55 [01:14<00:10,  0.64it/s, v_num=0, loss_step=50.20]
Epoch 0:  87%|████████▋ | 48/55 [01:14<00:10,  0.64it/s, v_num=0, loss_step=50.70]
Epoch 0:  89%|████████▉ | 49/55 [01:16<00:09,  0.64it/s, v_num=0, loss_step=50.70]
Epoch 0:  89%|████████▉ | 49/55 [01:16<00:09,  0.64it/s, v_num=0, loss_step=48.90]
Epoch 0:  91%|█████████ | 50/55 [01:17<00:07,  0.64it/s, v_num=0, loss_step=48.90]
Epoch 0:  91%|█████████ | 50/55 [01:17<00:07,  0.64it/s, v_num=0, loss_step=49.30]
Epoch 0:  93%|█████████▎| 51/55 [01:19<00:06,  0.64it/s, v_num=0, loss_step=49.30]
Epoch 0:  93%|█████████▎| 51/55 [01:19<00:06,  0.64it/s, v_num=0, loss_step=49.90]
Epoch 0:  95%|█████████▍| 52/55 [01:20<00:04,  0.64it/s, v_num=0, loss_step=49.90]
Epoch 0:  95%|█████████▍| 52/55 [01:20<00:04,  0.64it/s, v_num=0, loss_step=49.70]
Epoch 0:  96%|█████████▋| 53/55 [01:22<00:03,  0.64it/s, v_num=0, loss_step=49.70]
Epoch 0:  96%|█████████▋| 53/55 [01:22<00:03,  0.64it/s, v_num=0, loss_step=49.60]
Epoch 0:  98%|█████████▊| 54/55 [01:23<00:01,  0.64it/s, v_num=0, loss_step=49.60]
Epoch 0:  98%|█████████▊| 54/55 [01:23<00:01,  0.64it/s, v_num=0, loss_step=50.30]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00,  0.64it/s, v_num=0, loss_step=50.30]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00,  0.64it/s, v_num=0, loss_step=50.60]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s]


Epoch 0: 100%|██████████| 55/55 [01:25<00:00,  0.64it/s, v_num=0, loss_step=50.60, val_loss=49.60]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00,  0.64it/s, v_num=0, loss_step=50.60, val_loss=49.60, loss_epoch=52.30]
Epoch 0: 100%|██████████| 55/55 [01:25<00:00,  0.64it/s, v_num=0, loss_step=50.60, val_loss=49.60, loss_epoch=52.30]

Total running time of the script: (1 minutes 27.434 seconds)

Estimated memory usage: 589 MB

Gallery generated by Sphinx-Gallery