Source code for myogestic.models.definitions.raulnet_models

"""
This module contains the functions to save, load, train and predict using RaulNet models.
"""
import multiprocessing
import platform
from pathlib import Path
from typing import Any

import lightning as L
import numpy as np
import torch
from lightning.pytorch.callbacks import StochasticWeightAveraging, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from myoverse.datasets.filters.generic import IndexDataFilter
from myoverse.datasets.loader import EMGDatasetLoader
from myoverse.datatypes import _Data

from myogestic.gui.widgets.logger import CustomLogger


[docs] def save(_: str, __: L.LightningModule) -> str: """ Save a RaulNet model. .. note:: Saving the model is not necessary as the model is saved automatically by PyTorch Lightning. This function only returns the path of the last saved model. Parameters ---------- _: str The path to save the model. __: L.LightningModule The RaulNet model to save. Returns ------- str The path where the model was saved. """ return sorted( list(Path("data/logs/RaulNet_models/").rglob("last.ckpt")), key=lambda x: int(x.parts[-3].split("_")[-1]), )[-1]
def save_per_finger(_: str, __: L.LightningModule) -> str: """ Save a RaulNet model. .. note:: Saving the model is not necessary as the model is saved automatically by PyTorch Lightning. This function only returns the path of the last saved model. Parameters ---------- _: str The path to save the model. __: L.LightningModule The RaulNet model to save. Returns ------- str The path where the model was saved. """ parts = list( sorted( list(Path("data/logs/RaulNet_models_per_finger/").glob("*_*")), key=lambda x: int(x.parts[-1].split("_")[0]), )[-1].parts ) parts[-1] = parts[-1].split("_")[0] return str(Path(*parts))
[docs] def load(model_path: str, model: L.LightningModule) -> L.LightningModule: """ load a RaulNet model. Parameters ---------- model_path: str The path to load the model. model: _CatBoostBase A new instance of the CatBoost model. This instance is used to load the model. Returns ------- _CatBoostBase The loaded RaulNet model. """ return ( model.__class__.load_from_checkpoint(model_path) .to("cuda" if torch.cuda.is_available() else "cpu") .eval() .requires_grad_(False) )
def load_per_finger( model_path: str, model: L.LightningModule ) -> list[L.LightningModule]: """ load a RaulNet model. Parameters ---------- model_path: str The path to load the model. model: L.LightningModule A new instance of the CatBoost model. This instance is used to load the model. Returns ------- L.LightningModule The loaded RaulNet model. """ return [ model.__class__.load_from_checkpoint( list(Path(model_path + f"_{i}").rglob("last.ckpt"))[0] ).to("cuda" if torch.cuda.is_available() else "cpu") for i in range(3) ]
[docs] def train( model: L.LightningModule, dataset, _: bool, __: CustomLogger ) -> L.LightningModule: """ Train a RaulNet model. Parameters ---------- model: L.LightningModule The RaulNet model to train. dataset: dict The dataset to train the model with. _: bool If the model is a classifier. __: CustomLogger The logger to log the training process. Returns ------- L.LightningModule The trained RaulNet model. """ torch.set_float32_matmul_precision("medium") torch.backends.cudnn.benchmark = True hparams = model.hparams hparams["input_length__samples"] = dataset["emg"].shape[-1] hparams["nr_of_electrodes_per_grid"] = dataset["emg"].shape[-2] hparams["nr_of_outputs"] = dataset["kinematics"].shape[-1] model = model.__class__(**hparams) class CustomDataClass(_Data): def __init__( self, raw_data, sampling_frequency=dataset["device_information"]["sampling_frequency"], ): # Initialize parent class with raw data super().__init__( raw_data.reshape(1, -1), sampling_frequency, nr_of_dimensions_when_unchunked=2, ) loader = EMGDatasetLoader( Path(r"data/datasets/" + dataset["zarr_file_path"]).resolve(), target_data_class=CustomDataClass, dataloader_params={ "batch_size": 64, "drop_last": True, "num_workers": 0 if platform.system() == "Windows" else multiprocessing.cpu_count() - 1, "pin_memory": True, "persistent_workers": platform.system() != "Windows", }, ) Path("data/logs/").mkdir(parents=True, exist_ok=True) trainer = L.Trainer( accelerator="auto", devices=1, check_val_every_n_epoch=5, callbacks=[ StochasticWeightAveraging( swa_lrs=10 ** (-4), swa_epoch_start=0.5, annealing_epochs=5 ), ModelCheckpoint( monitor="val_loss", mode="min", save_top_k=1, save_last=True ), ], precision="16-mixed", max_epochs=50, logger=CSVLogger( name="RaulNet_models", save_dir=str(Path(r"data/logs/").resolve()) ), enable_checkpointing=True, enable_model_summary=True, deterministic=False, ) trainer.fit(model, datamodule=loader) return model
def train_per_finger(model: L.LightningModule, dataset, _: bool, __: CustomLogger): """ Train a RaulNet model. Parameters ---------- model: L.LightningModule The RaulNet model to train. dataset: dict The dataset to train the model with. _: bool If the model is a classifier. __: CustomLogger The logger to log the training process. Returns ------- L.LightningModule The trained RaulNet model. """ torch.set_float32_matmul_precision("medium") torch.backends.cudnn.benchmark = True Path("data/logs/").mkdir(parents=True, exist_ok=True) # find the latest version of the models try: version_nr = ( int( sorted( list(Path("data/logs/RaulNet_models_per_finger/").glob("*_*")), key=lambda x: int(x.parts[-1].split("_")[0]), )[-1].name.split("_")[0] ) + 1 ) except Exception: version_nr = 0 for i in range(3): hparams = model.hparams hparams["input_length__samples"] = dataset["emg"].shape[-1] model = model.__class__(**hparams) class CustomDataClass(_Data): def __init__( self, raw_data, sampling_frequency=dataset["device_information"]["sampling_frequency"], ): # Initialize parent class with raw data super().__init__( raw_data.reshape(1, -1), sampling_frequency, nr_of_dimensions_when_unchunked=2, ) loader = EMGDatasetLoader( Path(r"data/datasets/" + dataset["zarr_file_path"]).resolve(), target_data_class=CustomDataClass, dataloader_params={ "batch_size": 64, "drop_last": True, "num_workers": 0 if platform.system() == "Windows" else multiprocessing.cpu_count() - 1, "pin_memory": True, "persistent_workers": platform.system() != "Windows", }, target_augmentation_pipeline=[ [ IndexDataFilter( indices=(0, [i + 1]), is_output=True, input_is_chunked=False ) ] ], ) trainer = L.Trainer( accelerator="auto", devices=1, check_val_every_n_epoch=5, callbacks=[ StochasticWeightAveraging( swa_lrs=10 ** (-4), swa_epoch_start=0.5, annealing_epochs=5 ), ModelCheckpoint( monitor="val_loss", mode="min", save_top_k=1, save_last=True ), ], precision="16-mixed", max_epochs=20, logger=CSVLogger( name="RaulNet_models_per_finger", save_dir=str(Path(r"data/logs/").resolve()), version=f"{version_nr}_{i}", ), enable_checkpointing=True, enable_model_summary=True, deterministic=False, ) trainer.fit(model, datamodule=loader) return
[docs] def predict( model: L.LightningModule, input: np.ndarray, is_classifier: bool ) -> list[Any] | None: """ Predict with a RaulNet model. Parameters ---------- model: L.LightningModule The RaulNet model to predict with. input: np.ndarray The input data to predict. The shape of the input data will be (1, n_features, n_samples). is_classifier If the model is a classifier. Returns ------- list[float] The predicted output. """ if not is_classifier: with torch.inference_mode(): return list( model( torch.from_numpy(input) .to(torch.float32) .to(model.device)[None, ...] ) .detach() .cpu() .numpy()[0] ) return None
def predict_per_finger( model: list[L.LightningModule], input: np.ndarray, is_classifier: bool ) -> list[int] | None: """ Predict with a RaulNet model. Parameters ---------- model: list[L.LightningModule] The RaulNet model to predict with. input: np.ndarray The input data to predict. The shape of the input data will be (1, n_features, n_samples). is_classifier If the model is a classifier. Returns ------- list[float] The predicted output. """ if not is_classifier: with torch.inference_mode(): return ( [0] + list( torch.concatenate( [ model[i]( torch.from_numpy(input) .to(torch.float32) .to(model[i].device)[None, ...] ) for i in range(3) ] ) .detach() .cpu() .numpy()[:, 0] ) + [0] ) return None