Source code for amplfi.train.losses

import torch
import torch.nn.functional as F
from torch import Tensor


[docs] class VICRegLoss(torch.nn.Module): """Implementation of the VICReg loss [0]. This implementation is based on the code published by the authors [1]. - [0] VICReg, 2022, https://arxiv.org/abs/2105.04906 - [1] https://github.com/facebookresearch/vicreg/ Attributes: lambda_param: Scaling coefficient for the invariance term of the loss. mu_param: Scaling coefficient for the variance term of the loss. nu_param: Scaling coefficient for the covariance term of the loss. eps: Epsilon for numerical stability. """ def __init__( self, lambda_param: float = 25.0, mu_param: float = 25.0, nu_param: float = 1.0, eps: float = 0.0001, max_std: float = 1.0, ): """Initializes the VICRegLoss module with the specified parameters.""" super(VICRegLoss, self).__init__() self.lambda_param = lambda_param self.mu_param = mu_param self.nu_param = nu_param self.max_std = max_std self.eps = eps
[docs] def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor: """Returns VICReg loss. Args: z_a: Tensor with shape (batch_size, ..., dim). z_b: Tensor with shape (batch_size, ..., dim). Returns: The computed VICReg loss. """ # Invariance term of the loss inv_loss = invariance_loss(x=z_a, y=z_b) # Variance and covariance terms of the loss var_loss = 0.5 * ( variance_loss(x=z_a, eps=self.eps, max_std=self.max_std) + variance_loss(x=z_b, eps=self.eps, max_std=self.max_std) ) cov_loss = covariance_loss(x=z_a) + covariance_loss(x=z_b) # Total VICReg loss loss = ( self.lambda_param * inv_loss + self.mu_param * var_loss + self.nu_param * cov_loss ) return loss, (inv_loss, var_loss, cov_loss)
[docs] def invariance_loss(x: Tensor, y: Tensor) -> Tensor: """Returns VICReg invariance loss. Args: x: Tensor with shape (batch_size, ..., dim). y: Tensor with shape (batch_size, ..., dim). Returns: The computed VICReg invariance loss. """ return F.mse_loss(x, y)
[docs] def variance_loss( x: Tensor, eps: float = 0.0001, max_std: float = 1.0 ) -> Tensor: """Returns VICReg variance loss. Args: x: Tensor with shape (batch_size, ..., dim). eps: Epsilon for numerical stability. Returns: The computed VICReg variance loss. """ std = torch.sqrt(x.var(dim=0) + eps) loss = torch.mean(F.relu(max_std - std)) return loss
[docs] def covariance_loss(x: Tensor) -> Tensor: """Returns VICReg covariance loss. Generalized version of the covariance loss with support for tensors with more than two dimensions. Adapted from VICRegL: https://github.com/facebookresearch/VICRegL/blob/803ae4c8cd1649a820f03afb4793763e95317620/main_vicregl.py#L299 Args: x: Tensor with shape (batch_size, ..., dim). Returns: The computed VICReg covariance loss. """ # noqa x = x - x.mean(dim=0) batch_size = x.size(0) dim = x.size(-1) # nondiag_mask has shape (dim, dim) with 1s on all non-diagonal entries. nondiag_mask = ~torch.eye(dim, device=x.device, dtype=torch.bool) # cov has shape (..., dim, dim) cov = torch.einsum("b...c,b...d->...cd", x, x) / (batch_size - 1) loss = cov[..., nondiag_mask].pow(2).sum(-1) / dim return loss.mean()