Source code for myogestic.models.definitions.catboost_models
"""
This module contains the functions to save, load, train and predict using CatBoost models.
"""
from typing import Union
import numpy as np
from catboost.core import _CatBoostBase
from myogestic.gui.widgets.logger import CustomLogger
def _flatten_3d_to_2d(data: np.ndarray) -> np.ndarray:
"""Flatten 3D data (samples, channels, time) to 2D (samples, features)."""
if data.ndim == 3:
return data.reshape(data.shape[0], data.shape[1] * data.shape[2])
return data
def _extract_classifier_prediction(prediction: np.ndarray):
"""Extract single prediction value for classifier output."""
if prediction.ndim == 2:
return prediction[0, 0]
return prediction[0]
[docs]
def save(model_path: str, model: _CatBoostBase) -> str:
"""
Save a CatBoost model.
Parameters
----------
model_path: str
The path to save the model.
model: catboost.core._CatBoostBase
The CatBoost model to save.
Returns
-------
str
The path where the model was saved.
"""
output_model_path = str(model_path).split(".")[0] + "_model" + ".cbm"
model.save_model(output_model_path)
return output_model_path
[docs]
def load(model_path: str, model: _CatBoostBase) -> _CatBoostBase:
"""
Load a CatBoost model.
Parameters
----------
model_path: str
The path to load the model.
model: _CatBoostBase
A new instance of the CatBoost model. This instance is used to load the model.
Returns
-------
_CatBoostBase
The loaded CatBoost model.
"""
model.load_model(model_path)
return model
[docs]
def train(
model: _CatBoostBase, dataset: dict, is_classifier: bool, logger: CustomLogger
) -> _CatBoostBase:
"""
Train a CatBoost model.
Parameters
----------
model: _CatBoostBase
The CatBoost model to train.
dataset: dict
The dataset to train the model.
is_classifier: bool
If the model is a classifier.
logger: CustomLogger
The logger to use.
Returns
-------
_CatBoostBase
The trained CatBoost model.
"""
x_train = _flatten_3d_to_2d(dataset["emg"][()])
if is_classifier:
y_train = dataset["classes"][()]
else:
y_train = dataset["kinematics"][()]
# Add small noise to zero targets to avoid numerical errors
zero_mask = y_train == 0
y_train[zero_mask] = np.random.uniform(0.0001, 0.001, zero_mask.sum())
model.fit(x_train, y_train, log_cerr=logger.print, log_cout=logger.print)
return model
[docs]
def predict(
model: _CatBoostBase, input: np.ndarray, is_classifier: bool
) -> Union[np.array, list[float]]:
"""
Predict with a CatBoost model.
Parameters
----------
model: _CatBoostBase
The CatBoost model to predict with.
input: np.ndarray
The input data to predict. The shape of the input data will be (1, n_features, n_samples).
is_classifier: bool
If the model is a classifier.
Returns
-------
Union[np.array, list[float]]
The prediction. If the model is a classifier, the prediction will be a np.array.
If the model is a regressor, the prediction will be a list of floats.
"""
input_2d = _flatten_3d_to_2d(input)
prediction = model.predict(input_2d)
if is_classifier:
return _extract_classifier_prediction(prediction)
return list(prediction[0])