Source code for myoverse.transforms.augment

"""GPU-accelerated augmentation transforms using PyTorch.

All augmentations work with named tensors and run on GPU.
They are stochastic and respect torch.random state.

Example:
-------
>>> import torch
>>> from myoverse.transforms.tensor import GaussianNoise, MagnitudeWarp, TimeWarp
>>>
>>> x = torch.randn(32, 64, 200, device='cuda', names=('batch', 'channel', 'time'))
>>>
>>> # Augmentation pipeline
>>> augment = Pipeline([
...     GaussianNoise(std=0.1),
...     MagnitudeWarp(sigma=0.2),
... ])
>>> y = augment(x)  # Augmented on GPU

"""

from __future__ import annotations

import torch
import torch.nn.functional as F

from myoverse.transforms.base import TensorTransform, get_dim_index


[docs] class GaussianNoise(TensorTransform): """Add Gaussian noise to the signal. Parameters ---------- std : float Standard deviation of the noise. p : float Probability of applying the augmentation. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> noise = GaussianNoise(std=0.1) >>> y = noise(x) """
[docs] def __init__(self, std: float = 0.1, p: float = 1.0, **kwargs): super().__init__(**kwargs) self.std = std self.p = p
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x names = x.names x = x.rename(None) noise = torch.randn_like(x) * self.std result = x + noise if names[0] is not None: result = result.rename(*names) return result
[docs] class MagnitudeWarp(TensorTransform): """Warp magnitude using smooth random curves. Creates smooth random scaling factors that vary over time. Parameters ---------- sigma : float Standard deviation for the warping curves. n_knots : int Number of control points for the spline. p : float Probability of applying the augmentation. dim : str Dimension to warp along. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> warp = MagnitudeWarp(sigma=0.2, n_knots=4) >>> y = warp(x) """
[docs] def __init__( self, sigma: float = 0.2, n_knots: int = 4, p: float = 1.0, dim: str = "time", **kwargs, ): super().__init__(dim=dim, **kwargs) self.sigma = sigma self.n_knots = n_knots self.p = p
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x dim_idx = get_dim_index(x, self.dim) names = x.names n_samples = x.shape[dim_idx] x = x.rename(None) # Generate smooth warping curve # Create random knots knots = ( torch.randn(self.n_knots, device=x.device, dtype=x.dtype) * self.sigma + 1.0 ) # Interpolate to full length warp = F.interpolate( knots.view(1, 1, -1), size=n_samples, mode="linear", align_corners=True, ).squeeze() # Expand warp to match x dimensions shape = [1] * x.ndim shape[dim_idx] = n_samples warp = warp.view(*shape) result = x * warp if names[0] is not None: result = result.rename(*names) return result
[docs] class TimeWarp(TensorTransform): """Warp time axis with smooth random curves. Creates smooth random time shifts using cubic interpolation. Parameters ---------- sigma : float Standard deviation for the warping curves. n_knots : int Number of control points. p : float Probability of applying the augmentation. dim : str Time dimension to warp. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> warp = TimeWarp(sigma=0.2, n_knots=4) >>> y = warp(x) """
[docs] def __init__( self, sigma: float = 0.2, n_knots: int = 4, p: float = 1.0, dim: str = "time", **kwargs, ): super().__init__(dim=dim, **kwargs) self.sigma = sigma self.n_knots = n_knots self.p = p
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x dim_idx = get_dim_index(x, self.dim) names = x.names n_samples = x.shape[dim_idx] x = x.rename(None) # Generate warping indices # Create random cumulative distortions distortions = torch.randn(self.n_knots + 2, device=x.device, dtype=x.dtype) distortions[0] = 0 distortions[-1] = 0 distortions = distortions.cumsum(0) * self.sigma # Interpolate to get warped indices orig_indices = torch.linspace( 0, n_samples - 1, self.n_knots + 2, device=x.device ) warped_indices = orig_indices + distortions * (n_samples / self.n_knots) warped_indices = torch.clamp(warped_indices, 0, n_samples - 1) # Interpolate warping function to full length warp_func = F.interpolate( warped_indices.view(1, 1, -1), size=n_samples, mode="linear", align_corners=True, ).squeeze() # Apply warping using grid_sample (need to reshape for grid_sample) # Move time dimension to last position if dim_idx != x.ndim - 1: perm = list(range(x.ndim)) perm[dim_idx], perm[-1] = perm[-1], perm[dim_idx] x = x.permute(*perm) moved = True else: moved = False # Reshape for grid_sample: need (N, C, H, W) or (N, C, D, H, W) original_shape = x.shape x_flat = x.reshape(-1, 1, 1, n_samples) # Create sampling grid # grid_sample expects coordinates in [-1, 1] grid = warp_func / (n_samples - 1) * 2 - 1 grid = grid.view(1, 1, 1, -1).expand(x_flat.shape[0], 1, 1, -1) # Apply warping warped = F.grid_sample( x_flat, grid, mode="bilinear", padding_mode="border", align_corners=True, ) result = warped.reshape(original_shape) # Restore dimension order if moved: result = result.permute(*perm) if names[0] is not None: result = result.rename(*names) return result
[docs] class Dropout(TensorTransform): """Randomly zero out elements. Parameters ---------- p : float Probability of zeroing each element. dim : str If specified, drops entire slices along this dimension. If None, drops individual elements. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> # Element-wise dropout >>> dropout = Dropout(p=0.1) >>> # Channel dropout (drop entire channels) >>> channel_dropout = Dropout(p=0.1, dim='channel') """
[docs] def __init__(self, p: float = 0.1, dim: str | None = None, **kwargs): super().__init__(**kwargs) self.p = p self.drop_dim = dim
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if not self.training_mode: return x names = x.names x = x.rename(None) if self.drop_dim is None: # Element-wise dropout result = F.dropout(x, p=self.p, training=True) else: # Structured dropout along dimension dim_idx = get_dim_index(x, self.drop_dim) if names[0] is not None else -2 # Create mask mask_shape = [1] * x.ndim mask_shape[dim_idx] = x.shape[dim_idx] mask = (torch.rand(mask_shape, device=x.device) > self.p).float() result = x * mask if names[0] is not None: result = result.rename(*names) return result
@property def training_mode(self) -> bool: """Check if in training mode (dropout only during training).""" return getattr(self, "_training", True)
[docs] def train(self): """Set to training mode.""" self._training = True return self
[docs] def eval(self): """Set to evaluation mode.""" self._training = False return self
[docs] class ChannelShuffle(TensorTransform): """Randomly shuffle channel order. Parameters ---------- p : float Probability of applying shuffle. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> shuffle = ChannelShuffle(p=0.5) >>> y = shuffle(x) """
[docs] def __init__(self, p: float = 0.5, **kwargs): super().__init__(dim="channel", **kwargs) self.p = p
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x dim_idx = get_dim_index(x, self.dim) names = x.names x = x.rename(None) # Generate random permutation n_channels = x.shape[dim_idx] perm = torch.randperm(n_channels, device=x.device) result = x.index_select(dim_idx, perm) if names[0] is not None: result = result.rename(*names) return result
[docs] class TimeShift(TensorTransform): """Randomly shift signal in time. Parameters ---------- max_shift : int | float Maximum shift amount. If float, interpreted as fraction of length. p : float Probability of applying shift. fill : str How to fill shifted regions: 'zero', 'wrap', 'edge'. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> shift = TimeShift(max_shift=100, p=0.5) >>> y = shift(x) """
[docs] def __init__( self, max_shift: float = 100, p: float = 0.5, fill: str = "zero", **kwargs, ): super().__init__(dim="time", **kwargs) self.max_shift = max_shift self.p = p self.fill = fill
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x dim_idx = get_dim_index(x, self.dim) names = x.names n_samples = x.shape[dim_idx] x = x.rename(None) # Compute shift amount if isinstance(self.max_shift, float) and self.max_shift < 1: max_shift = int(n_samples * self.max_shift) else: max_shift = int(self.max_shift) shift = torch.randint(-max_shift, max_shift + 1, (1,)).item() if shift == 0: if names[0] is not None: x = x.rename(*names) return x # Apply shift result = torch.roll(x, shifts=shift, dims=dim_idx) # Handle fill mode if self.fill == "zero": # Zero out the wrapped region if shift > 0: idx = [slice(None)] * x.ndim idx[dim_idx] = slice(0, shift) result[tuple(idx)] = 0 else: idx = [slice(None)] * x.ndim idx[dim_idx] = slice(shift, None) result[tuple(idx)] = 0 elif self.fill == "edge": # Use edge values for the wrapped region if shift > 0: idx = [slice(None)] * x.ndim idx[dim_idx] = slice(0, shift) edge_idx = [slice(None)] * x.ndim edge_idx[dim_idx] = shift result[tuple(idx)] = x[tuple(edge_idx)].unsqueeze(dim_idx) else: idx = [slice(None)] * x.ndim idx[dim_idx] = slice(shift, None) edge_idx = [slice(None)] * x.ndim edge_idx[dim_idx] = shift - 1 result[tuple(idx)] = x[tuple(edge_idx)].unsqueeze(dim_idx) # 'wrap' is default behavior of roll if names[0] is not None: result = result.rename(*names) return result
[docs] class Scale(TensorTransform): """Random amplitude scaling. Parameters ---------- scale_range : tuple[float, float] Range of scale factors (min, max). p : float Probability of applying scaling. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> scale = Scale(scale_range=(0.8, 1.2)) >>> y = scale(x) """
[docs] def __init__( self, scale_range: tuple[float, float] = (0.8, 1.2), p: float = 1.0, **kwargs, ): super().__init__(**kwargs) self.scale_range = scale_range self.p = p
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x names = x.names x = x.rename(None) scale = torch.empty(1, device=x.device, dtype=x.dtype).uniform_( *self.scale_range ) result = x * scale if names[0] is not None: result = result.rename(*names) return result
[docs] class Cutout(TensorTransform): """Randomly zero out contiguous regions. Parameters ---------- n_holes : int Number of regions to cut out. length : int | float Length of each cutout. If float < 1, fraction of total length. p : float Probability of applying cutout. dim : str Dimension to cut along. Examples -------- >>> x = torch.randn(64, 2048, device='cuda', names=('channel', 'time')) >>> cutout = Cutout(n_holes=3, length=50) >>> y = cutout(x) """
[docs] def __init__( self, n_holes: int = 1, length: float = 50, p: float = 0.5, dim: str = "time", **kwargs, ): super().__init__(dim=dim, **kwargs) self.n_holes = n_holes self.length = length self.p = p
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: if torch.rand(1).item() > self.p: return x dim_idx = get_dim_index(x, self.dim) names = x.names n_samples = x.shape[dim_idx] x = x.rename(None) result = x.clone() # Compute hole length if isinstance(self.length, float) and self.length < 1: hole_length = int(n_samples * self.length) else: hole_length = int(self.length) for _ in range(self.n_holes): # Random position start = torch.randint(0, n_samples - hole_length + 1, (1,)).item() # Zero out region idx = [slice(None)] * x.ndim idx[dim_idx] = slice(start, start + hole_length) result[tuple(idx)] = 0 if names[0] is not None: result = result.rename(*names) return result