
Source code for quaterion.loss.contrastive_loss

from typing import Any, Dict, Type

import torch
from torch import LongTensor, Tensor

from quaterion.distances import Distance
from quaterion.loss.pairwise_loss import PairwiseLoss
from quaterion.utils import max_value_of_dtype

[docs]class ContrastiveLoss(PairwiseLoss): """Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. Further information: Args: distance_metric_name: Name of the function, e.g., :class:`~quaterion.distances.Distance`. Optional, defaults to :attr:`~quaterion.distances.Distance.COSINE`. margin: Negative samples (label == 0) should have a distance of at least the margin value. size_average: Average by the size of the mini-batch. """ def __init__( self, distance_metric_name: Distance = Distance.COSINE, margin: float = 0.5, size_average: bool = True, ): super().__init__(distance_metric_name=distance_metric_name) self.margin = margin self.size_average = size_average
[docs] def get_config_dict(self) -> Dict[str, Any]: """Config used in saving and loading purposes. Config object has to be JSON-serializable. Returns: Dict[str, Any]: JSON-serializable dict of params """ return { **super().get_config_dict(), "margin": self.margin, "size_average": self.size_average, }
[docs] def forward( self, embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs ) -> Tensor: """Compute loss value. Args: embeddings: Batch of embeddings, first half of embeddings are embeddings of first objects in pairs, second half are embeddings of second objects in pairs. pairs: Indices of corresponding objects in pairs. labels: Scores of positive and negative objects. subgroups: subgroups to distinguish objects which can and cannot be used as negative examples **kwargs: additional key-word arguments for generalization of loss call Returns: Tensor: averaged or summed loss value """ rep_anchor = embeddings[pairs[:, 0]] rep_other = embeddings[pairs[:, 1]] distances = self.distance_metric.distance(rep_anchor, rep_other) negative_distances_impact = 0.0 if len(subgroups.unique()) > 1: # shape (2 * batch_size, embeddings_size) embeddings_count = embeddings.shape[0] # `embeddings_count` consists of # number of embeddings for `obj_a` and `obj_b` # `subgroups` shape is (embeddings_count,) # shape (embeddings_count, embeddings_count) subgroup_matrix: Tensor = subgroups.repeat(embeddings_count, 1) # shape (embeddings_count, embeddings_count) comp_matrix: Tensor = subgroup_matrix != subgroup_matrix.T # a matrix to take into account only distances to negative # examples, i.e. from examples which don't belong to current # subgroup # shape (embeddings_count, embeddings_count) distance_matrix = self.distance_metric.distance_matrix(embeddings) distance_matrix[~comp_matrix] = max_value_of_dtype(distance_matrix.dtype) # shape (embeddings_count, 1) negative_distances, _ = distance_matrix.min(dim=1) # find negative examples # which are the closest to positive ones # shape (embeddings_count // 2, 1) neg_dist_to_anchors = negative_distances[pairs[:, 0]] # shape (embeddings_count // 2, 1) neg_dist_to_other = negative_distances[pairs[:, 1]] # shape (embeddings_count // 2, 1) negative_distances_impact = torch.relu( self.margin - neg_dist_to_anchors ).pow(2) + torch.relu(self.margin - neg_dist_to_other).pow(2) # shape (embeddings_count // 2, 1) losses = ( 0.5 * ( labels.float() * distances.pow(2) + (1 - labels).float() * torch.relu(self.margin - distances).pow(2) ) + negative_distances_impact ) return losses.mean() if self.size_average else losses.sum()


