Shortcuts

Source code for quaterion.eval.samplers.pair_sampler

import random
from collections.abc import Sized
from typing import Tuple, Union

import torch
from quaterion_models import SimilarityModel

from quaterion.dataset.similarity_data_loader import PairsSimilarityDataLoader
from quaterion.eval.accumulators import PairAccumulator
from quaterion.eval.pair import PairMetric
from quaterion.eval.samplers import BaseSampler
from quaterion.utils.utils import iter_by_batch


[docs]class PairSampler(BaseSampler): """Perform selection of embeddings and targets for pairs based tasks. Sampler allows reducing amount of time and resources to calculate a distance matrix. Instead of calculation of squared matrix with shape (num_embeddings, num_embeddings), it selects embeddings and computes matrix of a rectangle shape. Args: sample_size: int - amount of objects to select distinguish: bool - determines whether to compare all objects each-to-each, or to compare only `obj_a` to `obj_b`. If true - compare only `obj_a` to `obj_b`. Significantly reduces matrix size. encode_batch_size: int - batch size to use during encoding """ def __init__( self, sample_size: int = -1, distinguish: bool = False, encode_batch_size: int = 16, device: Union[torch.device, str, None] = None, log_progress: bool = True, ): super().__init__(sample_size, device, log_progress) self.encode_batch_size = encode_batch_size self.distinguish = distinguish self.accumulator = PairAccumulator()
[docs] def accumulate(self, model: SimilarityModel, dataset: Sized): """Encodes objects and accumulates embeddings with the corresponding raw labels Args: model: model to encode objects dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to accumulate """ for input_batch in iter_by_batch( dataset, self.encode_batch_size // 2, self.log_progress ): batch_labels = PairsSimilarityDataLoader.collate_labels(input_batch) objects_a, objects_b = [], [] for similarity_sample in input_batch: objects_a.append(similarity_sample.obj_a) objects_b.append(similarity_sample.obj_b) features = objects_a + objects_b embeddings = model.encode( features, batch_size=self.encode_batch_size, to_numpy=False ) self.accumulator.update(embeddings, **batch_labels, device=self.device) self.accumulator.set_filled()
[docs] def reset(self): """Reset accumulated state""" self.accumulator.reset()
[docs] def sample( self, dataset: Sized, metric: PairMetric, model: SimilarityModel ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample embeddings and targets for pairs based tasks. Args: dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to sample metric: PairMetric instance to compute final labels representation model: model to encode objects Returns: torch.Tensor, torch.Tensor: metrics labels and computed distance matrix """ if not self.accumulator.filled: self.accumulate(model, dataset) embeddings = self.accumulator.embeddings pairs = self.accumulator.pairs labels = metric.prepare_labels( self.accumulator.labels, pairs, self.accumulator.subgroups ) embeddings_num = embeddings.shape[0] max_sample_size = embeddings_num if not self.distinguish else pairs.shape[0] if self.sample_size > 0: sample_size = min(self.sample_size, max_sample_size) else: sample_size = max_sample_size sample_indices = torch.LongTensor( random.sample(range(max_sample_size), k=sample_size) ) labels = labels[sample_indices] if self.distinguish: ref_embeddings = embeddings[pairs[sample_indices][:, 0]] embeddings = embeddings[pairs[:, 1]] labels = labels[:, pairs[:, 1]] distance_matrix = metric.distance.distance_matrix( ref_embeddings, embeddings ) else: ref_embeddings = embeddings[sample_indices] distance_matrix = metric.distance.distance_matrix( ref_embeddings, embeddings ) device = embeddings.device self_mask = ( torch.arange( 0, distance_matrix.shape[0], dtype=torch.long, device=device ) .view(-1, 1) .to(device) ) self_mask = torch.cat( [self_mask, sample_indices.view(-1, 1).to(device)], dim=1 ) distance_matrix[self_mask[:, 0], self_mask[:, 1]] = ( distance_matrix.max() + 1 ) return labels.float(), distance_matrix

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community