Shortcuts

quaterion.loss.online_contrastive_loss module

class OnlineContrastiveLoss(margin: float | None = 0.5, distance_metric_name: Distance = Distance.COSINE, mining: str | None = 'hard')[source]

Bases: GroupLoss

Implements Contrastive Loss as defined in http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf

Unlike ContrastiveLoss, this one supports online pair mining, i.e., it makes positive and negative pairs on-the-fly, so you don’t need to form such pairs yourself. Instead, it first calculates all possible pairs in a batch, and then filters valid positive pairs and valid negative pairs separately. Batch-all and batch-hard strategies for online pair mining are supported.

Parameters:
  • margin – Margin value to push negative examples apart. Optional, defaults to 0.5.

  • distance_metric_name – Name of the distance function, e.g., Distance. Optional, defaults to COSINE.

  • mining (str, optional) – Pair mining strategy. One of “all”, “hard”. Defaults to “hard”.

forward(embeddings: Tensor, groups: LongTensor) Tensor[source]

Calculates Contrastive Loss by making pairs on-the-fly.

Parameters:
  • embeddings – Shape: (batch_size, vector_length) - Batch of embeddings

  • groups – Shape (batch_size,) Batch of labels associated with embeddings

Returns:

torch.Tensor – Scalar loss value.

get_config_dict()[source]

Config used in saving and loading purposes.

Config object has to be JSON-serializable.

Returns:

Dict[str, Any] – JSON-serializable dict of params

training: bool

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community