tirank.Loss

tirank.Loss.regularization_loss(feature_weights)[source]

Calculates the L1 regularization loss (mean absolute value) for a weight matrix.

Parameters:

feature_weights (torch.Tensor) – The learnable weight matrix (e.g., from the encoder).

Returns:

The scalar L1 regularization loss.

Return type:

torch.Tensor

tirank.Loss.cox_loss(pred, t, e, margin=0.1)[source]

Calculates the Cox partial log-likelihood loss for survival analysis.

This implementation uses a pairwise comparison approach with a margin.

Parameters:
  • pred (torch.Tensor) – Predicted risk scores, shape [batch_size].

  • t (torch.Tensor) – Event times, shape [batch_size].

  • e (torch.Tensor) – Event indicators (1 if event occurred, 0 otherwise), shape [batch_size].

  • margin (float, optional) – A margin parameter to stabilize the loss. Defaults to 0.1.

Returns:

The scalar Cox partial log-likelihood loss.

Return type:

torch.Tensor

tirank.Loss.cosine_loss(embeddings, A, weight_connected=1.0, weight_unconnected=0.1)[source]

Computes a weighted cosine similarity loss.

This loss encourages embeddings of connected nodes (A=1) to have high cosine similarity (B=1) and pushes unconnected nodes (A=0) to be dissimilar (B=-1). It uses different weights for connected and unconnected pairs to handle sparse adjacency matrices.

Parameters:
  • embeddings (torch.Tensor) – Embeddings, shape [n_cells, embedding_dim].

  • A (torch.Tensor) – The sparse adjacency matrix, shape [n_cells, n_cells].

  • weight_connected (float, optional) – Weight for connected pairs (A=1). Defaults to 1.0.

  • weight_unconnected (float, optional) – Weight for unconnected pairs (A=0). Defaults to 0.1.

Returns:

The scalar weighted cosine loss.

Return type:

torch.Tensor

tirank.Loss.gaussian_kernel(a, b, sigma=1.0)[source]

Calculates the Gaussian (RBF) kernel similarity between two tensors.

Parameters:
  • a (torch.Tensor) – First input tensor (samples x features).

  • b (torch.Tensor) – Second input tensor (samples x features).

  • sigma (float, optional) – The sigma value (bandwidth) of the Gaussian kernel. Defaults to 1.0.

Returns:

The pairwise Gaussian kernel similarity matrix.

Return type:

torch.Tensor

tirank.Loss.mmd_loss(embeddings_A, embeddings_B, sigma=1.0)[source]

Calculates the Maximum Mean Discrepancy (MMD) loss.

MMD is used to measure the distance between the distributions of two sets of embeddings.

Parameters:
  • embeddings_A (torch.Tensor) – Embeddings from the first distribution.

  • embeddings_B (torch.Tensor) – Embeddings from the second distribution.

  • sigma (float, optional) – The sigma value (bandwidth) for the Gaussian kernel. Defaults to 1.0.

Returns:

The scalar MMD loss.

Return type:

torch.Tensor

tirank.Loss.CrossEntropy_loss(y_pred, y_true)[source]

Calculates the cross-entropy loss for classification.

Wrapper for nn.CrossEntropyLoss.

Parameters:
  • y_pred (torch.Tensor) – Predicted logits (N x C).

  • y_true (torch.Tensor) – True class labels (N).

Returns:

The scalar cross-entropy loss.

Return type:

torch.Tensor

tirank.Loss.MSE_loss(y_pred, y_true)[source]

Calculates the Mean Squared Error (MSE) loss for regression.

Wrapper for nn.MSELoss.

Parameters:
  • y_pred (torch.Tensor) – Predicted values.

  • y_true (torch.Tensor) – True values.

Returns:

The scalar MSE loss.

Return type:

torch.Tensor

tirank.Loss.prototype_loss(cell_embeddings, bulk_embeddings, bulk_labels, threshold=0.1, margin=1.0)[source]

Calculates a prototype-based contrastive loss.

This loss computes prototype embeddings (class means) from the bulk data. It then assigns pseudo-labels to single-cell embeddings based on the closest prototype. Finally, it applies a contrastive loss to pull confident cells (those with a large distance difference) closer to their correct prototype and push them away from the incorrect one.

Parameters:
  • cell_embeddings (torch.Tensor) – Embeddings for sc/st data.

  • bulk_embeddings (torch.Tensor) – Embeddings for bulk data.

  • bulk_labels (torch.Tensor) – Class labels (0 or 1) for bulk data.

  • threshold (float, optional) – Confidence threshold. Only cells with a distance difference greater than this are used. Defaults to 0.1.

  • margin (float, optional) – Margin for the contrastive loss. Defaults to 1.0.

Returns:

The scalar prototype loss.

Return type:

torch.Tensor

Functions

tirank.Loss.CrossEntropy_loss

Calculates the cross-entropy loss for classification.

tirank.Loss.MSE_loss

Calculates the Mean Squared Error (MSE) loss for regression.

tirank.Loss.cosine_loss

Computes a weighted cosine similarity loss.

tirank.Loss.cox_loss

Calculates the Cox partial log-likelihood loss for survival analysis.

tirank.Loss.gaussian_kernel

Calculates the Gaussian (RBF) kernel similarity between two tensors.

tirank.Loss.mmd_loss

Calculates the Maximum Mean Discrepancy (MMD) loss.

tirank.Loss.prototype_loss

Calculates a prototype-based contrastive loss.

tirank.Loss.regularization_loss

Calculates the L1 regularization loss (mean absolute value) for a weight matrix.