Source code for myogestic.models.definitions.sklearn_models
"""
This module contains the functions to save, load, train and predict using sklearn models.
"""
from typing import Union
import joblib
import numpy as np
from myogestic.gui.widgets.logger import CustomLogger
[docs]
def save(model_path: str, model: object) -> str:
"""
Save a sklearn model.
Parameters
----------
model_path: str
The path to save the model.
model: Any
The sklearn model to save.
Returns
-------
str
The path where the model was saved.
"""
output_model_path: str = str(model_path).split(".")[0] + "_model" + ".pkl"
joblib.dump(model, output_model_path)
return output_model_path
[docs]
def load(model_path: str, _: object) -> object:
"""
Load a sklearn model.
Parameters
----------
model_path: str
The path to load the model.
_: Any
A new instance of the sklearn model. This instance is not used to load the model.
Returns
-------
Any
The loaded sklearn model
"""
with open(model_path, "rb") as f:
model = joblib.load(f)
return model
[docs]
def train(model: object, dataset, is_classifier: bool, _: CustomLogger) -> object:
"""
Train a sklearn model.
Parameters
----------
model: Any
The sklearn model to train.
dataset
_: CustomLogger
The logger to log the training process. This parameter is not used.
is_classifier: bool
Whether the model is a classifier.
Returns
-------
Any
The trained sklearn model.
"""
x_train = dataset["emg"][()]
x_train = np.reshape(
x_train, (x_train.shape[0], x_train.shape[1] * x_train.shape[2])
)
if is_classifier:
y_train = dataset["classes"][()]
else:
y_train = dataset["kinematics"][()]
# add small noise to the target to avoid errors
y_train[y_train == 0] = np.random.uniform(
0.0001, 0.001, y_train[y_train == 0].shape
)
model.fit(x_train, y_train)
return model
[docs]
def predict(
model: object, input: np.ndarray, is_classifier: bool
) -> Union[np.ndarray, list[float]]:
"""
Predict with a sklearn model.
Parameters
----------
model: Any
The sklearn model to predict with.
input: np.ndarray
The input data to predict.
is_classifier: bool
Whether the model is a classifier.
Returns
-------
Union[np.ndarray, list[float]]
The prediction of the model. If the model is a classifier, the prediction is a np.array.
Otherwise, the prediction is a list of floats.
"""
prediction = model.predict(
np.reshape(input, (input.shape[0], input.shape[1] * input.shape[2]))
)
if is_classifier:
try:
prediction = prediction[0, 0]
except IndexError:
prediction = prediction[0]
return prediction
return list(prediction[0])