Source code for myogestic.models.definitions.catboost_models
"""This module contains the functions to save, load and train CatBoost models."""importnumpyasnpfromcatboost.coreimport_CatBoostBasefrommyogestic.gui.widgets.loggerimportCustomLogger
[docs]defsave(model_path:str,model:_CatBoostBase)->str:""" Save a CatBoost model. Parameters ---------- model_path: str The path to save the model. model: catboost.core._CatBoostBase The CatBoost model to save. Returns ------- str The path where the model was saved. """output_model_path:str=model_path.split(".")[0]+"_model"+".cbm"model.save_model(output_model_path)returnoutput_model_path
[docs]defload(model_path:str,model:_CatBoostBase)->_CatBoostBase:""" Load a CatBoost model. Parameters ---------- model_path: str The path to load the model. model: _CatBoostBase A new instance of the CatBoost model. This instance is used to load the model. Returns ------- _CatBoostBase The loaded CatBoost model. """model.load_model(model_path)returnmodel
[docs]deftrain(model:_CatBoostBase,x_train:np.ndarray,y_train:np.ndarray,logger:CustomLogger)->_CatBoostBase:""" Train a CatBoost model. Parameters ---------- model: _CatBoostBase The CatBoost model to train. x_train: np.ndarray The training data. y_train: np.ndarray The training ground truth. logger: CustomLogger The logger to use. Returns ------- _CatBoostBase The trained CatBoost model. """ifmodel.__class__.__name__=="CatBoostRegressor":# add small noise to the target to avoid errorsy_train[y_train==0]=np.random.uniform(0.0001,0.001,y_train[y_train==0].shape)model.fit(x_train,y_train,log_cerr=logger.print,log_cout=logger.print)returnmodel