Shortcuts

Source code for quaterion.eval.samplers.pair_sampler

import random
from collections.abc import Sized
from typing import Tuple, Union

import torch
from quaterion_models import SimilarityModel

from quaterion.dataset.similarity_data_loader import PairsSimilarityDataLoader
from quaterion.eval.accumulators import PairAccumulator
from quaterion.eval.pair import PairMetric
from quaterion.eval.samplers import BaseSampler
from quaterion.utils.utils import iter_by_batch


[docs]class PairSampler(BaseSampler): """Perform selection of embeddings and targets for pairs based tasks. Sampler allows reducing amount of time and resources to calculate a distance matrix. Instead of calculation of squared matrix with shape (num_embeddings, num_embeddings), it selects embeddings and computes matrix of a rectangle shape. Args: sample_size: int - amount of objects to select distinguish: bool - determines whether to compare all objects each-to-each, or to compare only `obj_a` to `obj_b`. If true - compare only `obj_a` to `obj_b`. Significantly reduces matrix size. encode_batch_size: int - batch size to use during encoding """ def __init__( self, sample_size: int = -1, distinguish: bool = False, encode_batch_size: int = 16, device: Union[torch.device, str, None] = None, log_progress: bool = True, ): super().__init__(sample_size, device, log_progress) self.encode_batch_size = encode_batch_size self.distinguish = distinguish self.accumulator = PairAccumulator()
[docs] def accumulate(self, model: SimilarityModel, dataset: Sized): """Encodes objects and accumulates embeddings with the corresponding raw labels Args: model: model to encode objects dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to accumulate """ for input_batch in iter_by_batch( dataset, self.encode_batch_size // 2, self.log_progress ): batch_labels = PairsSimilarityDataLoader.collate_labels(input_batch) objects_a, objects_b = [], [] for similarity_sample in input_batch: objects_a.append(similarity_sample.obj_a) objects_b.append(similarity_sample.obj_b) features = objects_a + objects_b embeddings = model.encode( features, batch_size=self.encode_batch_size, to_numpy=False ) self.accumulator.update(embeddings, **batch_labels, device=self.device) self.accumulator.set_filled()
[docs] def reset(self): """Reset accumulated state""" self.accumulator.reset()
[docs] def sample( self, dataset: Sized, metric: PairMetric, model: SimilarityModel ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample embeddings and targets for pairs based tasks. Args: dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to sample metric: PairMetric instance to compute final labels representation model: model to encode objects Returns: torch.Tensor, torch.Tensor: metrics labels and computed distance matrix """ if not self.accumulator.filled: self.accumulate(model, dataset) embeddings = self.accumulator.embeddings pairs = self.accumulator.pairs labels = metric.prepare_labels( self.accumulator.labels, pairs, self.accumulator.subgroups ) embeddings_num = embeddings.shape[0] max_sample_size = embeddings_num if not self.distinguish else pairs.shape[0] if self.sample_size > 0: sample_size = min(self.sample_size, max_sample_size) else: sample_size = max_sample_size sample_indices = torch.LongTensor( random.sample(range(max_sample_size), k=sample_size) ) labels = labels[sample_indices] if self.distinguish: ref_embeddings = embeddings[pairs[sample_indices][:, 0]] embeddings = embeddings[pairs[:, 1]] labels = labels[:, pairs[:, 1]] distance_matrix = metric.distance.distance_matrix( ref_embeddings, embeddings ) else: ref_embeddings = embeddings[sample_indices] distance_matrix = metric.distance.distance_matrix( ref_embeddings, embeddings ) device = embeddings.device self_mask = ( torch.arange( 0, distance_matrix.shape[0], dtype=torch.long, device=device ) .view(-1, 1) .to(device) ) self_mask = torch.cat( [self_mask, sample_indices.view(-1, 1).to(device)], dim=1 ) distance_matrix[self_mask[:, 0], self_mask[:, 1]] = ( distance_matrix.max() + 1 ) return labels.float(), distance_matrix