"""Model definition not used in any publication"""
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import lightning as L
import torch
import torch.nn as nn
import torch.optim as optim
[docs]
class RaulNetV16(L.LightningModule):
"""Model definition not used in any publication
Attributes
----------
learning_rate : float
The learning rate.
nr_of_input_channels : int
The number of input channels.
nr_of_outputs : int
The number of outputs.
cnn_encoder_channels : Tuple[int, int, int]
Tuple containing 3 integers defining the cnn encoder channels.
mlp_encoder_channels : Tuple[int, int]
Tuple containing 2 integers defining the mlp encoder channels.
event_search_kernel_length : int
Integer that sets the length of the kernels searching for action potentials.
event_search_kernel_stride : int
Integer that sets the stride of the kernels searching for action potentials.
"""
[docs]
def __init__(
self,
learning_rate: float,
nr_of_input_channels: int,
input_length__samples: int,
nr_of_outputs: int,
cnn_encoder_channels: Tuple[int, int, int],
mlp_encoder_channels: Tuple[int, int],
event_search_kernel_length: int,
event_search_kernel_stride: int,
nr_of_electrode_grids: int = 3,
nr_of_electrodes_per_grid: int = 36,
inference_only: bool = False,
apply_softmax: bool = False,
):
super(RaulNetV16, self).__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.nr_of_input_channels = nr_of_input_channels
self.nr_of_outputs = nr_of_outputs
self.input_length__samples = input_length__samples
self.cnn_encoder_channels = cnn_encoder_channels
self.mlp_encoder_channels = mlp_encoder_channels
self.event_search_kernel_length = event_search_kernel_length
self.event_search_kernel_stride = event_search_kernel_stride
self.nr_of_electrode_grids = nr_of_electrode_grids
self.nr_of_electrodes_per_grid = nr_of_electrodes_per_grid
self.inference_only = inference_only
self.apply_softmax = apply_softmax
self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
self.cnn_encoder = nn.Sequential(
nn.Conv3d(
self.nr_of_input_channels,
self.cnn_encoder_channels[0],
kernel_size=(1, 1, self.event_search_kernel_length),
stride=(1, 1, self.event_search_kernel_stride),
groups=self.nr_of_input_channels,
),
nn.GELU(approximate="tanh"),
nn.InstanceNorm3d(self.cnn_encoder_channels[0]),
nn.Dropout3d(p=0.20),
nn.Conv3d(
self.cnn_encoder_channels[0],
self.cnn_encoder_channels[1],
kernel_size=(
self.nr_of_electrode_grids,
int(np.floor(self.nr_of_electrodes_per_grid / 2)),
18,
),
dilation=(1, 2, 1),
padding=(
int(np.floor(self.nr_of_electrode_grids / 2)),
int(np.floor(self.nr_of_electrodes_per_grid / 4)),
0,
),
padding_mode="circular",
),
nn.GELU(approximate="tanh"),
nn.InstanceNorm3d(self.cnn_encoder_channels[1]),
nn.Conv3d(
self.cnn_encoder_channels[1],
self.cnn_encoder_channels[2],
kernel_size=(
self.nr_of_electrode_grids,
int(np.floor(self.nr_of_electrodes_per_grid / 7)),
1,
),
),
nn.GELU(approximate="tanh"),
nn.InstanceNorm3d(self.cnn_encoder_channels[2]),
nn.Flatten(),
)
self.mlp = nn.Sequential(
nn.Linear(
self.cnn_encoder(
torch.rand(
(
1,
self.nr_of_input_channels,
self.nr_of_electrode_grids,
self.nr_of_electrodes_per_grid,
self.input_length__samples,
)
)
)
.detach()
.shape[1],
self.mlp_encoder_channels[0],
),
nn.GELU(approximate="tanh"),
nn.Linear(self.mlp_encoder_channels[0], self.mlp_encoder_channels[1]),
nn.GELU(approximate="tanh"),
nn.Linear(self.mlp_encoder_channels[1], self.nr_of_outputs),
)
self.softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, inputs) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
x = self._reshape_and_normalize(inputs)
x = self.cnn_encoder(x)
x = self.mlp(x)
if self.apply_softmax:
x = self.softmax(x)
return x
[docs]
def _reshape_and_normalize(self, inputs):
x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2)
return (x - x.mean(dim=(3, 4), keepdim=True)) / (
x.std(dim=(3, 4), keepdim=True, unbiased=True) + 1e-15
)
[docs]
def _reshape_and_normalize_v2(self, inputs):
x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2)
return (x - x.median(dim=4, keepdim=True)[0].median(dim=3, keepdim=True)[0]) / (
torch.quantile(
torch.quantile(x, q=0.75, dim=4, keepdim=True),
q=0.75,
dim=3,
keepdim=True,
)
- torch.quantile(
torch.quantile(x, q=0.25, dim=4, keepdim=True),
q=0.25,
dim=3,
keepdim=True,
)
+ 1e-15
)
[docs]
def _reshape_and_normalize_v3(self, inputs):
x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2)
positive_mask = x >= 0
negative_mask = x < 0
positive_data = torch.log1p(x[positive_mask] + 1e-15)
negative_data = -torch.log1p(-x[negative_mask] + 1e-15)
x[positive_mask] = positive_data
x[negative_mask] = negative_data
return (x - x.mean(dim=(3, 4), keepdim=True)) / (
x.std(dim=(3, 4), keepdim=True, unbiased=True) + 1e-15
)
[docs]
def training_step(
self, train_batch, batch_idx: int
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
inputs, ground_truths = train_batch
prediction = self(inputs)
scores_dict = {
"loss": self.criterion(prediction, ground_truths.to(float)),
}
if scores_dict["loss"].isnan().item():
return None
self.log_dict(
scores_dict, prog_bar=True, logger=False, on_epoch=True, sync_dist=True
)
self.log_dict(
{f"train/{k}": v for k, v in scores_dict.items()},
prog_bar=False,
logger=True,
on_epoch=True,
on_step=False,
sync_dist=True,
)
return scores_dict
[docs]
def validation_step(
self, batch, batch_idx
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
inputs, ground_truths = batch
prediction = self(inputs)
scores_dict = {
"val_loss": self.criterion(prediction, ground_truths.to(float)),
}
self.log_dict(
scores_dict, prog_bar=True, logger=False, on_epoch=True, sync_dist=True
)
return scores_dict
[docs]
def test_step(
self, batch, batch_idx
) -> Optional[Union[torch.Tensor, Dict[str, Any]]]:
inputs, ground_truths = batch
prediction = self(inputs)
scores_dict = {"loss": self.criterion(prediction, ground_truths.to(float))}
self.log_dict(
scores_dict, prog_bar=True, logger=False, on_epoch=True, sync_dist=True
)
self.log_dict(
{f"test/{k}": v for k, v in scores_dict.items()},
prog_bar=False,
logger=True,
on_epoch=False,
on_step=True,
sync_dist=True,
)
return scores_dict