Shortcuts

Source code for quaterion.eval.samplers.base_sampler

from typing import Sized, Tuple, Union

import torch
from quaterion_models import SimilarityModel
from torch import Tensor

from quaterion.eval.base_metric import BaseMetric


[docs]class BaseSampler: """Sample part of embeddings and targets to perform metric calculation on a part of the data Sampler allows reducing amount of time and resources to calculate a distance matrix. Instead of calculation of squared matrix with shape (num_embeddings, num_embeddings), it selects embeddings and computes matrix of a rectangle shape. Args: sample_size: amount of objects to select. """ def __init__( self, sample_size=-1, device: Union[torch.device, str, None] = None, log_progress: bool = True, ): self.log_progress = log_progress self.sample_size = sample_size self.device = device
[docs] def sample( self, dataset: Sized, metric: BaseMetric, model: SimilarityModel ) -> Tuple[Tensor, Tensor]: """Sample objects and labels to calculate metrics Args: dataset: Sized object, like list, tuple, torch.utils.data.Dataset, etc. to sample metric: metric instance to compute final labels representation model: model to encode objects Returns: Tensor, Tensor: metrics labels and computed distance matrix """ pass
[docs] def reset(self): """Reset accumulated state if any""" pass