quaterion.utils.utils module¶
- get_anchor_negative_mask(labels_a: Tensor, labels_b: Tensor | None = None) BoolTensor [source]¶
Creates a 2D mask of valid anchor-negative pairs.
- Parameters:
- Returns:
torch.Tensor – Anchor-negative mask. Shape: (batch_size_a, batch_size_b)
- get_anchor_positive_mask(labels_a: Tensor, labels_b: Tensor | None = None) BoolTensor [source]¶
Creates a 2D mask of valid anchor-positive pairs.
- Parameters:
- Returns:
torch.Tensor – Anchor-positive mask. Shape: (batch_size_a, batch_size_b)
- get_masked_maximum(dists: Tensor, mask: Tensor, dim: int = 1) Tensor [source]¶
Utility function for semi hard mining.
- Parameters:
dists – Tiled distance matrix.
mask – Tiled mask.
dim – Dimension to operate on.
- Returns:
torch.Tensor - masked maximums.
- get_masked_minimum(dists, mask, dim=1)[source]¶
Utility function for semi hard mining.
- Parameters:
dists – Tiled distance matrix.
mask – Tiled mask.
dim – Dimension to operate on.
- Returns:
torch.Tensor - masked maximums.
- get_triplet_mask(labels: Tensor) Tensor [source]¶
Creates a 3D mask of valid triplets for the batch-all strategy.
Given a batch of labels with shape = (batch_size,) the number of possible triplets that can be formed is: batch_size^3, i.e. cube of batch_size, which can be represented as a tensor with shape = (batch_size, batch_size, batch_size). However, a triplet is valid if: labels[i] == labels[j] and labels[i] != labels[k] and i, j and k are distinct indices. This function calculates a mask indicating which ones of all the possible triplets are actually valid triplets based on the given criteria above.
- Parameters:
labels (Tensor) – Labels associated with embeddings in the batch. Shape: (batch_size,)
- Returns:
torch.Tensor – Triplet mask. Shape: (batch_size, batch_size, batch_size)
- info_value_of_dtype(dtype: dtype) finfo | iinfo [source]¶
Returns the finfo or iinfo object of a given PyTorch data type.
Does not allow torch.bool.
- Parameters:
dtype – dtype for which to return info value
- Returns:
Union[torch.finfo, torch.iinfo] – info about given data type
- Raises:
TypeError – if torch.bool is passed
- iter_by_batch(sequence: Sized | Iterable | Dataset, batch_size: int, log_progress: bool = True)[source]¶
Iterate through index-able or iterable by batches
Try to iterate by indices, if fail - via iterable interface.
- l2_norm(inputs: Tensor, dim: int = 0) Tensor [source]¶
Apply L2 normalization to tensor
- Parameters:
inputs – Input tensor.
dim – Dimension to operate on.
- Returns:
torch.Tensor – L2-normalized tensor