.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/01_tutorials/4_train_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end <sphx_glr_download_auto_examples_01_tutorials_4_train_model.py>` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_01_tutorials_4_train_model.py: Training a deep learning model =========================== This example shows how to train a deep learning model using the dataset created in the previous example. .. GENERATED FROM PYTHON SOURCE LINES 9-18 Loading the dataset ------------------- To load the dataset we need to use the EMGDatasetLoader class. Two parameters are required: - data_path: Path to the dataset file. - dataloader_parameters: Parameters for the DataLoader. .. GENERATED FROM PYTHON SOURCE LINES 18-37 .. code-block:: Python from pathlib import Path from myoverse.datasets.loader import EMGDatasetLoader from myoverse.datatypes import _Data # Create a class to handle the target data that doesn't enforce a specific shape class CustomDataClass(_Data): def __init__(self, raw_data, sampling_frequency=None): # Initialize parent class with raw data super().__init__(raw_data.reshape(1, 60), sampling_frequency, nr_of_dimensions_when_unchunked=2) # Let's use the built-in IdentityFilter which just passes data through loader = EMGDatasetLoader( Path(r"../data/dataset.zarr").resolve(), dataloader_params={"batch_size": 16, "drop_last": True}, target_data_class=CustomDataClass, ) .. GENERATED FROM PYTHON SOURCE LINES 38-40 Training the model ------------------ .. GENERATED FROM PYTHON SOURCE LINES 40-70 .. code-block:: Python from myoverse.models.definitions.raul_net.online.v16 import RaulNetV16 import lightning as L # Create the model model = RaulNetV16( learning_rate=1e-4, nr_of_input_channels=2, input_length__samples=192, nr_of_outputs=60, nr_of_electrode_grids=5, nr_of_electrodes_per_grid=64, # Multiply following by 4, 8, 16 to have a useful network cnn_encoder_channels=(4, 1, 1), mlp_encoder_channels=(8, 8), event_search_kernel_length=31, event_search_kernel_stride=8, ) trainer = L.Trainer( accelerator="auto", devices=1, precision="16-mixed", max_epochs=1, log_every_n_steps=50, logger=None, enable_checkpointing=False, deterministic=False, ) trainer.fit(model, datamodule=loader) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead. /home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. Sanity Checking: | | 0/? [00:00<?, ?it/s]/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. Sanity Checking: 0%| | 0/1 [00:00<?, ?it/s] Sanity Checking DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s] Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.72it/s] Training: | | 0/? [00:00<?, ?it/s] Training: 0%| | 0/55 [00:00<?, ?it/s] Epoch 0: 0%| | 0/55 [00:00<?, ?it/s] Epoch 0: 2%|▏ | 1/55 [00:01<01:15, 0.71it/s] Epoch 0: 2%|▏ | 1/55 [00:01<01:15, 0.71it/s, v_num=0, loss_step=54.20] Epoch 0: 4%|▎ | 2/55 [00:02<01:13, 0.72it/s, v_num=0, loss_step=54.20] Epoch 0: 4%|▎ | 2/55 [00:02<01:13, 0.72it/s, v_num=0, loss_step=55.40] Epoch 0: 5%|▌ | 3/55 [00:04<01:11, 0.72it/s, v_num=0, loss_step=55.40] Epoch 0: 5%|▌ | 3/55 [00:04<01:11, 0.72it/s, v_num=0, loss_step=53.70] Epoch 0: 7%|▋ | 4/55 [00:05<01:09, 0.73it/s, v_num=0, loss_step=53.70] Epoch 0: 7%|▋ | 4/55 [00:05<01:09, 0.73it/s, v_num=0, loss_step=54.30] Epoch 0: 9%|▉ | 5/55 [00:06<01:08, 0.73it/s, v_num=0, loss_step=54.30] Epoch 0: 9%|▉ | 5/55 [00:06<01:08, 0.73it/s, v_num=0, loss_step=54.40] Epoch 0: 11%|█ | 6/55 [00:08<01:07, 0.73it/s, v_num=0, loss_step=54.40] Epoch 0: 11%|█ | 6/55 [00:08<01:07, 0.73it/s, v_num=0, loss_step=54.50] Epoch 0: 13%|█▎ | 7/55 [00:09<01:05, 0.73it/s, v_num=0, loss_step=54.50] Epoch 0: 13%|█▎ | 7/55 [00:09<01:05, 0.73it/s, v_num=0, loss_step=54.20] Epoch 0: 15%|█▍ | 8/55 [00:10<01:04, 0.73it/s, v_num=0, loss_step=54.20] Epoch 0: 15%|█▍ | 8/55 [00:10<01:04, 0.73it/s, v_num=0, loss_step=53.50] Epoch 0: 16%|█▋ | 9/55 [00:12<01:03, 0.73it/s, v_num=0, loss_step=53.50] Epoch 0: 16%|█▋ | 9/55 [00:12<01:03, 0.73it/s, v_num=0, loss_step=54.70] Epoch 0: 18%|█▊ | 10/55 [00:13<01:01, 0.73it/s, v_num=0, loss_step=54.70] Epoch 0: 18%|█▊ | 10/55 [00:13<01:01, 0.73it/s, v_num=0, loss_step=54.20] Epoch 0: 20%|██ | 11/55 [00:15<01:00, 0.73it/s, v_num=0, loss_step=54.20] Epoch 0: 20%|██ | 11/55 [00:15<01:00, 0.73it/s, v_num=0, loss_step=54.10] Epoch 0: 22%|██▏ | 12/55 [00:16<00:58, 0.73it/s, v_num=0, loss_step=54.10] Epoch 0: 22%|██▏ | 12/55 [00:16<00:58, 0.73it/s, v_num=0, loss_step=54.80] Epoch 0: 24%|██▎ | 13/55 [00:17<00:57, 0.73it/s, v_num=0, loss_step=54.80] Epoch 0: 24%|██▎ | 13/55 [00:17<00:57, 0.73it/s, v_num=0, loss_step=54.90] Epoch 0: 25%|██▌ | 14/55 [00:19<00:56, 0.73it/s, v_num=0, loss_step=54.90] Epoch 0: 25%|██▌ | 14/55 [00:19<00:56, 0.73it/s, v_num=0, loss_step=55.30] Epoch 0: 27%|██▋ | 15/55 [00:20<00:54, 0.73it/s, v_num=0, loss_step=55.30] Epoch 0: 27%|██▋ | 15/55 [00:20<00:54, 0.73it/s, v_num=0, loss_step=53.30] Epoch 0: 29%|██▉ | 16/55 [00:21<00:53, 0.73it/s, v_num=0, loss_step=53.30] Epoch 0: 29%|██▉ | 16/55 [00:21<00:53, 0.73it/s, v_num=0, loss_step=53.70] Epoch 0: 31%|███ | 17/55 [00:23<00:52, 0.73it/s, v_num=0, loss_step=53.70] Epoch 0: 31%|███ | 17/55 [00:23<00:52, 0.73it/s, v_num=0, loss_step=53.40] Epoch 0: 33%|███▎ | 18/55 [00:24<00:50, 0.73it/s, v_num=0, loss_step=53.40] Epoch 0: 33%|███▎ | 18/55 [00:24<00:50, 0.73it/s, v_num=0, loss_step=54.30] Epoch 0: 35%|███▍ | 19/55 [00:26<00:49, 0.73it/s, v_num=0, loss_step=54.30] Epoch 0: 35%|███▍ | 19/55 [00:26<00:49, 0.73it/s, v_num=0, loss_step=54.30] Epoch 0: 36%|███▋ | 20/55 [00:27<00:47, 0.73it/s, v_num=0, loss_step=54.30] Epoch 0: 36%|███▋ | 20/55 [00:27<00:47, 0.73it/s, v_num=0, loss_step=52.70] Epoch 0: 38%|███▊ | 21/55 [00:28<00:46, 0.73it/s, v_num=0, loss_step=52.70] Epoch 0: 38%|███▊ | 21/55 [00:28<00:46, 0.73it/s, v_num=0, loss_step=53.30] Epoch 0: 40%|████ | 22/55 [00:30<00:45, 0.73it/s, v_num=0, loss_step=53.30] Epoch 0: 40%|████ | 22/55 [00:30<00:45, 0.73it/s, v_num=0, loss_step=53.30] Epoch 0: 42%|████▏ | 23/55 [00:31<00:43, 0.73it/s, v_num=0, loss_step=53.30] Epoch 0: 42%|████▏ | 23/55 [00:31<00:43, 0.73it/s, v_num=0, loss_step=52.40] Epoch 0: 44%|████▎ | 24/55 [00:32<00:42, 0.73it/s, v_num=0, loss_step=52.40] Epoch 0: 44%|████▎ | 24/55 [00:32<00:42, 0.73it/s, v_num=0, loss_step=52.90] Epoch 0: 45%|████▌ | 25/55 [00:34<00:41, 0.73it/s, v_num=0, loss_step=52.90] Epoch 0: 45%|████▌ | 25/55 [00:34<00:41, 0.73it/s, v_num=0, loss_step=51.50] Epoch 0: 47%|████▋ | 26/55 [00:35<00:39, 0.73it/s, v_num=0, loss_step=51.50] Epoch 0: 47%|████▋ | 26/55 [00:35<00:39, 0.73it/s, v_num=0, loss_step=52.70] Epoch 0: 49%|████▉ | 27/55 [00:36<00:38, 0.73it/s, v_num=0, loss_step=52.70] Epoch 0: 49%|████▉ | 27/55 [00:36<00:38, 0.73it/s, v_num=0, loss_step=51.40] Epoch 0: 51%|█████ | 28/55 [00:38<00:36, 0.73it/s, v_num=0, loss_step=51.40] Epoch 0: 51%|█████ | 28/55 [00:38<00:36, 0.73it/s, v_num=0, loss_step=50.20] Epoch 0: 53%|█████▎ | 29/55 [00:39<00:35, 0.73it/s, v_num=0, loss_step=50.20] Epoch 0: 53%|█████▎ | 29/55 [00:39<00:35, 0.73it/s, v_num=0, loss_step=50.20] Epoch 0: 55%|█████▍ | 30/55 [00:41<00:34, 0.73it/s, v_num=0, loss_step=50.20] Epoch 0: 55%|█████▍ | 30/55 [00:41<00:34, 0.73it/s, v_num=0, loss_step=48.50] Epoch 0: 56%|█████▋ | 31/55 [00:42<00:32, 0.73it/s, v_num=0, loss_step=48.50] Epoch 0: 56%|█████▋ | 31/55 [00:42<00:32, 0.73it/s, v_num=0, loss_step=48.90] Epoch 0: 58%|█████▊ | 32/55 [00:43<00:31, 0.73it/s, v_num=0, loss_step=48.90] Epoch 0: 58%|█████▊ | 32/55 [00:43<00:31, 0.73it/s, v_num=0, loss_step=48.70] Epoch 0: 60%|██████ | 33/55 [00:45<00:30, 0.73it/s, v_num=0, loss_step=48.70] Epoch 0: 60%|██████ | 33/55 [00:45<00:30, 0.73it/s, v_num=0, loss_step=48.50] Epoch 0: 62%|██████▏ | 34/55 [00:46<00:28, 0.73it/s, v_num=0, loss_step=48.50] Epoch 0: 62%|██████▏ | 34/55 [00:46<00:28, 0.73it/s, v_num=0, loss_step=46.40] Epoch 0: 64%|██████▎ | 35/55 [00:47<00:27, 0.73it/s, v_num=0, loss_step=46.40] Epoch 0: 64%|██████▎ | 35/55 [00:47<00:27, 0.73it/s, v_num=0, loss_step=46.80] Epoch 0: 65%|██████▌ | 36/55 [00:49<00:25, 0.73it/s, v_num=0, loss_step=46.80] Epoch 0: 65%|██████▌ | 36/55 [00:49<00:25, 0.73it/s, v_num=0, loss_step=46.80] Epoch 0: 67%|██████▋ | 37/55 [00:50<00:24, 0.73it/s, v_num=0, loss_step=46.80] Epoch 0: 67%|██████▋ | 37/55 [00:50<00:24, 0.73it/s, v_num=0, loss_step=45.70] Epoch 0: 69%|██████▉ | 38/55 [00:51<00:23, 0.73it/s, v_num=0, loss_step=45.70] Epoch 0: 69%|██████▉ | 38/55 [00:51<00:23, 0.73it/s, v_num=0, loss_step=46.30] Epoch 0: 71%|███████ | 39/55 [00:53<00:21, 0.73it/s, v_num=0, loss_step=46.30] Epoch 0: 71%|███████ | 39/55 [00:53<00:21, 0.73it/s, v_num=0, loss_step=45.90] Epoch 0: 73%|███████▎ | 40/55 [00:54<00:20, 0.73it/s, v_num=0, loss_step=45.90] Epoch 0: 73%|███████▎ | 40/55 [00:54<00:20, 0.73it/s, v_num=0, loss_step=45.80] Epoch 0: 75%|███████▍ | 41/55 [00:55<00:19, 0.73it/s, v_num=0, loss_step=45.80] Epoch 0: 75%|███████▍ | 41/55 [00:55<00:19, 0.73it/s, v_num=0, loss_step=46.30] Epoch 0: 76%|███████▋ | 42/55 [00:57<00:17, 0.73it/s, v_num=0, loss_step=46.30] Epoch 0: 76%|███████▋ | 42/55 [00:57<00:17, 0.73it/s, v_num=0, loss_step=44.90] Epoch 0: 78%|███████▊ | 43/55 [00:58<00:16, 0.73it/s, v_num=0, loss_step=44.90] Epoch 0: 78%|███████▊ | 43/55 [00:58<00:16, 0.73it/s, v_num=0, loss_step=44.10] Epoch 0: 80%|████████ | 44/55 [01:00<00:15, 0.73it/s, v_num=0, loss_step=44.10] Epoch 0: 80%|████████ | 44/55 [01:00<00:15, 0.73it/s, v_num=0, loss_step=44.60] Epoch 0: 82%|████████▏ | 45/55 [01:01<00:13, 0.73it/s, v_num=0, loss_step=44.60] Epoch 0: 82%|████████▏ | 45/55 [01:01<00:13, 0.73it/s, v_num=0, loss_step=45.10] Epoch 0: 84%|████████▎ | 46/55 [01:02<00:12, 0.73it/s, v_num=0, loss_step=45.10] Epoch 0: 84%|████████▎ | 46/55 [01:02<00:12, 0.73it/s, v_num=0, loss_step=43.70] Epoch 0: 85%|████████▌ | 47/55 [01:04<00:10, 0.73it/s, v_num=0, loss_step=43.70] Epoch 0: 85%|████████▌ | 47/55 [01:04<00:10, 0.73it/s, v_num=0, loss_step=44.60] Epoch 0: 87%|████████▋ | 48/55 [01:05<00:09, 0.73it/s, v_num=0, loss_step=44.60] Epoch 0: 87%|████████▋ | 48/55 [01:05<00:09, 0.73it/s, v_num=0, loss_step=44.50] Epoch 0: 89%|████████▉ | 49/55 [01:06<00:08, 0.73it/s, v_num=0, loss_step=44.50] Epoch 0: 89%|████████▉ | 49/55 [01:06<00:08, 0.73it/s, v_num=0, loss_step=44.60] Epoch 0: 91%|█████████ | 50/55 [01:08<00:06, 0.73it/s, v_num=0, loss_step=44.60] Epoch 0: 91%|█████████ | 50/55 [01:08<00:06, 0.73it/s, v_num=0, loss_step=42.90] Epoch 0: 93%|█████████▎| 51/55 [01:09<00:05, 0.73it/s, v_num=0, loss_step=42.90] Epoch 0: 93%|█████████▎| 51/55 [01:09<00:05, 0.73it/s, v_num=0, loss_step=43.80] Epoch 0: 95%|█████████▍| 52/55 [01:10<00:04, 0.73it/s, v_num=0, loss_step=43.80] Epoch 0: 95%|█████████▍| 52/55 [01:10<00:04, 0.73it/s, v_num=0, loss_step=43.90] Epoch 0: 96%|█████████▋| 53/55 [01:12<00:02, 0.73it/s, v_num=0, loss_step=43.90] Epoch 0: 96%|█████████▋| 53/55 [01:12<00:02, 0.73it/s, v_num=0, loss_step=44.00] Epoch 0: 98%|█████████▊| 54/55 [01:13<00:01, 0.73it/s, v_num=0, loss_step=44.00] Epoch 0: 98%|█████████▊| 54/55 [01:13<00:01, 0.73it/s, v_num=0, loss_step=43.80] Epoch 0: 100%|██████████| 55/55 [01:15<00:00, 0.73it/s, v_num=0, loss_step=43.80] Epoch 0: 100%|██████████| 55/55 [01:15<00:00, 0.73it/s, v_num=0, loss_step=44.60] Validation: | | 0/? [00:00<?, ?it/s] Validation: 0%| | 0/1 [00:00<?, ?it/s] Validation DataLoader 0: 0%| | 0/1 [00:00<?, ?it/s] Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 1.84it/s] Epoch 0: 100%|██████████| 55/55 [01:15<00:00, 0.73it/s, v_num=0, loss_step=44.60, val_loss=44.00] Epoch 0: 100%|██████████| 55/55 [01:15<00:00, 0.73it/s, v_num=0, loss_step=44.60, val_loss=44.00, loss_epoch=49.70] Epoch 0: 100%|██████████| 55/55 [01:15<00:00, 0.73it/s, v_num=0, loss_step=44.60, val_loss=44.00, loss_epoch=49.70] .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 17.108 seconds) **Estimated memory usage:** 641 MB .. _sphx_glr_download_auto_examples_01_tutorials_4_train_model.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 4_train_model.ipynb <4_train_model.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 4_train_model.py <4_train_model.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 4_train_model.zip <4_train_model.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_