tirank.Loss.prototype_loss
- tirank.Loss.prototype_loss(cell_embeddings, bulk_embeddings, bulk_labels, threshold=0.1, margin=1.0)[source]
Calculates a prototype-based contrastive loss.
This loss computes prototype embeddings (class means) from the bulk data. It then assigns pseudo-labels to single-cell embeddings based on the closest prototype. Finally, it applies a contrastive loss to pull confident cells (those with a large distance difference) closer to their correct prototype and push them away from the incorrect one.
- Parameters:
cell_embeddings (torch.Tensor) – Embeddings for sc/st data.
bulk_embeddings (torch.Tensor) – Embeddings for bulk data.
bulk_labels (torch.Tensor) – Class labels (0 or 1) for bulk data.
threshold (float, optional) – Confidence threshold. Only cells with a distance difference greater than this are used. Defaults to 0.1.
margin (float, optional) – Margin for the contrastive loss. Defaults to 1.0.
- Returns:
The scalar prototype loss.
- Return type:
torch.Tensor