.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/01_tutorials/4_train_model.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_01_tutorials_4_train_model.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_01_tutorials_4_train_model.py:


Training a deep learning model
===========================

This example shows how to train a deep learning model using the dataset created in the previous example.

.. GENERATED FROM PYTHON SOURCE LINES 9-18

Loading the dataset
-------------------
To load the dataset we need to use the EMGDatasetLoader class.

Two parameters are required:

- data_path: Path to the dataset file.

- dataloader_parameters: Parameters for the DataLoader.

.. GENERATED FROM PYTHON SOURCE LINES 18-37

.. code-block:: Python

    from pathlib import Path

    from myoverse.datasets.loader import EMGDatasetLoader
    from myoverse.datatypes import _Data


    # Create a class to handle the target data that doesn't enforce a specific shape
    class CustomDataClass(_Data):
        def __init__(self, raw_data, sampling_frequency=None):
            # Initialize parent class with raw data
            super().__init__(raw_data.reshape(1, 60), sampling_frequency, nr_of_dimensions_when_unchunked=2)

    # Let's use the built-in IdentityFilter which just passes data through
    loader = EMGDatasetLoader(
        Path(r"../data/dataset.zarr").resolve(),
        dataloader_params={"batch_size": 16, "drop_last": True},
        target_data_class=CustomDataClass,
    )








.. GENERATED FROM PYTHON SOURCE LINES 38-40

Training the model
------------------

.. GENERATED FROM PYTHON SOURCE LINES 40-70

.. code-block:: Python

    from myoverse.models.definitions.raul_net.online.v16 import RaulNetV16
    import lightning as L

    # Create the model
    model = RaulNetV16(
        learning_rate=1e-4,
        nr_of_input_channels=2,
        input_length__samples=192,
        nr_of_outputs=60,
        nr_of_electrode_grids=5,
        nr_of_electrodes_per_grid=64,
        # Multiply following by 4, 8, 16 to have a useful network
        cnn_encoder_channels=(4, 1, 1),
        mlp_encoder_channels=(8, 8),
        event_search_kernel_length=31,
        event_search_kernel_stride=8,
    )

    trainer = L.Trainer(
        accelerator="auto",
        devices=1,
        precision="16-mixed",
        max_epochs=1,
        log_every_n_steps=50,
        logger=None,
        enable_checkpointing=False,
        deterministic=False,
    )

    trainer.fit(model, datamodule=loader)




