Shortcuts

quaterion.dataset.similarity_data_loader module

class GroupSimilarityDataLoader(dataset: Dataset[SimilarityGroupSample], **kwargs)[source]

Bases: SimilarityDataLoader[SimilarityGroupSample]

DataLoader designed to work with data represented as SimilarityGroupSample.

classmethod collate_labels(batch: List[SimilarityGroupSample]) Dict[str, Tensor][source]

Collate function for labels

Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators.

Parameters:

batch – List of SimilarityGroupSample

Returns:

Collated labels

  • groups – id of the group for each feature object

Examples

>>> GroupSimilarityDataLoader.collate_labels(
...     [
...         SimilarityGroupSample(obj="orange", group=0),
...         SimilarityGroupSample(obj="lemon", group=0),
...         SimilarityGroupSample(obj="apple", group=1)
...     ]
... )
{'groups': tensor([0, 0, 1])}
classmethod flatten_objects(batch: List[SimilarityGroupSample], hash_ids: List[int]) Tuple[List[Any], List[int]][source]

Retrieve and enumerate objects from similarity samples.

Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one.

Parameters:
  • batch – List of similarity samples

  • hash_ids – pseudo-random ids of the similarity samples

Returns:

  • List of input features for encoder collate

  • List of ids, associated with each feature

batch_size: int | None
dataset: Dataset[T_co]
drop_last: bool
num_workers: int
pin_memory: bool
pin_memory_device: str
prefetch_factor: int | None
sampler: Sampler | Iterable
timeout: float
class PairsSimilarityDataLoader(dataset: Dataset[SimilarityPairSample], **kwargs)[source]

Bases: SimilarityDataLoader[SimilarityPairSample]

DataLoader designed to work with data represented as SimilarityPairSample.

classmethod collate_labels(batch: List[SimilarityPairSample]) Dict[str, Tensor][source]

Collate function for labels of SimilarityPairSample

Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators.

Parameters:

batch – List of SimilarityPairSample

Returns:

Collated labels

  • labels - tensor of scores for each input pair

  • pairs - pairs of id offsets of features, associated with respect labels

  • subgroups - subgroup id for each featire

Examples

>>> labels_batch = PairsSimilarityDataLoader.collate_labels(
...     [
...         SimilarityPairSample(
...             obj_a="1st_pair_1st_obj", obj_b="1st_pair_2nd_obj", score=1.0, subgroup=0
...         ),
...         SimilarityPairSample(
...             obj_a="2nd_pair_1st_obj", obj_b="2nd_pair_2nd_obj", score=0.0, subgroup=1
...         ),
...     ]
... )
>>> labels_batch['labels']
tensor([1., 0.])
>>> labels_batch['pairs']
tensor([[0, 2],
        [1, 3]])
>>> labels_batch['subgroups']
tensor([0., 1., 0., 1.])
classmethod flatten_objects(batch: List[SimilarityPairSample], hash_ids: List[int]) Tuple[List[Any], List[int]][source]

Retrieve and enumerate objects from similarity samples.

Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one.

Parameters:
  • batch – List of similarity samples

  • hash_ids – pseudo-random ids of the similarity samples

Returns:

  • List of input features for encoder collate

  • List of ids, associated with each feature

batch_size: int | None
dataset: Dataset[T_co]
drop_last: bool
num_workers: int
pin_memory: bool
pin_memory_device: str
prefetch_factor: int | None
sampler: Sampler | Iterable
timeout: float
class SimilarityDataLoader(dataset: Dataset, **kwargs)[source]

Bases: DataLoader, Generic[T_co]

Special version of DataLoader which works with similarity samples.

SimilarityDataLoader will automatically assign dummy collate_fn for debug purposes, it will be overwritten once dataloader is used for training.

Required collate function should be defined individually for each encoder by overwriting get_collate_fn()

Parameters:
  • dataset – Dataset which outputs similarity samples

  • **kwargs – Parameters passed directly into __init__()

classmethod collate_labels(batch: List[T_co]) Dict[str, Tensor][source]

Collate function for labels

Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators.

Parameters:

batch – List of similarity samples

Returns:

Collated labels

classmethod flatten_objects(batch: List[T_co], hash_ids: List[int]) Tuple[List[Any], List[int]][source]

Retrieve and enumerate objects from similarity samples.

Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one.

Parameters:
  • batch – List of similarity samples

  • hash_ids – pseudo-random ids of the similarity samples

Returns:

  • List of input features for encoder collate

  • List of ids, associated with each feature

load_label_cache(path: str)[source]
classmethod pre_collate_fn(batch: List[T_co])[source]

Function applied to batch before actual collate.

Splits batch into features - arguments of prediction and labels - targets. Encoder-specific collate_fn will then be applied to feature list only. Loss functions consumes labels from this function without any additional transformations.

Parameters:

batch – List of similarity samples

Returns:

  • ids of the features

  • features batch

  • labels batch

save_label_cache(path: str)[source]
set_label_cache_mode(mode: LabelCacheMode)[source]

Manges how label caching works

set_salt(salt)[source]

Assigns a new salt to the IndexingDataset. Might be useful to distinguish cache sequential keys for train and validation datasets.

Parameters:

salt – salt for index generation

set_skip_read(skip: bool)[source]

Disable reading items in IndexingDataset. If cache is already filled and sequential key is used - it is not necessary to read dataset items the second time

Parameters:

skip – if True - do not read items, only indexes

batch_size: int | None
dataset: Dataset[T_co]
drop_last: bool
property full_cache_used
num_workers: int
property original_params: Dict[str, Any]

Initialization params of the original dataset.

pin_memory: bool
pin_memory_device: str
prefetch_factor: int | None
sampler: Sampler | Iterable
timeout: float