tirank.Loss.mmd_loss

tirank.Loss.mmd_loss(embeddings_A, embeddings_B, sigma=1.0)[source]

Calculates the Maximum Mean Discrepancy (MMD) loss.

MMD is used to measure the distance between the distributions of two sets of embeddings.

Parameters:
  • embeddings_A (torch.Tensor) – Embeddings from the first distribution.

  • embeddings_B (torch.Tensor) – Embeddings from the second distribution.

  • sigma (float, optional) – The sigma value (bandwidth) for the Gaussian kernel. Defaults to 1.0.

Returns:

The scalar MMD loss.

Return type:

torch.Tensor