- class MultipleNegativesRankingLoss(scale: float = 20.0, distance_metric_name: Distance = Distance.COSINE, symmetric: bool = False)[source]¶
Implement Multiple Negatives Ranking Loss as described in https://arxiv.org/pdf/1705.00652.pdf
This loss function works only with positive pairs, e.g., an anchor and a positive. For each pair, it uses positive of other pairs in the batch as negatives, so you don’t need to worry about specifying negative examples. It is great for retrieval tasks such as question-answer retrieval, duplicate sentence retrieval, and cross-modal retrieval. It accepts pairs of anchor and positive embeddings to calculate a similarity matrix between them. Then, it minimizes negative log-likelihood for softmax-normalized similarity scores. This optimizes retrieval of the correct positive pair when an anchor given.
subgroupvalues are ignored for this loss, assuming
obj_bform a positive pair, e.g., label = 1.
scale – Scaling value for multiplying with similarity scores to make cross-entropy work.
distance_metric_name – Name of the metric to calculate similarities between embeddings, e.g.,
Distance. Optional, defaults to
DOT_PRODUCT, scale must be 1.
symmetric – If True, loss is symmetric, i.e., it also accounts for retrieval of the correct anchor when a positive given.
- forward(embeddings: Tensor, pairs: LongTensor, labels: Tensor, subgroups: Tensor, **kwargs) Tensor [source]¶
Compute loss value.
embeddings – Batch of embeddings, first half of embeddings are embeddings of first objects in pairs, second half are embeddings of second objects in pairs.
pairs – Indices of corresponding objects in pairs.
labels – Ignored for this loss. Labels will be automatically formed from pairs.
subgroups – Ignored for this loss.
**kwargs – Additional key-word arguments for generalization of loss call
Tensor – Scalar loss value
- get_config_dict() Dict[str, Any] [source]¶
Config used in saving and loading purposes.
Config object has to be JSON-serializable.
Dict[str, Any] – JSON-serializable dict of params
- training: bool¶