Shortcuts

Source code for quaterion.distances.cosine

import torch
import torch.nn.functional as F
from torch import Tensor

from quaterion.distances.base_distance import BaseDistance


[docs]class Cosine(BaseDistance): """Compute cosine similarities (and its interpretation as distances). Note: The output range of this metric is `0 -> 1`. """
[docs] @staticmethod def similarity(x: Tensor, y: Tensor) -> Tensor: return torch.cosine_similarity(x, y)
[docs] @staticmethod def distance(x: Tensor, y: Tensor) -> Tensor: return 1 - Cosine.similarity(x, y)
[docs] @staticmethod def similarity_matrix(x: Tensor, y: Tensor = None) -> Tensor: x_norm = F.normalize(x, p=2, dim=1) if y is None: y_norm = x_norm.transpose(0, 1) else: y_norm = F.normalize(y, p=2, dim=1).transpose(0, 1) return (torch.mm(x_norm, y_norm) + 1) / 2
[docs] @staticmethod def distance_matrix(x: Tensor, y: Tensor = None) -> Tensor: return 1 - Cosine.similarity_matrix(x, y)