from typing import Optional

import torch
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.distances import Distance
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import get_anchor_positive_mask, max_value_of_dtype

[docs]class OnlineContrastiveLoss(GroupLoss): """Implements Contrastive Loss as defined in Unlike :class:`~quaterion.loss.contrastive_loss.ContrastiveLoss`, this one supports online pair mining, i.e., it makes positive and negative pairs on-the-fly, so you don't need to form such pairs yourself. Instead, it first calculates all possible pairs in a batch, and then filters valid positive pairs and valid negative pairs separately. Batch-all and batch-hard strategies for online pair mining are supported. Args: margin: Margin value to push negative examples apart. Optional, defaults to `0.5`. distance_metric_name: Name of the distance function, e.g., :class:`~quaterion.distances.Distance`. Optional, defaults to :attr:`~quaterion.distances.Distance.COSINE`. mining (str, optional): Pair mining strategy. One of `"all"`, `"hard"`. Defaults to `"hard"`. """ def __init__( self, margin: Optional[float] = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: Optional[str] = "hard", ): mining_types = ["all", "hard"] if mining not in mining_types: raise ValueError( f"Unrecognized mining strategy: {mining}. Must be one of {', '.join(mining_types)}" ) super(OnlineContrastiveLoss, self).__init__( distance_metric_name=distance_metric_name ) self._margin = margin self._mining = mining
[docs] def get_config_dict(self): config = super().get_config_dict() config.update( { "margin": self._margin, "mining": self._mining, } ) return config
[docs] def forward( self, embeddings: Tensor, groups: LongTensor, ) -> Tensor: """Calculates Contrastive Loss by making pairs on-the-fly. Args: embeddings: Shape: (batch_size, vector_length) - Batch of embeddings groups: Shape (batch_size,) Batch of labels associated with `embeddings` Returns: torch.Tensor: Scalar loss value. """ # Shape: (batch_size, batch_size) dists = self.distance_metric.distance_matrix(embeddings) # get a mask for valid anchor-positive pairs and apply it to the distance matrix # to set invalid ones to 0 anchor_positive_mask = get_anchor_positive_mask(groups, groups) anchor_positive_dists = ( anchor_positive_mask.float() * dists ) # invalid pairs set to 0 # get a mask for valid anchor-negative pairs, and apply it to distance matrix # # to set invalid ones to a maximum value of dtype anchor_negative_mask = ~anchor_positive_mask anchor_negative_dists = dists anchor_negative_dists[~anchor_negative_mask] = max_value_of_dtype( anchor_negative_dists.dtype ) if self._mining == "all": num_positive_pairs = anchor_positive_mask.sum() positive_loss = anchor_positive_dists.sum() / torch.max( num_positive_pairs, torch.tensor(1e-16) ) num_negative_pairs = anchor_negative_mask.float().sum() negative_loss = F.relu( self._margin - anchor_negative_dists ).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16)) else: # batch-hard pair mining # get the hardest positive for each anchor # shape: (batch_size,) hardest_positive_dists = anchor_positive_dists.max(dim=1)[0] num_positive_pairs = torch.count_nonzero(hardest_positive_dists) positive_loss = hardest_positive_dists.sum() / torch.max( num_positive_pairs, torch.tensor(1e-16) ) # get the hardest negative for each anchor # shape (batch_size,) hardest_negative_dists = anchor_negative_dists.min(dim=1)[0] num_negative_pairs = torch.sum( ( hardest_negative_dists < max_value_of_dtype( hardest_negative_dists.dtype ) # It's True where we didn't set to this maximum value to mark them invalid ).float() ) negative_loss = F.relu( self._margin - hardest_negative_dists ).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16)) total_loss = 0.5 * (positive_loss + negative_loss) return total_loss


