
Source code for quaterion.loss.center_loss

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import l2_norm

[docs]class CenterLoss(GroupLoss): """ Center Loss as defined in the paper "A Discriminative Feature Learning Approach for Deep Face Recognition" ( It aims to minimize the intra-class variations while keeping the features of different classes separable. Args: embedding_size: Output dimension of the encoder. num_groups: Number of groups (classes) in the dataset. lambda_c: A regularization parameter that controls the contribution of the center loss. """ def __init__( self, embedding_size: int, num_groups: int, lambda_c: Optional[float] = 0.5 ): super(GroupLoss, self).__init__() self.num_groups = num_groups self.centers = nn.Parameter(torch.randn(num_groups, embedding_size)) self.lambda_c = lambda_c nn.init.xavier_uniform_(self.centers)
[docs] def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor: """ Compute the Center Loss value. Args: embeddings: shape (batch_size, vector_length) - Output embeddings from the encoder. groups: shape (batch_size,) - Group (class) ids associated with embeddings. Returns: Tensor: loss value. """ embeddings = l2_norm(embeddings, 1) # Gather the center for each embedding's corresponding group centers_batch = self.centers.index_select(0, groups) # Calculate the distance between embeddings and their respective class centers loss = F.mse_loss(embeddings, centers_batch) # Scale the loss by the regularization parameter loss *= self.lambda_c return loss


Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning


Find people dealing with similar problems and get answers to your questions

Join Community