.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
    /home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.

    Sanity Checking: |          | 0/? [00:00<?, ?it/s]/home/runner/work/MyoVerse/MyoVerse/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.

    Sanity Checking:   0%|          | 0/1 [00:00<?, ?it/s]
    Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]
    Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.72it/s]
                                                                           

    Training: |          | 0/? [00:00<?, ?it/s]
    Training:   0%|          | 0/55 [00:00<?, ?it/s]
    Epoch 0:   0%|          | 0/55 [00:00<?, ?it/s] 
    Epoch 0:   2%|▏         | 1/55 [00:01<01:15,  0.71it/s]
    Epoch 0:   2%|▏         | 1/55 [00:01<01:15,  0.71it/s, v_num=0, loss_step=54.20]
    Epoch 0:   4%|▎         | 2/55 [00:02<01:13,  0.72it/s, v_num=0, loss_step=54.20]
    Epoch 0:   4%|▎         | 2/55 [00:02<01:13,  0.72it/s, v_num=0, loss_step=55.40]
    Epoch 0:   5%|▌         | 3/55 [00:04<01:11,  0.72it/s, v_num=0, loss_step=55.40]
    Epoch 0:   5%|▌         | 3/55 [00:04<01:11,  0.72it/s, v_num=0, loss_step=53.70]
    Epoch 0:   7%|▋         | 4/55 [00:05<01:09,  0.73it/s, v_num=0, loss_step=53.70]
    Epoch 0:   7%|▋         | 4/55 [00:05<01:09,  0.73it/s, v_num=0, loss_step=54.30]
    Epoch 0:   9%|▉         | 5/55 [00:06<01:08,  0.73it/s, v_num=0, loss_step=54.30]
    Epoch 0:   9%|▉         | 5/55 [00:06<01:08,  0.73it/s, v_num=0, loss_step=54.40]
    Epoch 0:  11%|█         | 6/55 [00:08<01:07,  0.73it/s, v_num=0, loss_step=54.40]
    Epoch 0:  11%|█         | 6/55 [00:08<01:07,  0.73it/s, v_num=0, loss_step=54.50]
    Epoch 0:  13%|█▎        | 7/55 [00:09<01:05,  0.73it/s, v_num=0, loss_step=54.50]
    Epoch 0:  13%|█▎        | 7/55 [00:09<01:05,  0.73it/s, v_num=0, loss_step=54.20]
    Epoch 0:  15%|█▍        | 8/55 [00:10<01:04,  0.73it/s, v_num=0, loss_step=54.20]
    Epoch 0:  15%|█▍        | 8/55 [00:10<01:04,  0.73it/s, v_num=0, loss_step=53.50]
    Epoch 0:  16%|█▋        | 9/55 [00:12<01:03,  0.73it/s, v_num=0, loss_step=53.50]
    Epoch 0:  16%|█▋        | 9/55 [00:12<01:03,  0.73it/s, v_num=0, loss_step=54.70]
    Epoch 0:  18%|█▊        | 10/55 [00:13<01:01,  0.73it/s, v_num=0, loss_step=54.70]
    Epoch 0:  18%|█▊        | 10/55 [00:13<01:01,  0.73it/s, v_num=0, loss_step=54.20]
    Epoch 0:  20%|██        | 11/55 [00:15<01:00,  0.73it/s, v_num=0, loss_step=54.20]
    Epoch 0:  20%|██        | 11/55 [00:15<01:00,  0.73it/s, v_num=0, loss_step=54.10]
    Epoch 0:  22%|██▏       | 12/55 [00:16<00:58,  0.73it/s, v_num=0, loss_step=54.10]
    Epoch 0:  22%|██▏       | 12/55 [00:16<00:58,  0.73it/s, v_num=0, loss_step=54.80]
    Epoch 0:  24%|██▎       | 13/55 [00:17<00:57,  0.73it/s, v_num=0, loss_step=54.80]
    Epoch 0:  24%|██▎       | 13/55 [00:17<00:57,  0.73it/s, v_num=0, loss_step=54.90]
    Epoch 0:  25%|██▌       | 14/55 [00:19<00:56,  0.73it/s, v_num=0, loss_step=54.90]
    Epoch 0:  25%|██▌       | 14/55 [00:19<00:56,  0.73it/s, v_num=0, loss_step=55.30]
    Epoch 0:  27%|██▋       | 15/55 [00:20<00:54,  0.73it/s, v_num=0, loss_step=55.30]
    Epoch 0:  27%|██▋       | 15/55 [00:20<00:54,  0.73it/s, v_num=0, loss_step=53.30]
    Epoch 0:  29%|██▉       | 16/55 [00:21<00:53,  0.73it/s, v_num=0, loss_step=53.30]
    Epoch 0:  29%|██▉       | 16/55 [00:21<00:53,  0.73it/s, v_num=0, loss_step=53.70]
    Epoch 0:  31%|███       | 17/55 [00:23<00:52,  0.73it/s, v_num=0, loss_step=53.70]
    Epoch 0:  31%|███       | 17/55 [00:23<00:52,  0.73it/s, v_num=0, loss_step=53.40]
    Epoch 0:  33%|███▎      | 18/55 [00:24<00:50,  0.73it/s, v_num=0, loss_step=53.40]
    Epoch 0:  33%|███▎      | 18/55 [00:24<00:50,  0.73it/s, v_num=0, loss_step=54.30]
    Epoch 0:  35%|███▍      | 19/55 [00:26<00:49,  0.73it/s, v_num=0, loss_step=54.30]
    Epoch 0:  35%|███▍      | 19/55 [00:26<00:49,  0.73it/s, v_num=0, loss_step=54.30]
    Epoch 0:  36%|███▋      | 20/55 [00:27<00:47,  0.73it/s, v_num=0, loss_step=54.30]
    Epoch 0:  36%|███▋      | 20/55 [00:27<00:47,  0.73it/s, v_num=0, loss_step=52.70]
    Epoch 0:  38%|███▊      | 21/55 [00:28<00:46,  0.73it/s, v_num=0, loss_step=52.70]
    Epoch 0:  38%|███▊      | 21/55 [00:28<00:46,  0.73it/s, v_num=0, loss_step=53.30]
    Epoch 0:  40%|████      | 22/55 [00:30<00:45,  0.73it/s, v_num=0, loss_step=53.30]
    Epoch 0:  40%|████      | 22/55 [00:30<00:45,  0.73it/s, v_num=0, loss_step=53.30]
    Epoch 0:  42%|████▏     | 23/55 [00:31<00:43,  0.73it/s, v_num=0, loss_step=53.30]
    Epoch 0:  42%|████▏     | 23/55 [00:31<00:43,  0.73it/s, v_num=0, loss_step=52.40]
    Epoch 0:  44%|████▎     | 24/55 [00:32<00:42,  0.73it/s, v_num=0, loss_step=52.40]
    Epoch 0:  44%|████▎     | 24/55 [00:32<00:42,  0.73it/s, v_num=0, loss_step=52.90]
    Epoch 0:  45%|████▌     | 25/55 [00:34<00:41,  0.73it/s, v_num=0, loss_step=52.90]
    Epoch 0:  45%|████▌     | 25/55 [00:34<00:41,  0.73it/s, v_num=0, loss_step=51.50]
    Epoch 0:  47%|████▋     | 26/55 [00:35<00:39,  0.73it/s, v_num=0, loss_step=51.50]
    Epoch 0:  47%|████▋     | 26/55 [00:35<00:39,  0.73it/s, v_num=0, loss_step=52.70]
    Epoch 0:  49%|████▉     | 27/55 [00:36<00:38,  0.73it/s, v_num=0, loss_step=52.70]
    Epoch 0:  49%|████▉     | 27/55 [00:36<00:38,  0.73it/s, v_num=0, loss_step=51.40]
    Epoch 0:  51%|█████     | 28/55 [00:38<00:36,  0.73it/s, v_num=0, loss_step=51.40]
    Epoch 0:  51%|█████     | 28/55 [00:38<00:36,  0.73it/s, v_num=0, loss_step=50.20]
    Epoch 0:  53%|█████▎    | 29/55 [00:39<00:35,  0.73it/s, v_num=0, loss_step=50.20]
    Epoch 0:  53%|█████▎    | 29/55 [00:39<00:35,  0.73it/s, v_num=0, loss_step=50.20]
    Epoch 0:  55%|█████▍    | 30/55 [00:41<00:34,  0.73it/s, v_num=0, loss_step=50.20]
    Epoch 0:  55%|█████▍    | 30/55 [00:41<00:34,  0.73it/s, v_num=0, loss_step=48.50]
    Epoch 0:  56%|█████▋    | 31/55 [00:42<00:32,  0.73it/s, v_num=0, loss_step=48.50]
    Epoch 0:  56%|█████▋    | 31/55 [00:42<00:32,  0.73it/s, v_num=0, loss_step=48.90]
    Epoch 0:  58%|█████▊    | 32/55 [00:43<00:31,  0.73it/s, v_num=0, loss_step=48.90]
    Epoch 0:  58%|█████▊    | 32/55 [00:43<00:31,  0.73it/s, v_num=0, loss_step=48.70]
    Epoch 0:  60%|██████    | 33/55 [00:45<00:30,  0.73it/s, v_num=0, loss_step=48.70]
    Epoch 0:  60%|██████    | 33/55 [00:45<00:30,  0.73it/s, v_num=0, loss_step=48.50]
    Epoch 0:  62%|██████▏   | 34/55 [00:46<00:28,  0.73it/s, v_num=0, loss_step=48.50]
    Epoch 0:  62%|██████▏   | 34/55 [00:46<00:28,  0.73it/s, v_num=0, loss_step=46.40]
    Epoch 0:  64%|██████▎   | 35/55 [00:47<00:27,  0.73it/s, v_num=0, loss_step=46.40]
    Epoch 0:  64%|██████▎   | 35/55 [00:47<00:27,  0.73it/s, v_num=0, loss_step=46.80]
    Epoch 0:  65%|██████▌   | 36/55 [00:49<00:25,  0.73it/s, v_num=0, loss_step=46.80]
    Epoch 0:  65%|██████▌   | 36/55 [00:49<00:25,  0.73it/s, v_num=0, loss_step=46.80]
    Epoch 0:  67%|██████▋   | 37/55 [00:50<00:24,  0.73it/s, v_num=0, loss_step=46.80]
    Epoch 0:  67%|██████▋   | 37/55 [00:50<00:24,  0.73it/s, v_num=0, loss_step=45.70]
    Epoch 0:  69%|██████▉   | 38/55 [00:51<00:23,  0.73it/s, v_num=0, loss_step=45.70]
    Epoch 0:  69%|██████▉   | 38/55 [00:51<00:23,  0.73it/s, v_num=0, loss_step=46.30]
    Epoch 0:  71%|███████   | 39/55 [00:53<00:21,  0.73it/s, v_num=0, loss_step=46.30]
    Epoch 0:  71%|███████   | 39/55 [00:53<00:21,  0.73it/s, v_num=0, loss_step=45.90]
    Epoch 0:  73%|███████▎  | 40/55 [00:54<00:20,  0.73it/s, v_num=0, loss_step=45.90]
    Epoch 0:  73%|███████▎  | 40/55 [00:54<00:20,  0.73it/s, v_num=0, loss_step=45.80]
    Epoch 0:  75%|███████▍  | 41/55 [00:55<00:19,  0.73it/s, v_num=0, loss_step=45.80]
    Epoch 0:  75%|███████▍  | 41/55 [00:55<00:19,  0.73it/s, v_num=0, loss_step=46.30]
    Epoch 0:  76%|███████▋  | 42/55 [00:57<00:17,  0.73it/s, v_num=0, loss_step=46.30]
    Epoch 0:  76%|███████▋  | 42/55 [00:57<00:17,  0.73it/s, v_num=0, loss_step=44.90]
    Epoch 0:  78%|███████▊  | 43/55 [00:58<00:16,  0.73it/s, v_num=0, loss_step=44.90]
    Epoch 0:  78%|███████▊  | 43/55 [00:58<00:16,  0.73it/s, v_num=0, loss_step=44.10]
    Epoch 0:  80%|████████  | 44/55 [01:00<00:15,  0.73it/s, v_num=0, loss_step=44.10]
    Epoch 0:  80%|████████  | 44/55 [01:00<00:15,  0.73it/s, v_num=0, loss_step=44.60]
    Epoch 0:  82%|████████▏ | 45/55 [01:01<00:13,  0.73it/s, v_num=0, loss_step=44.60]
    Epoch 0:  82%|████████▏ | 45/55 [01:01<00:13,  0.73it/s, v_num=0, loss_step=45.10]
    Epoch 0:  84%|████████▎ | 46/55 [01:02<00:12,  0.73it/s, v_num=0, loss_step=45.10]
    Epoch 0:  84%|████████▎ | 46/55 [01:02<00:12,  0.73it/s, v_num=0, loss_step=43.70]
    Epoch 0:  85%|████████▌ | 47/55 [01:04<00:10,  0.73it/s, v_num=0, loss_step=43.70]
    Epoch 0:  85%|████████▌ | 47/55 [01:04<00:10,  0.73it/s, v_num=0, loss_step=44.60]
    Epoch 0:  87%|████████▋ | 48/55 [01:05<00:09,  0.73it/s, v_num=0, loss_step=44.60]
    Epoch 0:  87%|████████▋ | 48/55 [01:05<00:09,  0.73it/s, v_num=0, loss_step=44.50]
    Epoch 0:  89%|████████▉ | 49/55 [01:06<00:08,  0.73it/s, v_num=0, loss_step=44.50]
    Epoch 0:  89%|████████▉ | 49/55 [01:06<00:08,  0.73it/s, v_num=0, loss_step=44.60]
    Epoch 0:  91%|█████████ | 50/55 [01:08<00:06,  0.73it/s, v_num=0, loss_step=44.60]
    Epoch 0:  91%|█████████ | 50/55 [01:08<00:06,  0.73it/s, v_num=0, loss_step=42.90]
    Epoch 0:  93%|█████████▎| 51/55 [01:09<00:05,  0.73it/s, v_num=0, loss_step=42.90]
    Epoch 0:  93%|█████████▎| 51/55 [01:09<00:05,  0.73it/s, v_num=0, loss_step=43.80]
    Epoch 0:  95%|█████████▍| 52/55 [01:10<00:04,  0.73it/s, v_num=0, loss_step=43.80]
    Epoch 0:  95%|█████████▍| 52/55 [01:10<00:04,  0.73it/s, v_num=0, loss_step=43.90]
    Epoch 0:  96%|█████████▋| 53/55 [01:12<00:02,  0.73it/s, v_num=0, loss_step=43.90]
    Epoch 0:  96%|█████████▋| 53/55 [01:12<00:02,  0.73it/s, v_num=0, loss_step=44.00]
    Epoch 0:  98%|█████████▊| 54/55 [01:13<00:01,  0.73it/s, v_num=0, loss_step=44.00]
    Epoch 0:  98%|█████████▊| 54/55 [01:13<00:01,  0.73it/s, v_num=0, loss_step=43.80]
    Epoch 0: 100%|██████████| 55/55 [01:15<00:00,  0.73it/s, v_num=0, loss_step=43.80]
    Epoch 0: 100%|██████████| 55/55 [01:15<00:00,  0.73it/s, v_num=0, loss_step=44.60]

    Validation: |          | 0/? [00:00<?, ?it/s]

    Validation:   0%|          | 0/1 [00:00<?, ?it/s]

    Validation DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

    Validation DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  1.84it/s]

                                                                          
    Epoch 0: 100%|██████████| 55/55 [01:15<00:00,  0.73it/s, v_num=0, loss_step=44.60, val_loss=44.00]
    Epoch 0: 100%|██████████| 55/55 [01:15<00:00,  0.73it/s, v_num=0, loss_step=44.60, val_loss=44.00, loss_epoch=49.70]
    Epoch 0: 100%|██████████| 55/55 [01:15<00:00,  0.73it/s, v_num=0, loss_step=44.60, val_loss=44.00, loss_epoch=49.70]





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (1 minutes 17.108 seconds)

**Estimated memory usage:**  641 MB


.. _sphx_glr_download_auto_examples_01_tutorials_4_train_model.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: 4_train_model.ipynb <4_train_model.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: 4_train_model.py <4_train_model.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: 4_train_model.zip <4_train_model.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_