tirank.Loss.mmd_loss
- tirank.Loss.mmd_loss(embeddings_A, embeddings_B, sigma=1.0)[source]
Calculates the Maximum Mean Discrepancy (MMD) loss.
MMD is used to measure the distance between the distributions of two sets of embeddings.
- Parameters:
embeddings_A (torch.Tensor) – Embeddings from the first distribution.
embeddings_B (torch.Tensor) – Embeddings from the second distribution.
sigma (float, optional) – The sigma value (bandwidth) for the Gaussian kernel. Defaults to 1.0.
- Returns:
The scalar MMD loss.
- Return type:
torch.Tensor