tirank.Loss.cosine_loss
- tirank.Loss.cosine_loss(embeddings, A, weight_connected=1.0, weight_unconnected=0.1)[source]
Computes a weighted cosine similarity loss.
This loss encourages embeddings of connected nodes (A=1) to have high cosine similarity (B=1) and pushes unconnected nodes (A=0) to be dissimilar (B=-1). It uses different weights for connected and unconnected pairs to handle sparse adjacency matrices.
- Parameters:
embeddings (torch.Tensor) – Embeddings, shape [n_cells, embedding_dim].
A (torch.Tensor) – The sparse adjacency matrix, shape [n_cells, n_cells].
weight_connected (float, optional) – Weight for connected pairs (A=1). Defaults to 1.0.
weight_unconnected (float, optional) – Weight for unconnected pairs (A=0). Defaults to 0.1.
- Returns:
The scalar weighted cosine loss.
- Return type:
torch.Tensor