.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/4_train_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_4_train_model.py: Training a deep learning model =========================== This example shows how to train a deep learning model using the dataset created in the previous example. .. GENERATED FROM PYTHON SOURCE LINES 9-18 Loading the dataset ------------------- To load the dataset we need to use the EMGDatasetLoader class. Two parameters are required: - data_path: Path to the dataset file. - dataloader_parameters: Parameters for the DataLoader. .. GENERATED FROM PYTHON SOURCE LINES 18-23 .. code-block:: Python from pathlib import Path from doc_octopy.datasets.loader import EMGDatasetLoader loader = EMGDatasetLoader(Path("data/dataset.zarr").resolve(), dataloader_parameters={"batch_size": 16, "drop_last": True}) .. GENERATED FROM PYTHON SOURCE LINES 24-26 Training the model ------------------ .. GENERATED FROM PYTHON SOURCE LINES 26-58 .. code-block:: Python from doc_octopy.models.definitions.raul_net.online.v16 import RaulNetV16 import lightning as L # Create the model model = RaulNetV16( learning_rate=1e-4, nr_of_input_channels=2, input_length__samples=192, nr_of_outputs=60, nr_of_electrode_grids=5, nr_of_electrodes_per_grid=64, # Multiply following by 4, 8, 16 to have a useful network cnn_encoder_channels=(4, 1, 1), mlp_encoder_channels=(8, 8), event_search_kernel_length=31, event_search_kernel_stride=8, ) trainer = L.Trainer( accelerator="auto", devices=1, precision="16-mixed", max_epochs=1, log_every_n_steps=50, logger=None, enable_checkpointing=False, deterministic=False, ) trainer.fit(model, datamodule=loader) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/DocOctopy/DocOctopy/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:512: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead. /home/runner/work/DocOctopy/DocOctopy/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance. Sanity Checking: | | 0/? [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 4_train_model.py <4_train_model.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 4_train_model.zip <4_train_model.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_