Shortcuts

Source code for quaterion.train.cache_mixin

import os
import warnings
from typing import Dict, Optional, Union

import pytorch_lightning as pl
import torch.cuda
from loguru import logger
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from quaterion_models.encoders import Encoder
from torch.utils.data import DataLoader

from quaterion.dataset.label_cache_dataset import LabelCacheMode
from quaterion.dataset.similarity_data_loader import SimilarityDataLoader
from quaterion.train.cache import (
    CacheConfig,
    CacheEncoder,
    CacheType,
    InMemoryCacheEncoder,
)
from quaterion.train.cache.cache_encoder import CacheMode
from quaterion.train.cache.cache_model import CacheModel
from quaterion.train.cache.cache_train_collator import CacheTrainCollator


[docs]class CacheMixin: @classmethod def _apply_cache_config( cls, encoders: Union[Encoder, Dict[str, Encoder]], cache_config: CacheConfig, ) -> Union[Encoder, Dict[str, Encoder]]: """Applies received cache configuration for cached encoders, remain non-cached encoders as is Args: encoders: all model's encoders cache_config: CacheConfig instance defined in `configure_cache` method of the model Returns: Union[Encoder, Dict[str, encoder]]: encoder or dict of encoders which were wrapped into CacheEncoder instances according to received cache config. Result type depends on the way encoder was defined in the model: with or without explicit mapping Raises: KeyError: encoder's name in cache config is not in model's encoders ValueError: if CacheConfig instance does not have some of required options set. E.g. not `mapping` nor `cache_type` being set """ if cache_config.cache_type == CacheType.NONE: return encoders if not cache_config.cache_type and not cache_config.mapping: raise ValueError( "If cache is configured, cache_type or mapping have to be set" ) if isinstance(encoders, Encoder): return cls._wrap_encoder( encoders, cache_config=cache_config, ) cached_encoders = {} for encoder_name, encoder in encoders.items(): cached_encoders[encoder_name] = cls._wrap_encoder( encoder, cache_config=cache_config, encoder_name=encoder_name ) return {**encoders, **cached_encoders} @staticmethod def _check_cuda(cache_type: CacheType, encoder_name: str) -> None: if cache_type == CacheType.GPU and not torch.cuda.is_available(): raise ValueError( f"`CacheType.GPU` has been chosen for `{encoder_name}` " "encoder, but cuda is not available" ) @classmethod def _wrap_encoder( cls, encoder: Encoder, cache_config: CacheConfig, encoder_name: str = "" ) -> Encoder: """Wrap encoder into CacheEncoder instance if it is required by config. Args: encoder: raw model's encoder cache_config: cache type of tensor storage Returns: wrapped CacheEncoder or original encoder Raises: ValueError: if encoder layers are not frozen. Cache can be applied only to fully frozen encoders' outputs. """ if isinstance(encoder, CacheEncoder): return encoder if encoder.trainable: if encoder_name in cache_config.mapping: raise ValueError( f"Can't configure cache for encoder {encoder_name}. " "Encoder must be frozen to cache it" ) return encoder cache_type = cache_config.mapping.get(encoder_name) or cache_config.cache_type if cache_type is None: logger.info( f"{encoder_name} haven't been cached, " "but could be as non-trainable encoders" ) return encoder cls._check_cuda(cache_type, encoder_name) return InMemoryCacheEncoder(encoder, cache_type) @classmethod def _cache( cls, trainer: pl.Trainer, encoders: Dict[str, Encoder], train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader], cache_config: CacheConfig, ) -> bool: """Filling cache for model's cache encoders. Args: trainer: Lightning Trainer holds required parameters for model launch (gpu, e.t.c.) encoders: mapping of all model's encoders and their names train_dataloader: model's train dataloader val_dataloader: model's val dataloader cache_config: cache config instance to configure cache batch size and num of workers to use for caching Returns: True, if cache was filled with data False, if cache is not used """ cache_encoders = { name: encoder for name, encoder in encoders.items() if isinstance(encoder, CacheEncoder) } if not cache_encoders: return False # Check if all encoders are cachable, and we don't use custom key extractor. # If so, we can also cache whole dataset and avoid reading from it is_full_cache_possible = ( len(cache_encoders) == len(encoders) and not cache_config.key_extractors and train_dataloader.num_workers == 0 and ( val_dataloader.num_workers == 0 if val_dataloader is not None else True ) ) if is_full_cache_possible: logger.debug("Using full cache") if cache_config.key_extractors and not isinstance( cache_config.key_extractors, dict ): # If only one function specified, use it for all encoders key_extractors = { name: cache_config.key_extractors for name in cache_encoders.keys() } else: key_extractors = cache_config.key_extractors cache_collator = CacheTrainCollator( pre_collate_fn=train_dataloader.pre_collate_fn, encoder_collates={ name: encoder.get_collate_fn() for name, encoder in encoders.items() }, meta_extractors={ name: encoder.get_meta_extractor() for name, encoder in encoders.items() }, key_extractors=key_extractors, cachable_encoders=list(cache_encoders.keys()), mode=CacheMode.TRAIN, ) train_dataloader.collate_fn = cache_collator if val_dataloader is not None: val_dataloader.collate_fn = cache_collator # Setup different cache key salt for train and val train_dataloader.set_salt("train") if val_dataloader is not None: val_dataloader.set_salt("val") is_persisted = cls.check_encoders_persisted( cache_config.save_dir, cache_encoders ) if not is_persisted: if is_full_cache_possible: cls._label_cache_train_mode(train_dataloader, val_dataloader) cache_collator.mode = CacheMode.FILL with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=PossibleUserWarning, message="The dataloader, .*" ) cls._fill_cache( trainer=trainer, cache_encoders=cache_encoders, train_dataloader=train_dataloader, val_dataloader=val_dataloader, cache_config=cache_config, ) cache_collator.mode = CacheMode.TRAIN logger.debug("Caching has been successfully finished") cls.save_cache( cache_config.save_dir, cache_encoders, train_dataloader, val_dataloader, ) else: cls.load_cache( cache_config.save_dir, cache_encoders, train_dataloader, val_dataloader ) if is_full_cache_possible: cls._enable_label_cache(train_dataloader, val_dataloader) return True @classmethod def _fill_cache( cls, trainer: pl.Trainer, cache_encoders: Dict[str, CacheEncoder], train_dataloader: SimilarityDataLoader, val_dataloader: SimilarityDataLoader, cache_config: CacheConfig, ) -> None: """Fills cache and restores trainer state for further training process. Args: trainer: performs one training and validation epoch cache_encoders: mapping of encoders to cache input train_dataloader: model's train dataloader val_dataloader: model's val dataloader """ cache_train_dataloader = cls._wrap_cache_dataloader( dataloader=train_dataloader, cache_config=cache_config ) cache_val_dataloader = None if val_dataloader is not None: cache_val_dataloader = cls._wrap_cache_dataloader( dataloader=val_dataloader, cache_config=cache_config ) # The actual caching trainer.predict( CacheModel( cache_encoders, ), [cache_train_dataloader, cache_val_dataloader], return_predictions=True, ) @classmethod def _unwrap_cache_encoders(cls, encoders: Dict[str, Encoder]) -> Dict[str, Encoder]: unwrapped_encoders = {} for key, encoder in encoders.items(): if isinstance(encoder, CacheEncoder): unwrapped_encoders[key] = encoder.wrapped_encoder else: unwrapped_encoders[key] = encoder return unwrapped_encoders @classmethod def _wrap_cache_dataloader( cls, dataloader: SimilarityDataLoader, cache_config: CacheConfig, ) -> DataLoader: """Creates dataloader for caching. Args: dataloader: dataloader to be wrapped cache_config: cache config to retrieve num of workers and batch size Returns: DataLoader: dataloader for caching """ num_workers = ( cache_config.num_workers if cache_config.num_workers is not None else dataloader.num_workers ) # We need to reduce random sampling and repeated calculations to # make cache as fast as possible. Thus, we recreate dataloader # and set batch size explicitly. params = { **dataloader.original_params, "num_workers": num_workers, "batch_size": cache_config.batch_size, "shuffle": False, "sampler": None, } params.pop("collate_fn") # Explicitly override collate cache_dl = DataLoader( dataset=dataloader.dataset, collate_fn=dataloader.collate_fn, **params ) return cache_dl @classmethod def _label_cache_train_mode( cls, train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader], ): train_dataloader.set_label_cache_mode(LabelCacheMode.learn) if val_dataloader: val_dataloader.set_label_cache_mode(LabelCacheMode.learn) @classmethod def _enable_label_cache( cls, train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader], ): train_dataloader.set_skip_read(True) train_dataloader.set_label_cache_mode(LabelCacheMode.read) if val_dataloader: val_dataloader.set_skip_read(True) val_dataloader.set_label_cache_mode(LabelCacheMode.read) @classmethod def _encoders_cache_path(cls, dir_path: Optional[str]): return os.path.join(dir_path, "encoders") if dir_path else None
[docs] @classmethod def check_encoders_persisted( cls, dir_path: Optional[str], encoders: Dict[str, Encoder] ): if not dir_path: return False encoders_path = cls._encoders_cache_path(dir_path) for key, encoder in encoders.items(): if not os.path.exists(os.path.join(encoders_path, key)): return False return True
[docs] @classmethod def save_cache( cls, dir_path: Optional[str], encoders: Dict[str, Encoder], train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader], ): if not dir_path: return encoders_path = cls._encoders_cache_path(dir_path) os.makedirs(encoders_path, exist_ok=True) for key, encoder in encoders.items(): if isinstance(encoder, CacheEncoder): encoder.save_cache(os.path.join(encoders_path, key)) train_dataloader.save_label_cache(os.path.join(dir_path, "train_labels")) if val_dataloader: val_dataloader.save_label_cache(os.path.join(dir_path, "val_labels")) logger.debug(f"Cache saved to {dir_path}")
[docs] @classmethod def load_cache( cls, dir_path: str, encoders: Dict[str, Encoder], train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader], ): if not dir_path: return encoders_path = cls._encoders_cache_path(dir_path) for key, encoder in encoders.items(): if isinstance(encoder, CacheEncoder): encoder_cache_path = os.path.join(encoders_path, key) if not os.path.exists(encoder_cache_path): raise RuntimeError( "Encoder cache was configured, but not found. " f"Expected to find cache at {encoder_cache_path}, but file does not exists!" ) encoder.load_cache(encoder_cache_path) train_dataloader.load_label_cache(os.path.join(dir_path, "train_labels")) if val_dataloader: val_dataloader.load_label_cache(os.path.join(dir_path, "val_labels")) logger.debug(f"Cache loaded from: {dir_path}")

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community