Source code for doc_octopy.models.definitions.raul_net.online.v15

"""Model definition not used in any publication"""
import sys
from functools import reduce
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

try:
    from bayesian_torch.layers import get_kernel_size
except ImportError:
    sys.exit(
        """You need bayesian_torch!
                install it from https://github.com/IntelLabs/bayesian-torch
                or run pip install bayesian-torch."""
    )


[docs] class BaseVariationalLayer_(nn.Module): def __init__(self): super().__init__()
[docs] def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): """ Calculates kl divergence between two gaussians (Q || P) Parameters: * mu_q: torch.Tensor -> mu parameter of distribution Q * sigma_q: torch.Tensor -> sigma parameter of distribution Q * mu_p: float -> mu parameter of distribution P * sigma_p: float -> sigma parameter of distribution P returns torch.Tensor of shape 0 """ kl = ( torch.log(sigma_p) - torch.log(sigma_q) + (sigma_q**2 + (mu_q - mu_p) ** 2) / (2 * (sigma_p**2)) - 0.5 ) return kl.mean()
[docs] class Conv3dFlipout(BaseVariationalLayer_): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, prior_mean=0, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, ): """ Implements Conv3d layer with Flipout reparameterization trick. Inherits from bayesian_torch.layers.BaseVariationalLayer_ Parameters: in_channels: int -> number of channels in the input image, out_channels: int -> number of channels produced by the convolution, kernel_size: int -> size of the convolving filter, stride: int -> stride of the convolution. Default: 1, padding: int -> zero-padding added to both sides of the input. Default: 0, dilation: int -> spacing between filter elements. Default: 1, groups: int -> number of blocked connections from input channels to output channels, prior_mean: float -> mean of the prior arbitrary distribution to be used on the complexity cost, prior_variance: float -> variance of the prior arbitrary distribution to be used on the complexity cost, posterior_mu_init: float -> init trainable mu parameter representing mean of the approximate posterior, posterior_rho_init: float -> init trainable rho parameter representing the sigma of the approximate posterior through softplus function, """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.prior_mean = prior_mean self.prior_variance = prior_variance self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init kernel_size = get_kernel_size(kernel_size, 3) self.mu_kernel = nn.Parameter( torch.Tensor( out_channels, in_channels // groups, kernel_size[0], kernel_size[1], kernel_size[2], ) ) self.rho_kernel = nn.Parameter( torch.Tensor( out_channels, in_channels // groups, kernel_size[0], kernel_size[1], kernel_size[2], ) ) self.register_buffer( "eps_kernel", torch.Tensor( out_channels, in_channels // groups, kernel_size[0], kernel_size[1], kernel_size[2], ), persistent=False, ) self.register_buffer( "prior_weight_mu", torch.Tensor( out_channels, in_channels // groups, kernel_size[0], kernel_size[1], kernel_size[2], ), persistent=False, ) self.register_buffer( "prior_weight_sigma", torch.Tensor( out_channels, in_channels // groups, kernel_size[0], kernel_size[1], kernel_size[2], ), persistent=False, ) self.mu_bias = nn.Parameter(torch.Tensor(out_channels)) self.rho_bias = nn.Parameter(torch.Tensor(out_channels)) self.register_buffer("eps_bias", torch.Tensor(out_channels), persistent=False) self.register_buffer( "prior_bias_mu", torch.Tensor(out_channels), persistent=False ) self.register_buffer( "prior_bias_sigma", torch.Tensor(out_channels), persistent=False ) self.init_parameters()
[docs] def init_parameters(self): # prior values self.prior_weight_mu.data.fill_(self.prior_mean) self.prior_weight_sigma.data.fill_(self.prior_variance) # init our weights for the deterministic and perturbated weights self.mu_kernel.data.normal_(mean=self.posterior_mu_init, std=0.1) self.rho_kernel.data.normal_(mean=self.posterior_rho_init, std=0.1) self.mu_bias.data.normal_(mean=self.posterior_mu_init, std=0.1) self.rho_bias.data.normal_(mean=self.posterior_rho_init, std=0.1) self.prior_bias_mu.data.fill_(self.prior_mean) self.prior_bias_sigma.data.fill_(self.prior_variance)
[docs] def kl_loss(self): sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) kl = self.kl_div( self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma ) sigma_bias = torch.log1p(torch.exp(self.rho_bias)) kl += self.kl_div( self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma ) return kl
[docs] def forward(self, x): # linear outputs outputs = F.conv3d( x, weight=self.mu_kernel, bias=self.mu_bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ) # sampling perturbation signs sign_input = torch.rand_like(x, device=x.device) * 2 - 1 sign_input = sign_input.sign() sign_output = torch.rand_like(outputs, device=outputs.device) * 2 - 1 sign_output = sign_output.sign() # gettin perturbation weights sigma_weight = torch.log1p(torch.exp(self.rho_kernel)) eps_kernel = self.eps_kernel.data.normal_() delta_kernel = sigma_weight * eps_kernel kl = self.kl_div( self.mu_kernel, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma ) sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() bias = sigma_bias * eps_bias kl = kl + self.kl_div( self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma ) # perturbed feedforward perturbed_outputs = ( F.conv3d( x * sign_input, weight=delta_kernel, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, ) * sign_output ) # returning outputs + perturbations return outputs + perturbed_outputs, kl
[docs] class RaulNetV15(pl.LightningModule): """Model definition not used in any publication Attributes ---------- learning_rate : float The learning rate. nr_of_input_channels : int The number of input channels. nr_of_outputs : int The number of outputs. cnn_encoder_channels : Tuple[int, int, int] Tuple containing 3 integers defining the cnn encoder channels. mlp_encoder_channels : Tuple[int, int] Tuple containing 2 integers defining the mlp encoder channels. event_search_kernel_length : int Integer that sets the length of the kernels searching for action potentials. event_search_kernel_stride : int Integer that sets the stride of the kernels searching for action potentials. """ def __init__( self, learning_rate: float, nr_of_input_channels: int, input_length__samples: int, nr_of_outputs: int, cnn_encoder_channels: Tuple[int, int, int], mlp_encoder_channels: Tuple[int, int], event_search_kernel_length: int, event_search_kernel_stride: int, nr_of_electrode_grids: int = 5, nr_of_electrodes_per_grid: int = 64, inference_only: bool = False, ): super(RaulNetV15, self).__init__() self.save_hyperparameters() self.learning_rate = learning_rate self.nr_of_input_channels = nr_of_input_channels self.nr_of_outputs = nr_of_outputs self.input_length__samples = input_length__samples self.cnn_encoder_channels = cnn_encoder_channels self.mlp_encoder_channels = mlp_encoder_channels self.event_search_kernel_length = event_search_kernel_length self.event_search_kernel_stride = event_search_kernel_stride self.nr_of_electrode_grids = nr_of_electrode_grids self.nr_of_electrodes_per_grid = nr_of_electrodes_per_grid self.inference_only = inference_only self.criterion = nn.L1Loss() self.bayesian_conv3d = Conv3dFlipout( self.nr_of_input_channels, self.cnn_encoder_channels[0], kernel_size=(1, 1, self.event_search_kernel_length), stride=(1, 1, self.event_search_kernel_stride), groups=self.nr_of_input_channels, ) self._dummy_conv3d = nn.Conv3d( self.nr_of_input_channels, self.cnn_encoder_channels[0], kernel_size=(1, 1, self.event_search_kernel_length), stride=(1, 1, self.event_search_kernel_stride), groups=self.nr_of_input_channels, ) self.cnn_encoder = nn.Sequential( nn.GELU(approximate="tanh"), nn.InstanceNorm3d(self.cnn_encoder_channels[0]), nn.Dropout3d(p=0.20), nn.Conv3d( self.cnn_encoder_channels[0], self.cnn_encoder_channels[1], kernel_size=( self.nr_of_electrode_grids, int(np.floor(self.nr_of_electrodes_per_grid / 2)), 18, ), dilation=(1, 2, 1), padding=( int(np.floor(self.nr_of_electrode_grids / 2)), int(np.floor(self.nr_of_electrodes_per_grid / 4)), 0, ), padding_mode="circular", ), nn.GELU(approximate="tanh"), nn.InstanceNorm3d(self.cnn_encoder_channels[1]), nn.Conv3d( self.cnn_encoder_channels[1], self.cnn_encoder_channels[2], kernel_size=( self.nr_of_electrode_grids, int(np.floor(self.nr_of_electrodes_per_grid / 7)), 1, ), ), nn.GELU(approximate="tanh"), nn.InstanceNorm3d(self.cnn_encoder_channels[2]), nn.Flatten(), ) self.mlp = nn.Sequential( nn.Linear( self.cnn_encoder( self._dummy_conv3d( torch.rand( ( 1, self.nr_of_input_channels, self.nr_of_electrode_grids, self.nr_of_electrodes_per_grid, self.input_length__samples, ) ) ) ) .detach() .shape[1], self.mlp_encoder_channels[0], ), nn.GELU(approximate="tanh"), nn.Linear(self.mlp_encoder_channels[0], self.mlp_encoder_channels[1]), nn.GELU(approximate="tanh"), nn.Linear(self.mlp_encoder_channels[1], self.nr_of_outputs), ) del self._dummy_conv3d
[docs] def forward(self, inputs) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: x = self._reshape_and_normalize(inputs) x, kl = self.bayesian_conv3d(x) x = self.cnn_encoder(x) x = self.mlp(x) if self.inference_only: return x return x, kl
def _reshape_and_normalize(self, inputs): x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2) return (x - x.mean(dim=(3, 4), keepdim=True)) / ( x.std(dim=(3, 4), keepdim=True, unbiased=True) + 1e-15 ) def _reshape_and_normalize_v2(self, inputs): x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2) return (x - x.median(dim=4, keepdim=True)[0].median(dim=3, keepdim=True)[0]) / ( torch.quantile( torch.quantile(x, q=0.75, dim=4, keepdim=True), q=0.75, dim=3, keepdim=True, ) - torch.quantile( torch.quantile(x, q=0.25, dim=4, keepdim=True), q=0.25, dim=3, keepdim=True, ) + 1e-15 ) def _reshape_and_normalize_v3(self, inputs): x = torch.stack(inputs.split(self.nr_of_electrodes_per_grid, dim=2), dim=2) positive_mask = x >= 0 negative_mask = x < 0 positive_data = torch.log1p(x[positive_mask] + 1e-15) negative_data = -torch.log1p(-x[negative_mask] + 1e-15) x[positive_mask] = positive_data x[negative_mask] = negative_data return (x - x.mean(dim=(3, 4), keepdim=True)) / ( x.std(dim=(3, 4), keepdim=True, unbiased=True) + 1e-15 ) def configure_optimizers(self): optimizer = optim.AdamW( self.parameters(), lr=self.learning_rate, amsgrad=True, weight_decay=0.32 ) lr_scheduler = { "scheduler": optim.lr_scheduler.OneCycleLR( optimizer, max_lr=self.learning_rate * (10**1.5), total_steps=self.trainer.estimated_stepping_batches, anneal_strategy="cos", three_phase=False, div_factor=10**1.5, final_div_factor=1e3, ), "name": "OneCycleLR", "interval": "step", "frequency": 1, } return [optimizer], [lr_scheduler] def training_step( self, train_batch, batch_idx: int ) -> Optional[Union[torch.Tensor, Dict[str, Any]]]: inputs, ground_truths = train_batch ground_truths = ground_truths[:, 0] predictions, kls = [], [] for i in range(7): prediction, kl = self(inputs) predictions.append(prediction) kls.append(kl) prediction = torch.stack(predictions, dim=0).mean(dim=0) kl = torch.stack(kls, dim=0).mean(dim=0) scores_dict = { "loss": self.criterion(prediction, ground_truths) + (kl / prediction.shape[0]), "kl": kl, } if scores_dict["loss"].isnan().item(): return None self.log_dict( scores_dict, prog_bar=True, logger=False, on_epoch=True, rank_zero_only=True ) self.log_dict( {f"train/{k}": v for k, v in scores_dict.items()}, prog_bar=False, logger=True, on_epoch=True, on_step=False, rank_zero_only=True, ) return scores_dict def validation_step( self, batch, batch_idx ) -> Optional[Union[torch.Tensor, Dict[str, Any]]]: inputs, ground_truths = batch ground_truths = ground_truths[:, 0] prediction, kl = self(inputs) scores_dict = { "val_loss": self.criterion(prediction, ground_truths), "val_kl": kl, } self.log_dict( scores_dict, prog_bar=True, logger=False, on_epoch=True, sync_dist=True ) return scores_dict def test_step( self, batch, batch_idx ) -> Optional[Union[torch.Tensor, Dict[str, Any]]]: inputs, ground_truths = batch ground_truths = ground_truths[:, 0] prediction, kl = self(inputs) scores_dict = {"loss": self.criterion(prediction, ground_truths), "kl": kl} self.log_dict( scores_dict, prog_bar=True, logger=False, on_epoch=True, sync_dist=True ) self.log_dict( {f"test/{k}": v for k, v in scores_dict.items()}, prog_bar=False, logger=True, on_epoch=False, on_step=True, sync_dist=True, ) return scores_dict