Source code for myoverse.models.components.losses

"""Loss functions for neural network training."""

from __future__ import annotations

import torch
from torch import nn


[docs] class EuclideanDistance(nn.Module): """Euclidean distance loss for 3D joint positions. Computes the mean Euclidean distance between predicted and ground truth 3D joint positions. Expects input tensors to be reshaped to (batch, joints, xyz). Parameters ---------- n_joints : int Number of joints in the skeleton. Default is 20. n_dims : int Number of dimensions per joint (typically 3 for x, y, z). Default is 3. Examples -------- >>> loss_fn = EuclideanDistance(n_joints=20) >>> pred = torch.randn(32, 60) # batch_size=32, 20 joints * 3 dims >>> target = torch.randn(32, 60) >>> loss = loss_fn(pred, target) """
[docs] def __init__(self, n_joints: int = 20, n_dims: int = 3): super().__init__() self.n_joints = n_joints self.n_dims = n_dims
[docs] def forward( self, prediction: torch.Tensor, ground_truth: torch.Tensor ) -> torch.Tensor: """Compute the mean Euclidean distance loss. Parameters ---------- prediction : torch.Tensor Predicted joint positions, shape (batch, n_joints * n_dims). ground_truth : torch.Tensor Ground truth joint positions, shape (batch, n_joints * n_dims). Returns ------- torch.Tensor Scalar loss value. """ pred_reshaped = prediction.reshape(-1, self.n_joints, self.n_dims) gt_reshaped = ground_truth.reshape(-1, self.n_joints, self.n_dims) # Compute per-joint Euclidean distances and average distances = torch.sqrt( torch.sum(torch.square(pred_reshaped - gt_reshaped), dim=-1) ) return distances.mean()
# Backward compatibility alias (deprecated spelling) EuclidianDistance = EuclideanDistance