Source code for myoverse.transforms.generic

"""GPU-accelerated generic transforms using PyTorch.

Array manipulation transforms that work with named tensors.

Example:
-------
>>> import torch
>>> from myoverse.transforms.tensor import Reshape, Index, Flatten
>>>
>>> x = torch.randn(64, 2048, names=('channel', 'time'))
>>> y = Reshape((8, 8, 2048), names=('row', 'col', 'time'))(x)

"""

from __future__ import annotations

from collections.abc import Callable

import torch

from myoverse.transforms.base import TensorTransform, get_dim_index


[docs] class Reshape(TensorTransform): """Reshape tensor with new dimension names. Parameters ---------- shape : tuple[int, ...] New shape (-1 allowed for one dimension). names : tuple[str, ...] | None New dimension names. Examples -------- >>> x = torch.randn(64, 2048, names=('channel', 'time')) >>> reshape = Reshape((8, 8, 2048), names=('row', 'col', 'time')) >>> y = reshape(x) # Shape: (8, 8, 2048) """
[docs] def __init__( self, shape: tuple[int, ...], names: tuple[str, ...] | None = None, **kwargs, ): super().__init__(dim="time", **kwargs) self.shape = shape self.names = names
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: x = x.rename(None) result = x.reshape(self.shape) if self.names is not None: result = result.rename(*self.names) return result
[docs] class Index(TensorTransform): """Index/slice along a dimension. Parameters ---------- indices : int | slice | list[int] Indices to select. dim : str Dimension to index. Examples -------- >>> x = torch.randn(64, 2048, names=('channel', 'time')) >>> # Select first 10 channels >>> index = Index(slice(0, 10), dim='channel') >>> y = index(x) # Shape: (10, 2048) """
[docs] def __init__( self, indices: int | slice | list[int], dim: str = "time", **kwargs, ): super().__init__(dim=dim, **kwargs) self.indices = indices
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: dim_idx = get_dim_index(x, self.dim) names = x.names x = x.rename(None) # Build index tuple index = [slice(None)] * x.ndim index[dim_idx] = self.indices result = x[tuple(index)] if names[0] is not None and result.ndim == len(names): result = result.rename(*names) elif names[0] is not None: # Dimension was squeezed new_names = list(names) if isinstance(self.indices, int): new_names.pop(dim_idx) result = result.rename(*new_names) return result
[docs] class Flatten(TensorTransform): """Flatten dimensions of a tensor. Parameters ---------- start_dim : int First dimension to flatten. end_dim : int Last dimension to flatten. Examples -------- >>> x = torch.randn(8, 8, 2048, names=('row', 'col', 'time')) >>> flatten = Flatten(start_dim=0, end_dim=1) >>> y = flatten(x) # Shape: (64, 2048) """
[docs] def __init__( self, start_dim: int = 0, end_dim: int = -1, **kwargs, ): super().__init__(dim="time", **kwargs) self.start_dim = start_dim self.end_dim = end_dim
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: x = x.rename(None) result = torch.flatten(x, start_dim=self.start_dim, end_dim=self.end_dim) return result
[docs] class Squeeze(TensorTransform): """Remove dimensions of size 1. Parameters ---------- dim : int | None Specific dimension to squeeze, or None for all. """
[docs] def __init__(self, dim: int | None = None, **kwargs): super().__init__(**kwargs) self.squeeze_dim = dim
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: x = x.rename(None) if self.squeeze_dim is not None: return x.squeeze(self.squeeze_dim) return x.squeeze()
[docs] class Unsqueeze(TensorTransform): """Add a dimension of size 1. Parameters ---------- dim : int Position to insert new dimension. name : str | None Name for the new dimension. """
[docs] def __init__(self, dim: int, name: str | None = None, **kwargs): super().__init__(**kwargs) self.unsqueeze_dim = dim self.new_name = name
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: names = list(x.names) if x.names[0] is not None else None x = x.rename(None) result = x.unsqueeze(self.unsqueeze_dim) if names is not None and self.new_name is not None: names.insert(self.unsqueeze_dim, self.new_name) result = result.rename(*names) return result
[docs] class Transpose(TensorTransform): """Transpose/permute dimensions. Parameters ---------- dims : tuple[int, ...] | tuple[str, ...] New dimension order (by index or name). Examples -------- >>> x = torch.randn(64, 2048, names=('channel', 'time')) >>> transpose = Transpose(('time', 'channel')) >>> y = transpose(x) # Shape: (2048, 64) """
[docs] def __init__(self, dims: tuple, **kwargs): super().__init__(**kwargs) self.dims = dims
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: names = x.names # Convert string dims to indices if all(isinstance(d, str) for d in self.dims): if names[0] is None: raise ValueError("Cannot use string dims on unnamed tensor") perm = [names.index(d) for d in self.dims] else: perm = list(self.dims) x = x.rename(None) result = x.permute(*perm) if names[0] is not None: new_names = [names[i] for i in perm] result = result.rename(*new_names) return result
[docs] class Mean(TensorTransform): """Compute mean along a dimension. Parameters ---------- dim : str Dimension to reduce. keepdim : bool Whether to keep the reduced dimension. """
[docs] def __init__(self, dim: str = "time", keepdim: bool = False, **kwargs): super().__init__(dim=dim, **kwargs) self.keepdim = keepdim
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: dim_idx = get_dim_index(x, self.dim) names = x.names x = x.rename(None) result = x.mean(dim=dim_idx, keepdim=self.keepdim) if names[0] is not None: if self.keepdim: result = result.rename(*names) else: new_names = [n for i, n in enumerate(names) if i != dim_idx] if new_names: result = result.rename(*new_names) return result
[docs] class Sum(TensorTransform): """Compute sum along a dimension. Parameters ---------- dim : str Dimension to reduce. keepdim : bool Whether to keep the reduced dimension. """
[docs] def __init__(self, dim: str = "time", keepdim: bool = False, **kwargs): super().__init__(dim=dim, **kwargs) self.keepdim = keepdim
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: dim_idx = get_dim_index(x, self.dim) names = x.names x = x.rename(None) result = x.sum(dim=dim_idx, keepdim=self.keepdim) if names[0] is not None: if self.keepdim: result = result.rename(*names) else: new_names = [n for i, n in enumerate(names) if i != dim_idx] if new_names: result = result.rename(*new_names) return result
[docs] class Stack(TensorTransform): """Stack multiple tensors along a new dimension. This is a container transform that holds multiple paths. Parameters ---------- transforms : dict[str, Callable] Named transforms to apply and stack. dim : str Name for the new stacking dimension. Examples -------- >>> from myoverse.transforms.tensor import RMS, MAV, Stack >>> stack = Stack({ ... 'rms': RMS(window_size=200), ... 'mav': MAV(window_size=200), ... }, dim='feature') >>> y = stack(x) # Shape: (2, channel, time_windows) """
[docs] def __init__( self, transforms: dict[str, Callable], dim: str = "representation", **kwargs, ): super().__init__(**kwargs) self.transforms = transforms self.stack_dim = dim
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: results = [] first_names = None for name, transform in self.transforms.items(): result = transform(x) if first_names is None: first_names = result.names results.append(result.rename(None)) stacked = torch.stack(results, dim=0) # Add dimension names if first_names is not None and first_names[0] is not None: names = (self.stack_dim,) + tuple(first_names) else: names = (self.stack_dim,) + tuple( f"dim_{i}" for i in range(results[0].ndim) ) return stacked.rename(*names)
[docs] class Concat(TensorTransform): """Concatenate multiple tensors along an existing dimension. Parameters ---------- transforms : dict[str, Callable] Named transforms to apply and concatenate. dim : str Dimension to concatenate along. Examples -------- >>> from myoverse.transforms.tensor import RMS, MAV, Concat >>> concat = Concat({ ... 'rms': RMS(window_size=200), ... 'mav': MAV(window_size=200), ... }, dim='channel') >>> y = concat(x) # Concatenated along channel dimension """
[docs] def __init__( self, transforms: dict[str, Callable], dim: str = "channel", **kwargs, ): super().__init__(dim=dim, **kwargs) self.transforms = transforms
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: results = [] first_names = None for name, transform in self.transforms.items(): result = transform(x) if first_names is None: first_names = result.names results.append(result.rename(None)) # Get dim index from first result's original names if first_names is not None and first_names[0] is not None: dim_idx = first_names.index(self.dim) else: dim_idx = -2 concatenated = torch.cat(results, dim=dim_idx) if first_names is not None and first_names[0] is not None: concatenated = concatenated.rename(*first_names) return concatenated
[docs] class Lambda(TensorTransform): """Apply a custom function. Parameters ---------- func : Callable Function to apply. Examples -------- >>> transform = Lambda(lambda x: x ** 2) >>> y = transform(x) """
[docs] def __init__(self, func: Callable, **kwargs): super().__init__(**kwargs) self.func = func
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: return self.func(x)
[docs] class Identity(TensorTransform): """Identity transform (returns input unchanged)."""
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: return x
[docs] class Repeat(TensorTransform): """Repeat tensor along a dimension. Parameters ---------- repeats : int Number of repetitions. dim : str Dimension to repeat along. """
[docs] def __init__(self, repeats: int, dim: str = "channel", **kwargs): super().__init__(dim=dim, **kwargs) self.repeats = repeats
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: dim_idx = get_dim_index(x, self.dim) names = x.names x = x.rename(None) # Build repeat tuple repeat_tuple = [1] * x.ndim repeat_tuple[dim_idx] = self.repeats result = x.repeat(*repeat_tuple) if names[0] is not None: result = result.rename(*names) return result
[docs] class Pad(TensorTransform): """Pad tensor along a dimension. Parameters ---------- padding : tuple[int, int] Padding (before, after) along the dimension. dim : str Dimension to pad. mode : str Padding mode: 'constant', 'reflect', 'replicate', 'circular'. value : float Fill value for constant padding. """
[docs] def __init__( self, padding: tuple[int, int], dim: str = "time", mode: str = "constant", value: float = 0.0, **kwargs, ): super().__init__(dim=dim, **kwargs) self.padding = padding self.mode = mode self.value = value
[docs] def _apply(self, x: torch.Tensor) -> torch.Tensor: dim_idx = get_dim_index(x, self.dim) names = x.names x = x.rename(None) # F.pad expects padding in reverse order: (last_dim, ..., first_dim) # Each dim needs (left, right) padding ndim = x.ndim pad_list = [0] * (2 * ndim) # dim_idx from the end idx_from_end = ndim - 1 - dim_idx pad_list[2 * idx_from_end] = self.padding[0] # left pad_list[2 * idx_from_end + 1] = self.padding[1] # right result = torch.nn.functional.pad(x, pad_list, mode=self.mode, value=self.value) if names[0] is not None: result = result.rename(*names) return result