Shortcuts

Source code for quaterion.utils.utils

from typing import Iterable, Optional, Sized, Union

import torch
import torch.nn.functional as F
import tqdm
from torch.utils.data import Dataset


[docs]def info_value_of_dtype(dtype: torch.dtype) -> Union[torch.finfo, torch.iinfo]: """Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool. Args: dtype: dtype for which to return info value Returns: Union[torch.finfo, torch.iinfo]: info about given data type Raises: TypeError: if torch.bool is passed """ if dtype == torch.bool: raise TypeError("Does not support torch.bool") elif dtype.is_floating_point: return torch.finfo(dtype) else: return torch.iinfo(dtype)
[docs]def min_value_of_dtype(dtype: torch.dtype) -> Union[int, float]: """Returns the minimum value of a given PyTorch data type. Does not allow torch.bool. Args: dtype: dtype for which to return min value Returns: Union[int, float]: max value of dtype """ return info_value_of_dtype(dtype).min
[docs]def max_value_of_dtype(dtype: torch.dtype) -> float: """Returns the maximum value of a given PyTorch data type. Does not allow torch.bool. Args: dtype: dtype for which to return max value Returns: Union[int, float]: max value of dtype """ return info_value_of_dtype(dtype).max
[docs]def get_triplet_mask(labels: torch.Tensor) -> torch.Tensor: """Creates a 3D mask of valid triplets for the batch-all strategy. Given a batch of labels with `shape = (batch_size,)` the number of possible triplets that can be formed is: batch_size^3, i.e. cube of batch_size, which can be represented as a tensor with `shape = (batch_size, batch_size, batch_size)`. However, a triplet is valid if: `labels[i] == labels[j] and labels[i] != labels[k]` and `i`, `j` and `k` are distinct indices. This function calculates a mask indicating which ones of all the possible triplets are actually valid triplets based on the given criteria above. Args: labels (torch.Tensor): Labels associated with embeddings in the batch. Shape: (batch_size,) Returns: torch.Tensor: Triplet mask. Shape: (batch_size, batch_size, batch_size) """ # get a mask for distinct indices # Shape: (batch_size, batch_size) indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device) indices_not_equal = torch.logical_not(indices_equal) # Shape: (batch_size, batch_size, 1) i_not_equal_j = indices_not_equal.unsqueeze(2) # Shape: (batch_size, 1, batch_size) i_not_equal_k = indices_not_equal.unsqueeze(1) # Shape: (1, batch_size, batch_size) j_not_equal_k = indices_not_equal.unsqueeze(0) # Shape: (batch_size, batch_size, batch_size) distinct_indices = torch.logical_and( torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k ) # get a mask for: # labels[i] == labels[j] and labels[i] != labels[k] # Shape: (batch_size, batch_size) labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # Shape: (batch_size, batch_size, 1) i_equal_j = labels_equal.unsqueeze(2) # Shape: (batch_size, 1, batch_size) i_equal_k = labels_equal.unsqueeze(1) # Shape: (batch_size, batch_size, batch_size) valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k)) # combine masks mask = torch.logical_and(distinct_indices, valid_indices) return mask
[docs]def get_anchor_positive_mask( labels_a: torch.Tensor, labels_b: Optional[torch.Tensor] = None ) -> torch.BoolTensor: """Creates a 2D mask of valid anchor-positive pairs. Args: labels_a (torch.Tensor): Labels associated with embeddings in the batch A. Shape: (batch_size_a,) labels_b (torch.Tensor): Labels associated with embeddings in the batch B. Shape: (batch_size_b,) If `labels_b is None`, it assigns `labels_a` to `labels_b`. Returns: torch.Tensor: Anchor-positive mask. Shape: (batch_size_a, batch_size_b) """ if labels_b is None: labels_b = labels_a # Shape: (batch_size_a, batch_size_b) mask = labels_a.expand(labels_b.shape[0], labels_a.shape[0]).t() == labels_b.expand( labels_a.shape[0], labels_b.shape[0] ) if torch.equal( labels_a, labels_b ): # handle identical batches of labels for regular loss # shape: (batch_size_a, batch_size_a) indices_equal = torch.eye( labels_a.size()[0], dtype=torch.bool, device=labels_a.device ) indices_not_equal = torch.logical_not(indices_equal) mask = torch.logical_and(indices_not_equal, mask) return mask
[docs]def get_anchor_negative_mask( labels_a: torch.Tensor, labels_b: Optional[torch.Tensor] = None ) -> torch.BoolTensor: """Creates a 2D mask of valid anchor-negative pairs. Args: labels_a (torch.Tensor): Labels associated with embeddings in the batch A. Shape: (batch_size_a,) labels_b (torch.Tensor): Labels associated with embeddings in the batch B. Shape: (batch_size_b,). If `labels_b is None`, it assigns `labels_a` to `labels_b`. Returns: torch.Tensor: Anchor-negative mask. Shape: (batch_size_a, batch_size_b) """ if labels_b is None: labels_b = labels_a # Shape: (batch_size_a, batch_size_b) mask = labels_a.expand(labels_b.shape[0], labels_a.shape[0]).t() != labels_b.expand( labels_a.shape[0], labels_b.shape[0] ) return mask
[docs]def iter_by_batch( sequence: Union[Sized, Iterable, Dataset], batch_size: int, log_progress: bool = True, ): """Iterate through index-able or iterable by batches Try to iterate by indices, if fail - via iterable interface. """ try: sequence.__getitem__(0) size = len(sequence) step = batch_size if batch_size < size else size if log_progress: iterator = tqdm.tqdm(range(0, size, step), total=size / step) else: iterator = range(0, size, step) for slice_start_index in iterator: slice_end_index = slice_start_index + step slice_end_index = slice_end_index if slice_end_index < size else size input_batch = [ sequence[index] for index in range(slice_start_index, slice_end_index) ] yield input_batch except (AttributeError, NotImplementedError, IndexError): batch = [] if log_progress: iterator = tqdm.tqdm(sequence) else: iterator = sequence for item in iterator: batch.append(item) if len(batch) == batch_size: yield batch batch = [] if len(batch) > 0: yield batch return
[docs]def get_masked_maximum( dists: torch.Tensor, mask: torch.Tensor, dim: int = 1 ) -> torch.Tensor: """Utility function for semi hard mining. Args: dists: Tiled distance matrix. mask: Tiled mask. dim: Dimension to operate on. Returns: torch.Tensor - masked maximums. """ axis_minimums, _ = dists.min(dim, keepdims=True) masked_maximums = (dists - axis_minimums) * mask masked_maximums, _ = masked_maximums.max(dim, keepdims=True) masked_maximums += axis_minimums return masked_maximums
[docs]def get_masked_minimum(dists, mask, dim=1): """Utility function for semi hard mining. Args: dists: Tiled distance matrix. mask: Tiled mask. dim: Dimension to operate on. Returns: torch.Tensor - masked maximums. """ axis_maximums, _ = dists.max(dim, keepdims=True) masked_minimums = (dists - axis_maximums) * mask masked_minimums, _ = masked_minimums.min(dim, keepdims=True) masked_minimums += axis_maximums return masked_minimums
[docs]def l2_norm(inputs: torch.Tensor, dim: int = 0) -> torch.Tensor: """Apply L2 normalization to tensor Args: inputs: Input tensor. dim: Dimension to operate on. Returns: torch.Tensor: L2-normalized tensor """ return F.normalize(inputs, p=2, dim=dim)

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community