Source code for myoverse.models.components.utils
"""Utility layers for neural network models."""
from __future__ import annotations
import torch
from torch import nn
[docs]
class WeightedSum(nn.Module):
"""Learnable weighted sum of two tensors.
Computes alpha * x + (1 - alpha) * y where alpha is a learnable parameter.
Parameters
----------
alpha : float, optional
Initial weight for the first input. Default is 0.5.
"""
[docs]
def __init__(self, alpha: float = 0.5):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(alpha), requires_grad=True)
[docs]
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return self.alpha * x + (1 - self.alpha) * y
class CircularPad(nn.Module):
"""Circular padding layer used in Sîmpetru et al. [1]_.
Applies fixed circular padding to 4D input tensors along dimensions 2 and 3.
References
----------
.. [1] Sîmpetru, R.C., Osswald, M., Braun, D.I., Oliveira, D.S., Cakici, A.L., Del Vecchio, A., 2022.
Accurate Continuous Prediction of 14 Degrees of Freedom of the Hand from Myoelectrical Signals
through Convolutive Deep Learning, in: 2022 44th Annual International Conference of the IEEE
Engineering in Medicine & Biology Society (EMBC), pp. 702-706.
https://doi.org/10.1109/EMBC48229.2022.9870937
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.cat([torch.narrow(x, 2, 3, 2), x, torch.narrow(x, 2, 0, 2)], dim=2)
x = torch.cat([torch.narrow(x, 3, 48, 16), x, torch.narrow(x, 3, 0, 16)], dim=3)
return x