Source code for quaterion.loss.center_loss
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import l2_norm
[docs]class CenterLoss(GroupLoss):
"""
Center Loss as defined in the paper "A Discriminative Feature Learning Approach
for Deep Face Recognition" (http://ydwen.github.io/papers/WenECCV16.pdf)
It aims to minimize the intra-class variations while keeping the features of
different classes separable.
Args:
embedding_size: Output dimension of the encoder.
num_groups: Number of groups (classes) in the dataset.
lambda_c: A regularization parameter that controls the contribution of the center loss.
"""
def __init__(
self, embedding_size: int, num_groups: int, lambda_c: Optional[float] = 0.5
):
super(GroupLoss, self).__init__()
self.num_groups = num_groups
self.centers = nn.Parameter(torch.randn(num_groups, embedding_size))
self.lambda_c = lambda_c
nn.init.xavier_uniform_(self.centers)
[docs] def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor:
"""
Compute the Center Loss value.
Args:
embeddings: shape (batch_size, vector_length) - Output embeddings from the encoder.
groups: shape (batch_size,) - Group (class) ids associated with embeddings.
Returns:
Tensor: loss value.
"""
embeddings = l2_norm(embeddings, 1)
# Gather the center for each embedding's corresponding group
centers_batch = self.centers.index_select(0, groups)
# Calculate the distance between embeddings and their respective class centers
loss = F.mse_loss(embeddings, centers_batch)
# Scale the loss by the regularization parameter
loss *= self.lambda_c
return loss