tirank.Loss.prototype_loss

tirank.Loss.prototype_loss(cell_embeddings, bulk_embeddings, bulk_labels, threshold=0.1, margin=1.0)[source]

Calculates a prototype-based contrastive loss.

This loss computes prototype embeddings (class means) from the bulk data. It then assigns pseudo-labels to single-cell embeddings based on the closest prototype. Finally, it applies a contrastive loss to pull confident cells (those with a large distance difference) closer to their correct prototype and push them away from the incorrect one.

Parameters:
  • cell_embeddings (torch.Tensor) – Embeddings for sc/st data.

  • bulk_embeddings (torch.Tensor) – Embeddings for bulk data.

  • bulk_labels (torch.Tensor) – Class labels (0 or 1) for bulk data.

  • threshold (float, optional) – Confidence threshold. Only cells with a distance difference greater than this are used. Defaults to 0.1.

  • margin (float, optional) – Margin for the contrastive loss. Defaults to 1.0.

Returns:

The scalar prototype loss.

Return type:

torch.Tensor