Shortcuts

Source code for quaterion.train.xbm.xbm_buffer

from typing import Tuple

import torch
from torch import LongTensor, Tensor

from quaterion.train.xbm.xbm_config import XbmConfig, XbmDevice


[docs]class XbmBuffer: """A buffer implementation to hold recent N embeddings and target values. Inspired by https://github.com/msight-tech/research-xbm/blob/master/ret_benchmark/modeling/xbm.py Args: config: Config class to configure XBM settings. embedding_size: Output dimension of `EncoderHead` configured for this model. """ def __init__(self, config: XbmConfig, embedding_size: int): if config.device == XbmDevice.AUTO: device = "cuda" if torch.cuda.is_available() else "cpu" else: device = config.device self._cfg = config self._embeddings = torch.zeros((self._cfg.buffer_size, embedding_size)).to( device ) self._targets = torch.zeros(self._cfg.buffer_size, dtype=torch.long).to(device) self._pointer = 0 self._is_full = False @property def is_full(self) -> bool: return self._is_full
[docs] def get(self) -> Tuple[Tensor, LongTensor]: if self.is_full: return self._embeddings.clone(), self._targets.clone() else: return ( self._embeddings[: self._pointer].clone(), self._targets[: self._pointer].clone(), )
[docs] def queue(self, embeddings: Tensor, targets: LongTensor) -> None: """Queue batch embeddings and targets in the buffer. Args: embeddings: Output embeddings in the batch. targets: Target values in the batch. """ batch_size = len(targets) temp_size = self._pointer + batch_size if temp_size > self._cfg.buffer_size: excess = temp_size - self._cfg.buffer_size self._embeddings[-(batch_size - excess) :] = embeddings[excess:] self._targets[-(batch_size - excess) :] = targets[excess:] self._embeddings[:excess] = embeddings[:excess] self._targets[:excess] = targets[:excess] self._pointer = excess elif temp_size == self._cfg.buffer_size: self._embeddings[-batch_size:] = embeddings self._targets[-batch_size:] = targets self._pointer = 0 else: self._embeddings[self._pointer : self._pointer + batch_size] = embeddings self._targets[self._pointer : self._pointer + batch_size] = targets self._pointer += batch_size if temp_size >= self._cfg.buffer_size and not self.is_full: self._is_full = True

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community