Shortcuts

quaterion.loss.online_contrastive_loss module

class OnlineContrastiveLoss(margin: float | None = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: str | None = 'hard')[source]

Bases: GroupLoss

Implements Contrastive Loss as defined in http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

Unlike ContrastiveLoss, this one supports online pair mining, i.e., it makes positive and negative pairs on-the-fly, so you don’t need to form such pairs yourself. Instead, it first calculates all possible pairs in a batch, and then filters valid positive pairs and valid negative pairs separately. Batch-all and batch-hard strategies for online pair mining are supported.

Parameters:
  • margin – Margin value to push negative examples apart. Optional, defaults to 0.5.

  • distance_metric_name – Name of the distance function, e.g., Distance. Optional, defaults to COSINE.

  • mining (str, optional) – Pair mining strategy. One of “all”, “hard”. Defaults to “hard”.

forward(embeddings: Tensor, groups: LongTensor) Tensor[source]

Calculates Contrastive Loss by making pairs on-the-fly.

Parameters:
  • embeddings – Shape: (batch_size, vector_length) - Batch of embeddings

  • groups – Shape (batch_size,) Batch of labels associated with embeddings

Returns:

torch.Tensor – Scalar loss value.

get_config_dict()[source]

Config used in saving and loading purposes.

Config object has to be JSON-serializable.

Returns:

Dict[str, Any] – JSON-serializable dict of params

training: bool