Shortcuts

Source code for quaterion.eval.accumulators.pair_accumulator

from typing import Dict

import torch

from quaterion.eval.accumulators import Accumulator


[docs]class PairAccumulator(Accumulator): """Accumulate embeddings, labels, pairs and subgroups for pair-based tasks. Keep track of current size to properly handle pairs. """ def __init__(self): super().__init__() self._labels = [] self._pairs = [] self._subgroups = [] self._accumulated_size = 0 @property def state(self) -> Dict[str, torch.Tensor]: """Accumulated state Returns: Dict[str, torch.Tensor] - dictionary accumulates embeddings, labels, pairs, subgroups. """ state = super().state state.update( {"labels": self.labels, "pairs": self.pairs, "subgroups": self._subgroups} ) return state
[docs] def update( self, embeddings: torch.Tensor, labels: torch.Tensor, pairs: torch.LongTensor, subgroups: torch.Tensor, device=None, ): """Update accumulator state. Move provided embeddings and groups to proper device and add to accumulated state. Args: embeddings: embeddings to accumulate labels: labels to distinguish similar and dissimilar objects. pairs: indices to determine objects of one pair subgroups: subgroups numbers to determine which samples can be considered negative device: device to store calculated embeddings and groups on. """ device = device if device else embeddings.device embeddings = embeddings.detach().to(device) labels = labels.detach().to(device) pairs = pairs.detach().to(device) subgroups = subgroups.detach().to(device) self._embeddings.append(embeddings) self._labels.append(labels) self._pairs.append(pairs + self._accumulated_size) self._subgroups.append(subgroups) self._accumulated_size += embeddings.shape[0]
[docs] def reset(self): """Reset accumulator state Reset accumulator status and size, accumulated embeddings, labels, pairs and subgroups """ super().reset() self._labels = [] self._pairs = [] self._subgroups = [] self._accumulated_size = 0
@property def labels(self): """Concatenate list of labels to Tensor Help to avoid concatenating labels for each batch during accumulation. Instead, concatenate it only on call. Returns: torch.Tensor: batch of labels """ return torch.cat(self._labels) if len(self._labels) else torch.Tensor() @property def subgroups(self): """Concatenate list of subgroups to Tensor Help to avoid concatenating subgroups for each batch during accumulation. Instead, concatenate it only on call. Returns: torch.Tensor: batch of subgroups """ return torch.cat(self._subgroups) if len(self._subgroups) else torch.Tensor() @property def pairs(self) -> torch.LongTensor: """Concatenate list of pairs to Tensor Help to avoid concatenating pairs for each batch during accumulation. Instead, concatenate it only on call. Returns: torch.Tensor: batch of pairs """ return torch.cat(self._pairs) if len(self._pairs) else torch.LongTensor()

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community