"""
This module contains the functions to save, load, train and predict using RaulNet models.
"""
from pathlib import Path
import lightning as L
import numpy as np
import torch
from myoverse.datasets.filters.generic import IndexDataFilter
from myoverse.datasets.loader import EMGDatasetLoader
from lightning.pytorch.callbacks import StochasticWeightAveraging, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from myogestic.gui.widgets.logger import CustomLogger
[docs]
def save(_: str, __: L.LightningModule) -> str:
"""
Save a RaulNet model.
.. note:: Saving the model is not necessary as the model is saved automatically by PyTorch Lightning. This function only returns the path of the last saved model.
Parameters
----------
_: str
The path to save the model.
__: L.LightningModule
The RaulNet model to save.
Returns
-------
str
The path where the model was saved.
"""
return sorted(
list(Path("data/logs/RaulNet_models/").rglob("last.ckpt")),
key=lambda x: int(x.parts[-3].split("_")[-1]),
)[-1]
def save_per_finger(_: str, __: L.LightningModule) -> str:
"""
Save a RaulNet model.
.. note:: Saving the model is not necessary as the model is saved automatically by PyTorch Lightning. This function only returns the path of the last saved model.
Parameters
----------
_: str
The path to save the model.
__: L.LightningModule
The RaulNet model to save.
Returns
-------
str
The path where the model was saved.
"""
parts = list(
sorted(
list(Path("data/logs/RaulNet_models_per_finger/").glob("*_*")),
key=lambda x: int(x.parts[-1].split("_")[0]),
)[-1].parts
)
parts[-1] = parts[-1].split("_")[0]
return str(Path(*parts))
[docs]
def load(model_path: str, model: L.LightningModule) -> L.LightningModule:
"""
load a RaulNet model.
Parameters
----------
model_path: str
The path to load the model.
model: _CatBoostBase
A new instance of the CatBoost model. This instance is used to load the model.
Returns
-------
_CatBoostBase
The loaded RaulNet model.
"""
return model.__class__.load_from_checkpoint(model_path).to(
"cuda" if torch.cuda.is_available() else "cpu"
)
def load_per_finger(
model_path: str, model: L.LightningModule
) -> list[L.LightningModule]:
"""
load a RaulNet model.
Parameters
----------
model_path: str
The path to load the model.
model: L.LightningModule
A new instance of the CatBoost model. This instance is used to load the model.
Returns
-------
L.LightningModule
The loaded RaulNet model.
"""
return [
model.__class__.load_from_checkpoint(
list(Path(model_path + f"_{i}").rglob("last.ckpt"))[0]
).to("cuda" if torch.cuda.is_available() else "cpu")
for i in range(3)
]
[docs]
def train(
model: L.LightningModule, dataset, _: bool, __: CustomLogger
) -> L.LightningModule:
"""
Train a RaulNet model.
Parameters
----------
model: L.LightningModule
The RaulNet model to train.
dataset: dict
The dataset to train the model with.
_: bool
If the model is a classifier.
__: CustomLogger
The logger to log the training process.
Returns
-------
L.LightningModule
The trained RaulNet model.
"""
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.benchmark = True
hparams = model.hparams
hparams["input_length__samples"] = dataset["emg"].shape[-1]
hparams["nr_of_electrodes_per_grid"] = dataset["emg"].shape[-2]
hparams["nr_of_outputs"] = dataset["kinematics"].shape[-1]
model = model.__class__(**hparams)
loader = EMGDatasetLoader(
Path(r"data/datasets/" + dataset["zarr_file_path"]).resolve(),
dataloader_parameters={
"batch_size": 64,
"drop_last": True,
"num_workers": 10,
"pin_memory": True,
"persistent_workers": True,
},
)
Path("data/logs/").mkdir(parents=True, exist_ok=True)
trainer = L.Trainer(
accelerator="auto",
devices=1,
check_val_every_n_epoch=5,
callbacks=[
StochasticWeightAveraging(
swa_lrs=10 ** (-4), swa_epoch_start=0.5, annealing_epochs=5
),
ModelCheckpoint(
monitor="val_loss", mode="min", save_top_k=1, save_last=True
),
],
precision="16-mixed",
max_epochs=50,
logger=CSVLogger(
name="RaulNet_models", save_dir=str(Path(r"data/logs/").resolve())
),
enable_checkpointing=True,
enable_model_summary=True,
deterministic=False,
)
trainer.fit(model, datamodule=loader)
return model
def train_per_finger(model: L.LightningModule, dataset, _: bool, __: CustomLogger):
"""
Train a RaulNet model.
Parameters
----------
model: L.LightningModule
The RaulNet model to train.
dataset: dict
The dataset to train the model with.
_: bool
If the model is a classifier.
__: CustomLogger
The logger to log the training process.
Returns
-------
L.LightningModule
The trained RaulNet model.
"""
torch.set_float32_matmul_precision("medium")
torch.backends.cudnn.benchmark = True
Path("data/logs/").mkdir(parents=True, exist_ok=True)
# find the latest version of the models
try:
version_nr = (
int(
sorted(
list(Path("data/logs/RaulNet_models_per_finger/").glob("*_*")),
key=lambda x: int(x.parts[-1].split("_")[0]),
)[-1].name.split("_")[0]
)
+ 1
)
except Exception:
version_nr = 0
for i in range(3):
hparams = model.hparams
hparams["input_length__samples"] = dataset["emg"].shape[-1]
model = model.__class__(**hparams)
loader = EMGDatasetLoader(
Path(r"data/datasets/" + dataset["zarr_file_path"]).resolve(),
dataloader_parameters={
"batch_size": 64,
"drop_last": True,
"num_workers": 10,
"pin_memory": True,
"persistent_workers": True,
},
ground_truth_augmentation_pipeline=[
[IndexDataFilter(indices=(0, [i + 1]), is_output=True)]
],
)
trainer = L.Trainer(
accelerator="auto",
devices=1,
check_val_every_n_epoch=5,
callbacks=[
StochasticWeightAveraging(
swa_lrs=10 ** (-4), swa_epoch_start=0.5, annealing_epochs=5
),
ModelCheckpoint(
monitor="val_loss", mode="min", save_top_k=1, save_last=True
),
],
precision="16-mixed",
max_epochs=20,
logger=CSVLogger(
name="RaulNet_models_per_finger",
save_dir=str(Path(r"data/logs/").resolve()),
version=f"{version_nr}_{i}",
),
enable_checkpointing=True,
enable_model_summary=True,
deterministic=False,
)
trainer.fit(model, datamodule=loader)
return
[docs]
def predict(
model: L.LightningModule, input: np.ndarray, is_classifier: bool
) -> list[float]:
"""
Predict with a RaulNet model.
Parameters
----------
model: L.LightningModule
The RaulNet model to predict with.
input: np.ndarray
The input data to predict. The shape of the input data will be (1, n_features, n_samples).
is_classifier
If the model is a classifier.
Returns
-------
list[float]
The predicted output.
"""
if not is_classifier:
with torch.inference_mode():
return list(
model(
torch.from_numpy(input)
.to(torch.float32)
.to(model.device)[None, ...]
)
.detach()
.cpu()
.numpy()[0]
)
def predict_per_finger(
model: list[L.LightningModule], input: np.ndarray, is_classifier: bool
) -> list[float]:
"""
Predict with a RaulNet model.
Parameters
----------
model: list[L.LightningModule]
The RaulNet model to predict with.
input: np.ndarray
The input data to predict. The shape of the input data will be (1, n_features, n_samples).
is_classifier
If the model is a classifier.
Returns
-------
list[float]
The predicted output.
"""
if not is_classifier:
with torch.inference_mode():
return (
[0]
+ list(
torch.concatenate(
[
model[i](
torch.from_numpy(input)
.to(torch.float32)
.to(model[i].device)[None, ...]
)
for i in range(3)
]
)
.detach()
.cpu()
.numpy()[:, 0]
)
+ [0]
)