Shortcuts

Source code for quaterion.eval.samplers.group_sampler

import random
from typing import Iterable, Sized, Tuple, Union

import torch
from quaterion_models import SimilarityModel
from torch.utils.data import Dataset

from quaterion.dataset.similarity_data_loader import GroupSimilarityDataLoader
from quaterion.eval.accumulators import GroupAccumulator
from quaterion.eval.group import GroupMetric
from quaterion.eval.samplers import BaseSampler
from quaterion.utils.utils import iter_by_batch


[docs]class GroupSampler(BaseSampler): """Perform selection of embeddings and targets for group based tasks.""" def __init__( self, sample_size=-1, encode_batch_size=16, device: Union[torch.device, str, None] = None, log_progress: bool = True, ): super().__init__(sample_size, device, log_progress) self.encode_batch_size = encode_batch_size self.accumulator = GroupAccumulator()
[docs] def accumulate( self, model: SimilarityModel, dataset: Union[Sized, Iterable, Dataset] ): """Encodes objects and accumulates embeddings with the corresponding raw labels Args: model: model to encode objects dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to accumulate """ for input_batch in iter_by_batch( dataset, self.encode_batch_size, self.log_progress ): batch_labels = GroupSimilarityDataLoader.collate_labels(input_batch) features = [similarity_sample.obj for similarity_sample in input_batch] embeddings = model.encode( features, batch_size=self.encode_batch_size, to_numpy=False ) self.accumulator.update(embeddings, **batch_labels, device=self.device) self.accumulator.set_filled()
[docs] def reset(self): """Reset accumulated state""" self.accumulator.reset()
[docs] def sample( self, dataset: Sized, metric: GroupMetric, model: SimilarityModel ) -> Tuple[torch.Tensor, torch.Tensor]: """Sample embeddings and targets for groups based tasks. Args: dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to sample metric: GroupMetric instance to compute final labels representation model: model to encode objects Returns: torch.Tensor, torch.Tensor: metrics labels and computed distance matrix """ if not self.accumulator.filled: self.accumulate(model, dataset) embeddings = self.accumulator.embeddings labels = metric.prepare_labels(self.accumulator.groups) max_sample_size = embeddings.shape[0] if self.sample_size > 0: sample_size = min(self.sample_size, max_sample_size) else: sample_size = max_sample_size sample_indices = torch.LongTensor( random.sample(range(max_sample_size), k=sample_size) ) labels = labels[sample_indices] distance_matrix = metric.distance.distance_matrix( embeddings[sample_indices], embeddings ) device = embeddings.device self_mask = ( torch.arange(0, distance_matrix.shape[0]) .view(-1, 1) .repeat(1, 2) .to(device) ) distance_matrix[self_mask[:, 0], self_mask[:, 1]] = distance_matrix.max() + 1 return labels.float(), distance_matrix

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community