tirank.Loss.cosine_loss

tirank.Loss.cosine_loss(embeddings, A, weight_connected=1.0, weight_unconnected=0.1)[source]

Computes a weighted cosine similarity loss.

This loss encourages embeddings of connected nodes (A=1) to have high cosine similarity (B=1) and pushes unconnected nodes (A=0) to be dissimilar (B=-1). It uses different weights for connected and unconnected pairs to handle sparse adjacency matrices.

Parameters:
  • embeddings (torch.Tensor) – Embeddings, shape [n_cells, embedding_dim].

  • A (torch.Tensor) – The sparse adjacency matrix, shape [n_cells, n_cells].

  • weight_connected (float, optional) – Weight for connected pairs (A=1). Defaults to 1.0.

  • weight_unconnected (float, optional) – Weight for unconnected pairs (A=0). Defaults to 0.1.

Returns:

The scalar weighted cosine loss.

Return type:

torch.Tensor