quaterion.loss.extras.pytorch_metric_learning_wrapper module¶
- class PytorchMetricLearningWrapper(loss: BaseMetricLossFunction, miner: BaseMiner | None = None)[source]¶
Bases:
GroupLoss
Provide a simple wrapper to be able to use losses and miners from pytorch-metric-learning.
You need to create loss (and optionally miner) instances yourself, and pass those instances to the constructor of this wrapper.
Note
This is an experimental feature that may be subject to change, deprecation or removal.
Note
See below for a quick usage example of this wrapper, but refer to the documentation of pytorch-metric-learning to learn more about individual losses and miners.
- Parameters:
loss – An instance of a loss object subclassing pytorch_metric_learning.losses.BaseMetricLossFunction.
miner – An instance of a miner object subclassing pytorch_metric_learning.miners.BaseMetric.
Example:
class MyTrainableModel(quaterion.TrainableModel): ... def configure_loss(self): loss = pytorch_metric_learning.losses.TripletMarginLoss() miner = pytorch_metric_learning.miner.MultiSimilarityMiner() return quaterion.loss.PytorchMetricLearningWrapper(loss, miner)
- forward(embeddings, groups)[source]¶
- Parameters:
embeddings – shape: (batch_size, vector_length)
groups – shape: (batch_size,) - Groups, associated with embeddings
- Returns:
Tensor – zero-size tensor, loss value
- training: bool¶