Source code for quaterion.loss.arcface_loss

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.loss.group_loss import GroupLoss
from quaterion.utils.utils import l2_norm

[docs]class ArcFaceLoss(GroupLoss): """Additive Angular Margin Loss as defined in Args: embedding_size: Output dimension of the encoder. num_groups: Number of groups in the dataset. scale: Scaling value to make cross entropy work. margin: Margin value to push groups apart. """ def __init__( self, embedding_size: int, num_groups: int, scale: float = 64.0, margin: float = 0.5, ): super(GroupLoss, self).__init__() self.kernel = nn.Parameter(torch.FloatTensor(embedding_size, num_groups)) nn.init.normal_(self.kernel, std=0.01) self.scale = scale self.margin = margin
[docs] def forward( self, embeddings: Tensor, groups: LongTensor, ) -> Tensor: """Compute loss value Args: embeddings: shape: (batch_size, vector_length) - Output embeddings from the encoder. groups: shape: (batch_size,) - Group ids associated with embeddings. Returns: Tensor: loss value. """ assert ( and ), f"Invalid group ids: all the values must be between 0 (inclusive) and num_groups (exclusive), but given: {groups}" embeddings = l2_norm(embeddings, 1) kernel_norm = l2_norm(self.kernel, 0) # Shape: (batch_size, num_groups) cos_theta =, kernel_norm) # insure numerical stability cos_theta = cos_theta.clamp(-1, 1) # Shape: (batch_size,) index = torch.where(groups != -1)[0] # Shape: (batch_size, num_groups) m_hot = torch.zeros( index.size()[0], cos_theta.size()[1], device=cos_theta.device ) m_hot.scatter_(1, groups[index, None], self.margin) cos_theta.acos_() cos_theta[index] += m_hot cos_theta.cos_().mul_(self.scale) # calculate scalar loss loss = F.cross_entropy(cos_theta, groups) return loss


Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning


Find people dealing with similar problems and get answers to your questions

Join Community