Shortcuts

Source code for quaterion.loss.circle_loss

from typing import Optional

import torch
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.distances import Distance
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import get_anchor_negative_mask, get_anchor_positive_mask


[docs]class CircleLoss(GroupLoss): """Implements Circle Loss as defined in https://arxiv.org/abs/2002.10857. Args: margin: Margin value to push negative examples. scale_factor: scale factor γ determines the largest scale of each similarity score. Note: Refer to sections 4.1 and 4.5 in the paper for default values and evaluation of margin and scaling_factor hyperparameters. """ def __init__( self, margin: Optional[float] = 0.25, scale_factor: Optional[float] = 256, distance_metric_name: Optional[Distance] = Distance.COSINE, ): super(GroupLoss, self).__init__(distance_metric_name=distance_metric_name) self.margin = margin self.scale_factor = scale_factor self.op = 1 + self.margin self.on = -self.margin self.delta_positive = 1 - self.margin self.delta_negative = self.margin
[docs] def forward( self, embeddings: Tensor, groups: LongTensor, ) -> Tensor: """Compute loss value. Args: embeddings: shape: (batch_size, vector_length) - Batch of embeddings. groups: shape: (batch_size,) - Batch of labels associated with `embeddings` Returns: Tensor: Scalar loss value. """ # Shape: (batch_size, batch_size) similarity_matrix = self.distance_metric.similarity_matrix(embeddings) # Shape: (batch_size * batch_size,) similarity_matrix = torch.reshape(similarity_matrix, [-1]) # Shape: (batch_size, batch_size) pos_mask = get_anchor_positive_mask(groups, groups).triu(diagonal=1) # Shape: (batch_size, batch_size) neg_mask = get_anchor_negative_mask(groups, groups).triu(diagonal=1) # Shape: (batch_size * batch_size,) pos_mask = torch.reshape(pos_mask, [-1]) # Shape: (batch_size * batch_size,) neg_mask = torch.reshape(neg_mask, [-1]) sp = similarity_matrix[pos_mask] sn = similarity_matrix[neg_mask] # get alpha-positive and alpha-negative weights as described in https://arxiv.org/abs/2002.10857. ap = torch.clamp_min(self.op + sp.detach(), min=0) an = torch.clamp_min(self.on + sn.detach(), min=0) exp_p = -ap * self.scale_factor * (sp - self.delta_positive) exp_n = an * self.scale_factor * (sn - self.delta_negative) circle_loss = F.softplus( torch.logsumexp(exp_n, dim=0) + torch.logsumexp(exp_p, dim=0) ) return circle_loss

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community