Shortcuts

Source code for quaterion.loss.multiple_negatives_ranking_loss

from typing import Any, Dict, Type

import torch
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.distances import Distance
from quaterion.loss.pairwise_loss import PairwiseLoss


[docs]class MultipleNegativesRankingLoss(PairwiseLoss): """Implement Multiple Negatives Ranking Loss as described in https://arxiv.org/pdf/1705.00652.pdf This loss function works only with positive pairs, e.g., an `anchor` and a `positive`. For each pair, it uses `positive` of other pairs in the batch as negatives, so you don't need to worry about specifying negative examples. It is great for retrieval tasks such as question-answer retrieval, duplicate sentence retrieval, and cross-modal retrieval. It accepts pairs of anchor and positive embeddings to calculate a similarity matrix between them. Then, it minimizes negative log-likelihood for softmax-normalized similarity scores. This optimizes retrieval of the correct positive pair when an anchor given. Note: :attr:`~quaterion.dataset.similarity_samples.SimilarityPairSample.score` and :attr:`~quaterion.dataset.similarity_samples.SimilarityPairSample.subgroup` values are ignored for this loss, assuming :attr:`~quaterion.dataset.similarity_samples.SimilarityPairSample.obj_a` and :attr:`~quaterion.dataset.similarity_samples.SimilarityPairSample.obj_b` form a positive pair, e.g., `label = 1`. Args: scale: Scaling value for multiplying with similarity scores to make cross-entropy work. distance_metric_name: Name of the metric to calculate similarities between embeddings, e.g., :class:`~quaterion.distances.Distance`. Optional, defaults to :attr:`~quaterion.distances.Distance.COSINE`. If :attr:`~quaterion.distances.Distance.DOT_PRODUCT`, `scale` must be `1`. symmetric: If True, loss is symmetric, i.e., it also accounts for retrieval of the correct anchor when a positive given. """ def __init__( self, scale: float = 20.0, distance_metric_name: Distance = Distance.COSINE, symmetric: bool = False, ): super().__init__(distance_metric_name=distance_metric_name) self._scale = scale self._symmetric = symmetric
[docs] def get_config_dict(self) -> Dict[str, Any]: """Config used in saving and loading purposes. Config object has to be JSON-serializable. Returns: Dict[str, Any]: JSON-serializable dict of params """ config = self.get_config_dict() config.update( { "scale": self._scale, "symmetric": self._symmetric, } ) return config
[docs] def forward( self, embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs, ) -> Tensor: """Compute loss value. Args: embeddings: Batch of embeddings, first half of embeddings are embeddings of first objects in pairs, second half are embeddings of second objects in pairs. pairs: Indices of corresponding objects in pairs. labels: Ignored for this loss. Labels will be automatically formed from `pairs`. subgroups: Ignored for this loss. **kwargs: Additional key-word arguments for generalization of loss call Returns: Tensor: Scalar loss value """ _warn = ( "You seem to be using non-positive pairs. " "Make sure that `SimilarityPairSample.obj_a` and `SimilarityPairSample.obj_b` " "are positive pairs with a score of `1`" ) assert labels is None or labels.sum() == labels.size()[0], _warn rep_anchor = embeddings[pairs[:, 0]] rep_positive = embeddings[pairs[:, 1]] # get similarity matrix to be used as logits # shape: (batch_size, batch_size) logits = self.distance_metric.similarity_matrix(rep_anchor, rep_positive) logits *= self._scale # create integer label IDs labels = torch.arange( start=0, end=logits.size()[0], dtype=torch.long, device=logits.device ) # calculate loss loss = F.cross_entropy(logits, labels) if self._symmetric: loss += F.cross_entropy(logits.transpose(0, 1), labels) loss /= 2 return loss

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community