Shortcuts

Source code for quaterion.loss.fast_ap_loss

from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from quaterion.distances import Distance
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import get_anchor_negative_mask, get_anchor_positive_mask


[docs]class FastAPLoss(GroupLoss): """FastAP Loss Adaptation from https://github.com/kunhe/FastAP-metric-learning. Further information: https://cs-people.bu.edu/fcakir/papers/fastap_cvpr2019.pdf. "Deep Metric Learning to Rank" Fatih Cakir(*), Kun He(*), Xide Xia, Brian Kulis, and Stan Sclaroff IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2019 Args: num_bins:The number of soft histogram bins for calculating average precision. The paper suggests using 10. """ def __init__(self, num_bins: Optional[int] = 10): # Eucledian distance is the only compatible distance metric for FastAP Loss super(GroupLoss, self).__init__(distance_metric_name=Distance.EUCLIDEAN) self.num_bins = num_bins
[docs] def get_config_dict(self) -> Dict[str, Any]: """Config used in saving and loading purposes. Config object has to be JSON-serializable. Returns: Dict[str, Any]: JSON-serializable dict of params """ config = super().get_config_dict() config.update({"num_bins": self.num_bins}) return config
[docs] def forward( self, embeddings: Tensor, groups: Tensor, ) -> Tensor: """Compute loss value. Args: embeddings: shape: (batch_size, vector_length) - Batch of embeddings. groups: shape: (batch_size,) - Batch of labels associated with `embeddings`. Returns: Tensor: Scalar loss value. """ _warn = "Batch size of embeddings and groups don't match." batch_size = groups.size()[0] # batch size assert embeddings.size()[0] == batch_size, _warn device = embeddings.device # get the device of the embeddings tensor # 1. get positive and negative masks pos_mask = get_anchor_positive_mask(groups).to( device ) # (batch_size, batch_size) neg_mask = get_anchor_negative_mask(groups).to( device ) # (batch_size, batch_size) n_pos = torch.sum(pos_mask, dim=1) # Sum over all columns (for each row) # 2. compute distances from embeddings squared Euclidean distance matrix embeddings = F.normalize(embeddings, p=2, dim=1).to( device ) # normalize embeddings dist_matrix = ( self.distance_metric.distance_matrix(embeddings).to(device) ** 2 ) # (batch_size, batch_size) # 3. estimate discrete histograms histogram_delta = torch.tensor(4.0 / self.num_bins, device=device) mid_points = torch.linspace( 0.0, 4.0, steps=self.num_bins + 1, device=device ).view(-1, 1, 1) pulse = F.relu( input=1 - torch.abs(dist_matrix - mid_points) / histogram_delta ).to( device ) # max(0, input) pos_hist = torch.t(torch.sum(pulse * pos_mask, dim=2)).to( device ) # positive histograms neg_hist = torch.t(torch.sum(pulse * neg_mask, dim=2)).to( device ) # negative histograms total_pos_hist = torch.cumsum(pos_hist, dim=1).to(device) total_hist = torch.cumsum(pos_hist + neg_hist, dim=1).to(device) # 4. compute FastAP FastAP = pos_hist * total_pos_hist / total_hist FastAP[torch.isnan(FastAP) | torch.isinf(FastAP)] = 0 FastAP = torch.sum(FastAP, 1) / n_pos FastAP = FastAP[~torch.isnan(FastAP)] loss = 1 - torch.mean(FastAP) return loss

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community