Shortcuts

Source code for quaterion.train.cache.cache_encoder

from enum import Enum
from typing import Any, Callable, Hashable, List, Tuple, Union

from quaterion_models.encoders import Encoder
from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange
from torch import Tensor

KeyExtractorType = Callable[[Any], Hashable]

CacheCollateReturnType = Union[
    List[Hashable], Tuple[List[Hashable], "TensorInterchange"]
]


[docs]class CacheMode(str, Enum): FILL = "fill" TRAIN = "train"
class CacheEncoder(Encoder): """Wrapper for encoders to avoid repeated calculations. Encoder results can be calculated one time and reused after in situations when encoder's layers are frozen and provide deterministic embeddings for input data. Args: encoder: Encoder object to be wrapped. :meta private: """ def __init__(self, encoder: Encoder): if encoder.trainable: raise ValueError("Trainable encoder can't be cached") super().__init__() self._encoder = encoder @property def wrapped_encoder(self): return self._encoder @property def trainable(self) -> bool: """Defines if encoder is trainable. This flag affects caching and checkpoint saving of the encoder. Returns: bool: whether encoder trainable or not """ return False @property def embedding_size(self) -> int: """Size of output embedding. Returns: int: Size of resulting embedding. """ return self._encoder.embedding_size def cache_extract_meta(self, batch: List[Any]) -> List[dict]: """Extracts meta information from batch. Args: batch: batch of data Returns: List[dict]: list of meta information """ raise NotImplementedError() def get_meta_extractor(self) -> MetaExtractorFnType: """Provides function that extracts meta information from batch. Returns: MetaExtractorFnType: meta extractor function """ return self.cache_extract_meta def cache_collate( self, batch: Union[Tuple[List[Hashable], List[Any]], List[Hashable]] ) -> "CacheCollateReturnType": """Converts raw data batch into suitable model input and keys for caching. Returns: In case only cache keys are provided: return keys If keys and actual features are provided - return result of original collate along with cache keys """ if isinstance(batch, tuple): # Cache filling phase keys, features = batch collated_features = self._encoder.get_collate_fn()(features) return keys, collated_features else: # Assume training phase. # Only keys are provided here return batch def get_collate_fn(self) -> "CollateFnType": """Provides function that converts raw data batch into suitable model input. Returns: CacheCollateFnType: cache collate function """ return self.cache_collate def forward(self, batch: "TensorInterchange") -> Tensor: """Infer encoder. Convert input batch to embeddings Args: batch: collated batch (currently, it can be only batch of keys) Returns: Tensor: shape: (batch_size, embedding_size) - embeddings """ raise NotImplementedError() def save(self, output_path: str) -> None: """Persist current state to the provided directory Args: output_path: path to save to """ self._encoder.save(output_path) @classmethod def load(cls, input_path: str) -> Encoder: """CachedEncoder classes wrap already instantiated encoders and don't provide loading support. Args: input_path: path to load from """ raise ValueError("Cached encoder does not support loading") def is_filled(self) -> bool: """Check if cache already filled""" raise NotImplementedError() def fill_cache( self, keys: List[Hashable], data: "TensorInterchange", meta: List[Any] ) -> None: """Apply wrapped encoder to data and store processed data on corresponding device. Args: keys: Hash keys which should be associated with resulting vectors data: Tuple of keys and batches suitable for encoder meta: List of batch meta information """ raise NotImplementedError() def finish_fill(self): """Notify cache that fill is complete""" raise NotImplementedError() def reset_cache(self): """Reset all stored data.""" raise NotImplementedError() def save_cache(self, path): """Persists cache state on disk""" raise NotImplementedError() def load_cache(self, path): """Loads cache state from disk""" raise NotImplementedError()

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community