Source code for myogestic.models.definitions.catboost_models
"""
This module contains the functions to save, load, train and predict using CatBoost models.
"""
from typing import Union
import numpy as np
from catboost.core import _CatBoostBase
from myogestic.gui.widgets.logger import CustomLogger
[docs]
def save(model_path: str, model: _CatBoostBase) -> str:
"""
Save a CatBoost model.
Parameters
----------
model_path: str
The path to save the model.
model: catboost.core._CatBoostBase
The CatBoost model to save.
Returns
-------
str
The path where the model was saved.
"""
output_model_path = str(model_path).split(".")[0] + "_model" + ".cbm"
model.save_model(output_model_path)
return output_model_path
[docs]
def load(model_path: str, model: _CatBoostBase) -> _CatBoostBase:
"""
Load a CatBoost model.
Parameters
----------
model_path: str
The path to load the model.
model: _CatBoostBase
A new instance of the CatBoost model. This instance is used to load the model.
Returns
-------
_CatBoostBase
The loaded CatBoost model.
"""
model.load_model(model_path)
return model
[docs]
def train(
model: _CatBoostBase, dataset: dict, is_classifier: bool, logger: CustomLogger
) -> _CatBoostBase:
"""
Train a CatBoost model.
Parameters
----------
model: _CatBoostBase
The CatBoost model to train.
dataset: dict
The dataset to train the model.
is_classifier: bool
If the model is a classifier.
logger: CustomLogger
The logger to use.
Returns
-------
_CatBoostBase
The trained CatBoost model.
"""
x_train = dataset["emg"][()]
x_train = np.reshape(
x_train, (x_train.shape[0], x_train.shape[1] * x_train.shape[2])
)
if is_classifier:
y_train = dataset["classes"][()]
else:
y_train = dataset["kinematics"][()]
# add small noise to the target to avoid errors
y_train[y_train == 0] = np.random.uniform(
0.0001, 0.001, y_train[y_train == 0].shape
)
model.fit(x_train, y_train, log_cerr=logger.print, log_cout=logger.print)
return model
[docs]
def predict(
model: _CatBoostBase, input: np.ndarray, is_classifier: bool
) -> Union[np.array, list[float]]:
"""
Predict with a CatBoost model.
Parameters
----------
model: _CatBoostBase
The CatBoost model to predict with.
input: np.ndarray
The input data to predict. The shape of the input data will be (1, n_features, n_samples).
is_classifier: bool
If the model is a classifier.
Returns
-------
Union[np.array, list[float]]
The prediction. If the model is a classifier, the prediction will be a np.array.
If the model is a regressor, the prediction will be a list of floats.
"""
prediction = model.predict(
np.reshape(input, (input.shape[0], input.shape[1] * input.shape[2]))
)
if is_classifier:
try:
prediction = prediction[0, 0]
except IndexError:
prediction = prediction[0]
return prediction
return list(prediction[0])