from __future__ import annotations
import os
import pickle
from datetime import datetime
from functools import partial
from typing import TYPE_CHECKING
import numpy as np
from PySide6.QtCore import QObject, QThread, Qt, Signal
from PySide6.QtWidgets import (
QAbstractItemView,
QCheckBox,
QComboBox,
QDialog,
QDoubleSpinBox,
QFileDialog,
QHBoxLayout,
QHeaderView,
QLabel,
QLineEdit,
QMessageBox,
QPushButton,
QScrollArea,
QSpinBox,
QTableWidget,
QTableWidgetItem,
QVBoxLayout,
QWidget,
)
from myogestic.gui.widgets.logger import LoggerLevel
from myogestic.models.config import (
FEATURES_MAP,
FUNCTIONS_MAP,
MODELS_MAP,
PARAMETERS_MAP,
UnchangeableParameter,
)
from myogestic.models.interface import MyoGesticModelInterface
if TYPE_CHECKING:
from myogestic.gui.myogestic import MyoGestic
[docs]
class PyQtThread(QThread):
"""
Custom PyQt thread class for running
a target function in a separate thread.
Parameters
----------
target : function
The target function to run in the thread.
parent : QObject | None
The parent object of the thread.
Attributes
----------
has_finished_signal : Signal
Signal for emitting when the thread has finished.
progress_bar_signal : Signal
Signal for emitting the progress
of the thread.
"""
has_finished_signal = Signal()
progress_bar_signal = Signal(int)
def __init__(self, target, parent=None) -> None:
super(PyQtThread, self).__init__(parent)
self.t = target
[docs]
def run(self):
self.t()
self.has_finished_signal.emit()
[docs]
def quit(self) -> None:
self.exit(0)
[docs]
class TrainingProtocol(QObject):
"""
Class for handling the training protocol of the MyoGestic application.
Parameters
----------
parent : MyoGestic | None
The parent object of the protocol object.
Attributes
----------
main_window : MyoGestic
The main window of the MyoGestic application.
model_interface : MyoGesticModelInterface
Interface for the models.
selected_recordings : dict[str, dict]
Dictionary for the selected recordings.
selected_dataset_filepath : dict[str, np.ndarray]
Dictionary for the selected dataset file path.
create_dataset_thread : PyQtThread
Thread for creating a dataset.
train_model_thread : PyQtThread
Thread for training a models.
recordings_dir_path : str
Path for the recordings directory.
models_dir_path : str
Path for the models directory.
datasets_dir_path : str
Path for the datasets directory.
no_dataset_selected_info : str
Information for no dataset selected.
"""
def __init__(self, parent: MyoGestic | None = ...) -> None:
super().__init__(parent)
self.main_window = parent
# Initialize Protocol UI
self._setup_protocol_ui()
# Initialize Protocol
self.selected_recordings: dict[str, dict] = None
self.selected_dataset_filepath: dict[str, np.ndarray] = None
# Model interface
self.model_interface = None
self.selected_model_name = "CatBoost Classifier"
self.selected_model, self.model_is_classifier = MODELS_MAP[
self.selected_model_name
]
self.model_changeable_parameters = {}
self.selected_features = ["Root Mean Square"]
# Get configuration update
self.main_window.device_widget.configure_toggled.connect(
self._update_device_configuration
)
# Threads
self.create_dataset_thread = None
self.train_model_thread = None
# File management:
self.recordings_dir_path: str = os.path.join(
self.main_window.base_path, "recordings"
)
self.models_dir_path: str = os.path.join(self.main_window.base_path, "models")
self.datasets_dir_path: str = os.path.join(
self.main_window.base_path, "datasets"
)
if not os.path.exists(self.recordings_dir_path):
os.makedirs(self.recordings_dir_path)
if not os.path.exists(self.models_dir_path):
os.makedirs(self.models_dir_path)
if not os.path.exists(self.datasets_dir_path):
os.makedirs(self.datasets_dir_path)
def _update_device_configuration(self, is_configured: bool) -> None:
if not is_configured:
return
self.model_interface = MyoGesticModelInterface(
device_information=self.main_window.device_widget.get_device_information(),
logger=self.main_window.logger,
)
def _check_if_recordings_selected(self) -> None:
if (
not self.training_create_dataset_selected_recordings_table_widget.selectedItems()
):
self.training_remove_selected_recording_push_button.setEnabled(False)
else:
self.training_remove_selected_recording_push_button.setEnabled(True)
def _remove_selected_recording(self) -> None:
selected_rows = list(
set(
index.row()
for index in self.training_create_dataset_selected_recordings_table_widget.selectedIndexes()
)
)
for row in selected_rows[::-1]:
self.training_create_dataset_selected_recordings_table_widget.removeRow(row)
self._check_if_recordings_selected()
def _remove_all_selected_recordings(self) -> None:
self.training_create_dataset_selected_recordings_table_widget.setRowCount(0)
self._check_if_recordings_selected()
def select_recordings(self) -> None:
self.training_create_dataset_push_button.setEnabled(False)
# Open dialog to select recordings
dialog = QFileDialog(self.main_window)
dialog.setFileMode(QFileDialog.ExistingFiles)
dialog.setNameFilter("Pickle files (*.pkl)")
dialog.setDirectory(self.recordings_dir_path)
filenames, _ = dialog.getOpenFileNames()
self.selected_recordings = {}
self.training_create_dataset_selected_recordings_table_widget.setRowCount(0)
for file in filenames:
with open(file, "rb") as f:
recording: dict = pickle.load(f)
if not recording:
continue
if not isinstance(recording, dict):
continue
keys = recording.keys()
required_keys = [
"emg",
"kinematics",
"timings_kinematics",
"timings_emg",
"label",
"task",
"device",
"bad_channels",
]
if not all(key in keys for key in required_keys):
self.main_window.logger.print(
f" {f} is an invalid recording! \n Missing keys: {set(required_keys) - set(keys)}",
LoggerLevel.ERROR,
)
continue
if recording["device"] != self.main_window.device_name:
self.main_window.logger.print(
f" {f} is not recorded with the current device! \n Please select {recording['device']}.",
LoggerLevel.ERROR,
)
continue
selected_recordings_key = recording["task"].capitalize()
if selected_recordings_key in self.selected_recordings.keys():
current_recording = self.selected_recordings[
selected_recordings_key
]
recording["emg"] = np.hstack(
[current_recording["emg"], recording["emg"]]
)
recording["kinematics"] = np.concatenate(
[current_recording["kinematics"], recording["kinematics"]],
axis=-1,
)
recording["timings_kinematics"] = np.concatenate(
[
current_recording["timings_kinematics"],
recording["timings_kinematics"],
]
)
recording["timings_emg"] = np.concatenate(
[current_recording["timings_emg"], recording["timings_emg"]]
)
recording["label"] += "; " + current_recording["label"]
recording["bad_channels"].extend(current_recording["bad_channels"])
recording["bad_channels"] = list(
set([item for item in recording["bad_channels"]])
)
recording["use_kinematics"] = (
recording["use_kinematics"]
and current_recording["use_kinematics"]
)
self.selected_recordings[selected_recordings_key] = recording
if len(self.selected_recordings) == 0:
self._open_warning_dialog("No valid recordings selected!")
return
for key, item in self.selected_recordings.items():
row_position = (
self.training_create_dataset_selected_recordings_table_widget.rowCount()
)
self.training_create_dataset_selected_recordings_table_widget.insertRow(
row_position
)
# Insert new recording
recording_time = int(
item["emg"].shape[1] / self.main_window.sampling_frequency
)
self.training_create_dataset_selected_recordings_table_widget.setItem(
row_position, 0, QTableWidgetItem(key)
)
self.training_create_dataset_selected_recordings_table_widget.setItem(
row_position, 1, QTableWidgetItem(str(recording_time))
)
self.training_create_dataset_selected_recordings_table_widget.setItem(
row_position, 2, QTableWidgetItem(item["label"])
)
self.training_create_dataset_push_button.setEnabled(True)
def _create_dataset(self) -> None:
if not self.selected_recordings:
self._open_warning_dialog("No recordings selected!")
self.training_create_dataset_push_button.setEnabled(False)
self.training_create_datasets_select_recordings_push_button.setEnabled(False)
self.create_dataset_thread = PyQtThread(
target=self._create_dataset_thread, parent=self.main_window
)
self.create_dataset_thread.has_finished_signal.connect(
self.__create_dataset_thread_finished
)
self.create_dataset_thread.start()
def __create_dataset_thread_finished(self) -> None:
self.training_create_dataset_selected_recordings_table_widget.setRowCount(0)
self.training_create_dataset_label_line_edit.setText("")
self.training_create_datasets_select_recordings_push_button.setEnabled(True)
self.selected_recordings = None
def _create_dataset_thread(self) -> None:
label = self.training_create_dataset_label_line_edit.text()
if not label:
label = "default"
now = datetime.now()
formatted_now = now.strftime("%Y%m%d_%H%M%S%f")
dataset_dict = self.model_interface.create_dataset(
self.selected_recordings, self.selected_features
)
file_name = f"MindMove_Dataset_{formatted_now}_{label.lower()}.pkl"
dataset_dict["dataset_file_path"] = os.path.join(
self.datasets_dir_path, file_name
)
with open(os.path.join(self.datasets_dir_path, file_name), "wb") as f:
pickle.dump(dataset_dict, f)
self.main_window.logger.print(f"Dataset {file_name} created!", LoggerLevel.INFO)
def _select_dataset(self) -> None:
# Open dialog to select dataset
dialog = QFileDialog(self.main_window)
dialog.setFileMode(QFileDialog.ExistingFile)
dialog.setNameFilter("Pickle files (*.pkl)")
dialog.setDirectory(self.datasets_dir_path)
filename, _ = dialog.getOpenFileName()
if not filename:
self._open_warning_dialog(self.no_dataset_selected_info)
self.training_selected_dataset_label.setText(self.no_dataset_selected_info)
return
self.selected_dataset_filepath = filename
self.training_selected_dataset_label.setText(
self.selected_dataset_filepath.split("_")[-1].replace(".pkl", "").title()
)
self.train_model_push_button.setEnabled(True)
def _open_warning_dialog(self, info: str) -> None:
QMessageBox.warning(self.main_window, "Warning", info, QMessageBox.Ok)
def _train_model(self) -> None:
if not self.selected_dataset_filepath:
self._open_warning_dialog(self.no_dataset_selected_info)
return
self.train_model_push_button.setEnabled(False)
self.training_select_dataset_push_button.setEnabled(False)
self.train_model_thread = PyQtThread(
target=self._train_model_thread, parent=self.main_window
)
self.train_model_thread.has_finished_signal.connect(self._train_model_finished)
self.train_model_thread.start()
def _train_model_thread(self) -> None:
label = self.training_model_label_line_edit.text()
if not label:
label = "default"
assert self.selected_dataset_filepath is not None
with open(self.selected_dataset_filepath, "rb") as file:
dataset = pickle.load(file)
if len(self.model_changeable_parameters) == 0:
self.model_changeable_parameters = {
key: value["default_value"]
for key, value in PARAMETERS_MAP[self.selected_model_name][
"changeable"
].items()
}
try:
self.model_interface.train_model(
dataset=dataset,
model_name=self.selected_model_name,
model_parameters={
**PARAMETERS_MAP[self.selected_model_name]["unchangeable"],
**self.model_changeable_parameters,
},
selected_features=self.selected_features,
**FUNCTIONS_MAP[self.selected_model_name],
)
except Exception as e:
self.main_window.logger.print(
f"Error during training: {e}", LoggerLevel.ERROR
)
return
now = datetime.now()
formatted_now = now.strftime("%Y%m%d_%H%M%S%f")
file_name = f"MindMove_Model_{formatted_now}_{label.lower()}.pkl"
model_filepath = os.path.join(self.models_dir_path, file_name)
try:
model_save_dict = self.model_interface.save_model(model_filepath)
except Exception as e:
self.main_window.logger.print(
f"Error during saving models: {e}", LoggerLevel.ERROR
)
return
with open(model_filepath, "wb") as file:
pickle.dump(model_save_dict, file)
self.main_window.logger.print(f"Model trained: {file_name}", LoggerLevel.INFO)
def _train_model_finished(self) -> None:
self.training_selected_dataset_label.setText(self.no_dataset_selected_info)
self.training_select_dataset_push_button.setEnabled(True)
self.selected_dataset_filepath = None
self.training_model_label_line_edit.setText("")
def _update_model_selection(self) -> None:
self.selected_model_name = self.training_model_selection_combo_box.currentText()
self.selected_model, self.model_is_classifier = MODELS_MAP[
self.selected_model_name
]
self.model_changeable_parameters = {
key: value["default_value"]
for key, value in PARAMETERS_MAP[self.selected_model_name][
"changeable"
].items()
}
if len(self.model_changeable_parameters) > 0:
self.training_model_parameters_push_button.setEnabled(True)
else:
self.training_model_parameters_push_button.setEnabled(False)
def _open_model_parameters_popup(self) -> None:
self.popup_window = PopupWindowParameters(self.selected_model_name)
self.popup_window.show()
self.popup_window.finished.connect(self._get_model_parameters)
def _get_model_parameters(self) -> None:
self.model_changeable_parameters = self.popup_window.model_changeable_parameters
def _open_feature_selection_popup(self) -> None:
self.popup_window = PopupWindowFeatures(self.selected_features)
self.popup_window.show()
self.popup_window.finished.connect(self._get_features)
def _get_features(self) -> None:
self.selected_features = list(self.popup_window.selected_features)
def _setup_protocol_ui(self) -> None:
# Create Datasets
self.training_create_dataset_group_box = (
self.main_window.ui.trainingCreateDatasetGroupBox
)
self.training_create_datasets_select_recordings_push_button = (
self.main_window.ui.trainingCreateDatasetsSelectRecordingsPushButton
)
self.training_create_datasets_select_recordings_push_button.clicked.connect(
self.select_recordings
)
# Table Widget for recordings
self.training_create_dataset_selected_recordings_table_widget: QTableWidget = (
self.main_window.ui.trainingCreateDatasetSelectedRecordingsTableWidget
)
self.training_create_dataset_selected_recordings_table_widget.setSelectionBehavior(
QAbstractItemView.SelectRows
)
self.training_create_dataset_selected_recordings_table_widget.itemSelectionChanged.connect(
self._check_if_recordings_selected
)
self.training_create_dataset_selected_recordings_table_widget.verticalHeader().setVisible(
False
)
header = (
self.training_create_dataset_selected_recordings_table_widget.horizontalHeader()
)
header.setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
header.setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
# Reset table
self.training_create_dataset_selected_recordings_table_widget.setRowCount(0)
self.training_remove_selected_recording_push_button: QPushButton = (
self.main_window.ui.trainingRemoveSelectedRecordingPushButton
)
self.training_remove_selected_recording_push_button.clicked.connect(
self._remove_selected_recording
)
self.training_remove_selected_recording_push_button.setEnabled(False)
self.training_remove_all_selected_recordings_push_button: QPushButton = (
self.main_window.ui.trainingRemoveAllSelectedRecordingsPushButton
)
self.training_remove_all_selected_recordings_push_button.clicked.connect(
self._remove_all_selected_recordings
)
self.training_create_dataset_select_features_push_button = (
self.main_window.ui.trainingCreateDatasetSelectFeaturesPushButton
)
self.training_create_dataset_select_features_push_button.clicked.connect(
self._open_feature_selection_popup
)
self.training_create_dataset_label_line_edit = (
self.main_window.ui.trainingCreateDatasetLabelLineEdit
)
self.training_create_dataset_push_button = (
self.main_window.ui.trainingCreateDatasetPushButton
)
self.training_create_dataset_push_button.clicked.connect(self._create_dataset)
self.training_create_dataset_push_button.setEnabled(False)
# Train Model
self.training_train_model_group_box = (
self.main_window.ui.trainingTrainModelGroupBox
)
self.training_select_dataset_push_button = (
self.main_window.ui.trainingSelectDatasetPushButton
)
self.training_select_dataset_push_button.clicked.connect(self._select_dataset)
self.training_selected_dataset_label = (
self.main_window.ui.trainingSelectedDatasetLabel
)
self.no_dataset_selected_info = "No dataset selected!"
self.training_selected_dataset_label.setText(self.no_dataset_selected_info)
self.training_model_selection_combo_box = (
self.main_window.ui.trainingModelSelectionComboBox
)
# set the models selection combo box
self.training_model_selection_combo_box.addItems(list(MODELS_MAP.keys()))
# connect the models selection combo box to the models selection function
self.training_model_selection_combo_box.currentIndexChanged.connect(
self._update_model_selection
)
self.training_model_selection_combo_box.setCurrentIndex(0)
self.training_model_label_line_edit = (
self.main_window.ui.trainingModelLabelLineEdit
)
self.training_model_parameters_push_button = (
self.main_window.ui.trainingModelParametersPushButton
)
# connect the models parameters push button to the models parameters popup
self.training_model_parameters_push_button.clicked.connect(
self._open_model_parameters_popup
)
self.train_model_push_button = self.main_window.ui.trainingTrainModelPushButton
self.train_model_push_button.clicked.connect(self._train_model)
self.train_model_push_button.setEnabled(False)