Source code for quaterion.dataset.similarity_dataset

from typing import Sized

from import Dataset

from quaterion.dataset.similarity_samples import SimilarityGroupSample

[docs]class SimilarityGroupDataset(Dataset[SimilarityGroupSample]): """Wrapper, which converts standard dataset of classification task into dataset, compatible with :class:`~quaterion.dataset.similarity_data_loader.GroupSimilarityDataLoader`. Args: dataset: a dataset, which return data in format: `(record, label)` """ def __init__(self, dataset: Dataset): self._dataset = dataset def __len__(self) -> int: if isinstance(self._dataset, Sized): return len(self._dataset) else: raise NotImplementedError def __getitem__(self, index) -> SimilarityGroupSample: record, label = self._dataset.__getitem__(index) return SimilarityGroupSample(obj=record, group=label)


