Shortcuts

Source code for quaterion.dataset.similarity_data_loader

from typing import Any, Dict, Generic, List, Tuple, Union

import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data.dataloader import T_co

from quaterion.dataset.indexing_dataset import IndexingDataset, IndexingIterableDataset
from quaterion.dataset.label_cache_dataset import (
    LabelCacheDataset,
    LabelCacheIterableDataset,
    LabelCacheMode,
)
from quaterion.dataset.similarity_samples import (
    SimilarityGroupSample,
    SimilarityPairSample,
)


[docs]class SimilarityDataLoader(DataLoader, Generic[T_co]): """Special version of :class:`~torch.utils.data.DataLoader` which works with similarity samples. SimilarityDataLoader will automatically assign dummy collate_fn for debug purposes, it will be overwritten once dataloader is used for training. Required collate function should be defined individually for each encoder by overwriting :meth:`~quaterion_models.encoders.encoder.Encoder.get_collate_fn` Args: dataset: Dataset which outputs similarity samples **kwargs: Parameters passed directly into :meth:`~torch.utils.data.DataLoader.__init__` """ def __init__(self, dataset: Dataset, **kwargs): if "collate_fn" not in kwargs: kwargs["collate_fn"] = self.__class__.pre_collate_fn self._original_dataset = dataset self._original_params = kwargs self._indexing_dataset_layer = self._wrap_indexing_dataset(dataset) self._label_cache_layer = self._wrap_label_cache_dataset( self._indexing_dataset_layer ) super().__init__(self._label_cache_layer, **kwargs) @property def full_cache_used(self): return self._label_cache_layer.mode != LabelCacheMode.transparent
[docs] def set_salt(self, salt): """Assigns a new salt to the IndexingDataset. Might be useful to distinguish cache sequential keys for train and validation datasets. Args: salt: salt for index generation """ self._indexing_dataset_layer.set_salt(salt)
[docs] def set_skip_read(self, skip: bool): """Disable reading items in IndexingDataset. If cache is already filled and sequential key is used - it is not necessary to read dataset items the second time Args: skip: if True - do not read items, only indexes """ self._indexing_dataset_layer.set_skip_read(skip)
[docs] def set_label_cache_mode(self, mode: LabelCacheMode): """Manges how label caching works""" self._label_cache_layer.set_mode(mode)
[docs] def save_label_cache(self, path: str): self._label_cache_layer.save(path)
[docs] def load_label_cache(self, path: str): self._label_cache_layer.load(path)
@classmethod def _wrap_label_cache_dataset( cls, dataset: Union[IndexingIterableDataset, IndexingDataset] ) -> Union[LabelCacheDataset, LabelCacheIterableDataset]: if isinstance(dataset, IndexingDataset): return LabelCacheDataset(dataset) if isinstance(dataset, IndexingIterableDataset): return LabelCacheIterableDataset(dataset) raise NotImplementedError() @classmethod def _wrap_indexing_dataset( cls, dataset: Dataset ) -> Union[IndexingIterableDataset, IndexingDataset]: if isinstance(dataset, IterableDataset): return IndexingIterableDataset(dataset) else: return IndexingDataset(dataset) @property def original_params(self) -> Dict[str, Any]: """Initialization params of the original dataset.""" return self._original_params
[docs] @classmethod def pre_collate_fn(cls, batch: List[T_co]): """Function applied to batch before actual collate. Splits batch into features - arguments of prediction and labels - targets. Encoder-specific `collate_fn` will then be applied to feature list only. Loss functions consumes labels from this function without any additional transformations. Args: batch: List of similarity samples Returns: - ids of the features - features batch - labels batch """ sample_ids, similarity_samples = list(zip(*batch)) sample_ids = list(sample_ids) labels = cls.collate_labels(similarity_samples) features, feature_ids = cls.flatten_objects( batch=similarity_samples, hash_ids=sample_ids ) return feature_ids, features, labels
[docs] @classmethod def collate_labels(cls, batch: List[T_co]) -> Dict[str, torch.Tensor]: """Collate function for labels Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators. Args: batch: List of similarity samples Returns: Collated labels """ raise NotImplementedError()
[docs] @classmethod def flatten_objects( cls, batch: List[T_co], hash_ids: List[int] ) -> Tuple[List[Any], List[int]]: """Retrieve and enumerate objects from similarity samples. Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one. Args: batch: List of similarity samples hash_ids: pseudo-random ids of the similarity samples Returns: - List of input features for encoder collate - List of ids, associated with each feature """ raise NotImplementedError()
[docs]class PairsSimilarityDataLoader(SimilarityDataLoader[SimilarityPairSample]): """DataLoader designed to work with data represented as :class:`~quaterion.dataset.similarity_samples.SimilarityPairSample`. """ def __init__(self, dataset: Dataset[SimilarityPairSample], **kwargs): super().__init__(dataset, **kwargs)
[docs] @classmethod def collate_labels( cls, batch: List[SimilarityPairSample] ) -> Dict[str, torch.Tensor]: """Collate function for labels of :class:`~quaterion.dataset.similarity_samples.SimilarityPairSample` Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators. Args: batch: List of :class:`~quaterion.dataset.similarity_samples.SimilarityPairSample` Returns: Collated labels: - labels - tensor of scores for each input pair - pairs - pairs of id offsets of features, associated with respect labels - subgroups - subgroup id for each featire Examples: >>> labels_batch = PairsSimilarityDataLoader.collate_labels( ... [ ... SimilarityPairSample( ... obj_a="1st_pair_1st_obj", obj_b="1st_pair_2nd_obj", score=1.0, subgroup=0 ... ), ... SimilarityPairSample( ... obj_a="2nd_pair_1st_obj", obj_b="2nd_pair_2nd_obj", score=0.0, subgroup=1 ... ), ... ] ... ) >>> labels_batch['labels'] tensor([1., 0.]) >>> labels_batch['pairs'] tensor([[0, 2], [1, 3]]) >>> labels_batch['subgroups'] tensor([0., 1., 0., 1.]) """ labels = { "pairs": torch.LongTensor([[i, i + len(batch)] for i in range(len(batch))]), "labels": torch.Tensor([record.score for record in batch]), "subgroups": torch.Tensor([record.subgroup for record in batch] * 2), } return labels
[docs] @classmethod def flatten_objects( cls, batch: List[SimilarityPairSample], hash_ids: List[int] ) -> Tuple[List[Any], List[int]]: res_features = [] res_ids = [] # Preserve same order as in `collate_labels` for hash_id, item in zip(hash_ids, batch): res_features.append(item.obj_a) res_ids.append(hash_id) for hash_id, item in zip(hash_ids, batch): res_features.append(item.obj_b) res_ids.append(hash_id + 1) # Hashes are pseudo-random, so it is ok return res_features, res_ids
[docs]class GroupSimilarityDataLoader(SimilarityDataLoader[SimilarityGroupSample]): """DataLoader designed to work with data represented as :class:`~quaterion.dataset.similarity_samples.SimilarityGroupSample`. """ def __init__(self, dataset: Dataset[SimilarityGroupSample], **kwargs): super().__init__(dataset, **kwargs)
[docs] @classmethod def collate_labels( cls, batch: List[SimilarityGroupSample] ) -> Dict[str, torch.Tensor]: """Collate function for labels Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators. Args: batch: List of :class:`~quaterion.dataset.similarity_samples.SimilarityGroupSample` Returns: Collated labels: - groups -- id of the group for each feature object Examples: >>> GroupSimilarityDataLoader.collate_labels( ... [ ... SimilarityGroupSample(obj="orange", group=0), ... SimilarityGroupSample(obj="lemon", group=0), ... SimilarityGroupSample(obj="apple", group=1) ... ] ... ) {'groups': tensor([0, 0, 1])} """ labels = {"groups": torch.LongTensor([record.group for record in batch])} return labels
[docs] @classmethod def flatten_objects( cls, batch: List[SimilarityGroupSample], hash_ids: List[int] ) -> Tuple[List[Any], List[int]]: return [sample.obj for sample in batch], hash_ids

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community