Source code for doc_octopy.models.definitions.raul_net.online.v13
"""Model definition not used in any publication"""importpicklefrompathlibimportPathimporttorchfromtorchimportnnfromdoc_octopy.models.definitions.raul_net.online.v9_5_grids_compilableimportRaulNetV9_Compilable
[docs]classRaulNetV13(nn.Module):"""Wrapper to make one RaulNetV9_Compilable per finger Attributes ---------- learning_rate : float The learning rate. nr_of_input_channels : int The number of input channels. nr_of_outputs : int The number of outputs. cnn_encoder_channels : Tuple[int, int, int] Tuple containing 3 integers defining the cnn encoder channels. mlp_encoder_channels : Tuple[int, int] Tuple containing 2 integers defining the mlp encoder channels. event_search_kernel_length : int Integer that sets the length of the kernels searching for action potentials. event_search_kernel_stride : int Integer that sets the stride of the kernels searching for action potentials. """def__init__(self,models_path:Path):super(RaulNetV13,self).__init__()self.cnns=nn.ModuleList()self.mlps=nn.ModuleList()model_paths={}forname_file_pathinlist(models_path.rglob("mlflow.runName")):model_paths[name_file_path.open("r").read().split(" ")[-1]]=name_file_pathforname_file_pathin[model_paths["thumb"],model_paths["index"],model_paths["middle"],model_paths["ring"],model_paths["pinky"],]:temp=RaulNetV9_Compilable(**pickle.load((Path(name_file_path).parent.parent/"artifacts"/"model_hparams.pkl").open("rb")))temp.load_state_dict({k.replace("model._orig_mod.",""):vfork,vintorch.load(name_file_path.parent.parent/"artifacts"/"model.pt").items()})self.cnns.append(temp.cnn_encoder)self.mlps.append(temp.mlp_encoder)