Shortcuts

Source code for quaterion.loss.contrastive_loss

from typing import Any, Dict, Type

import torch
from torch import LongTensor, Tensor

from quaterion.distances import Distance
from quaterion.loss.pairwise_loss import PairwiseLoss
from quaterion.utils import max_value_of_dtype


[docs]class ContrastiveLoss(PairwiseLoss): """Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf Args: distance_metric_name: Name of the function, e.g., :class:`~quaterion.distances.Distance`. Optional, defaults to :attr:`~quaterion.distances.Distance.COSINE`. margin: Negative samples (label == 0) should have a distance of at least the margin value. size_average: Average by the size of the mini-batch. """ def __init__( self, distance_metric_name: Distance = Distance.COSINE, margin: float = 0.5, size_average: bool = True, ): super().__init__(distance_metric_name=distance_metric_name) self.margin = margin self.size_average = size_average
[docs] def get_config_dict(self) -> Dict[str, Any]: """Config used in saving and loading purposes. Config object has to be JSON-serializable. Returns: Dict[str, Any]: JSON-serializable dict of params """ return { **super().get_config_dict(), "margin": self.margin, "size_average": self.size_average, }
[docs] def forward( self, embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs ) -> Tensor: """Compute loss value. Args: embeddings: Batch of embeddings, first half of embeddings are embeddings of first objects in pairs, second half are embeddings of second objects in pairs. pairs: Indices of corresponding objects in pairs. labels: Scores of positive and negative objects. subgroups: subgroups to distinguish objects which can and cannot be used as negative examples **kwargs: additional key-word arguments for generalization of loss call Returns: Tensor: averaged or summed loss value """ rep_anchor = embeddings[pairs[:, 0]] rep_other = embeddings[pairs[:, 1]] distances = self.distance_metric.distance(rep_anchor, rep_other) negative_distances_impact = 0.0 if len(subgroups.unique()) > 1: # shape (2 * batch_size, embeddings_size) embeddings_count = embeddings.shape[0] # `embeddings_count` consists of # number of embeddings for `obj_a` and `obj_b` # `subgroups` shape is (embeddings_count,) # shape (embeddings_count, embeddings_count) subgroup_matrix: Tensor = subgroups.repeat(embeddings_count, 1) # shape (embeddings_count, embeddings_count) comp_matrix: Tensor = subgroup_matrix != subgroup_matrix.T # a matrix to take into account only distances to negative # examples, i.e. from examples which don't belong to current # subgroup # shape (embeddings_count, embeddings_count) distance_matrix = self.distance_metric.distance_matrix(embeddings) distance_matrix[~comp_matrix] = max_value_of_dtype(distance_matrix.dtype) # shape (embeddings_count, 1) negative_distances, _ = distance_matrix.min(dim=1) # find negative examples # which are the closest to positive ones # shape (embeddings_count // 2, 1) neg_dist_to_anchors = negative_distances[pairs[:, 0]] # shape (embeddings_count // 2, 1) neg_dist_to_other = negative_distances[pairs[:, 1]] # shape (embeddings_count // 2, 1) negative_distances_impact = torch.relu( self.margin - neg_dist_to_anchors ).pow(2) + torch.relu(self.margin - neg_dist_to_other).pow(2) # shape (embeddings_count // 2, 1) losses = ( 0.5 * ( labels.float() * distances.pow(2) + (1 - labels).float() * torch.relu(self.margin - distances).pow(2) ) + negative_distances_impact ) return losses.mean() if self.size_average else losses.sum()

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community