tirank.Loss
- tirank.Loss.regularization_loss(feature_weights)[source]
Calculates the L1 regularization loss (mean absolute value) for a weight matrix.
- Parameters:
feature_weights (torch.Tensor) – The learnable weight matrix (e.g., from the encoder).
- Returns:
The scalar L1 regularization loss.
- Return type:
torch.Tensor
- tirank.Loss.cox_loss(pred, t, e, margin=0.1)[source]
Calculates the Cox partial log-likelihood loss for survival analysis.
This implementation uses a pairwise comparison approach with a margin.
- Parameters:
pred (torch.Tensor) – Predicted risk scores, shape [batch_size].
t (torch.Tensor) – Event times, shape [batch_size].
e (torch.Tensor) – Event indicators (1 if event occurred, 0 otherwise), shape [batch_size].
margin (float, optional) – A margin parameter to stabilize the loss. Defaults to 0.1.
- Returns:
The scalar Cox partial log-likelihood loss.
- Return type:
torch.Tensor
- tirank.Loss.cosine_loss(embeddings, A, weight_connected=1.0, weight_unconnected=0.1)[source]
Computes a weighted cosine similarity loss.
This loss encourages embeddings of connected nodes (A=1) to have high cosine similarity (B=1) and pushes unconnected nodes (A=0) to be dissimilar (B=-1). It uses different weights for connected and unconnected pairs to handle sparse adjacency matrices.
- Parameters:
embeddings (torch.Tensor) – Embeddings, shape [n_cells, embedding_dim].
A (torch.Tensor) – The sparse adjacency matrix, shape [n_cells, n_cells].
weight_connected (float, optional) – Weight for connected pairs (A=1). Defaults to 1.0.
weight_unconnected (float, optional) – Weight for unconnected pairs (A=0). Defaults to 0.1.
- Returns:
The scalar weighted cosine loss.
- Return type:
torch.Tensor
- tirank.Loss.gaussian_kernel(a, b, sigma=1.0)[source]
Calculates the Gaussian (RBF) kernel similarity between two tensors.
- Parameters:
a (torch.Tensor) – First input tensor (samples x features).
b (torch.Tensor) – Second input tensor (samples x features).
sigma (float, optional) – The sigma value (bandwidth) of the Gaussian kernel. Defaults to 1.0.
- Returns:
The pairwise Gaussian kernel similarity matrix.
- Return type:
torch.Tensor
- tirank.Loss.mmd_loss(embeddings_A, embeddings_B, sigma=1.0)[source]
Calculates the Maximum Mean Discrepancy (MMD) loss.
MMD is used to measure the distance between the distributions of two sets of embeddings.
- Parameters:
embeddings_A (torch.Tensor) – Embeddings from the first distribution.
embeddings_B (torch.Tensor) – Embeddings from the second distribution.
sigma (float, optional) – The sigma value (bandwidth) for the Gaussian kernel. Defaults to 1.0.
- Returns:
The scalar MMD loss.
- Return type:
torch.Tensor
- tirank.Loss.CrossEntropy_loss(y_pred, y_true)[source]
Calculates the cross-entropy loss for classification.
Wrapper for nn.CrossEntropyLoss.
- Parameters:
y_pred (torch.Tensor) – Predicted logits (N x C).
y_true (torch.Tensor) – True class labels (N).
- Returns:
The scalar cross-entropy loss.
- Return type:
torch.Tensor
- tirank.Loss.MSE_loss(y_pred, y_true)[source]
Calculates the Mean Squared Error (MSE) loss for regression.
Wrapper for nn.MSELoss.
- Parameters:
y_pred (torch.Tensor) – Predicted values.
y_true (torch.Tensor) – True values.
- Returns:
The scalar MSE loss.
- Return type:
torch.Tensor
- tirank.Loss.prototype_loss(cell_embeddings, bulk_embeddings, bulk_labels, threshold=0.1, margin=1.0)[source]
Calculates a prototype-based contrastive loss.
This loss computes prototype embeddings (class means) from the bulk data. It then assigns pseudo-labels to single-cell embeddings based on the closest prototype. Finally, it applies a contrastive loss to pull confident cells (those with a large distance difference) closer to their correct prototype and push them away from the incorrect one.
- Parameters:
cell_embeddings (torch.Tensor) – Embeddings for sc/st data.
bulk_embeddings (torch.Tensor) – Embeddings for bulk data.
bulk_labels (torch.Tensor) – Class labels (0 or 1) for bulk data.
threshold (float, optional) – Confidence threshold. Only cells with a distance difference greater than this are used. Defaults to 0.1.
margin (float, optional) – Margin for the contrastive loss. Defaults to 1.0.
- Returns:
The scalar prototype loss.
- Return type:
torch.Tensor
Functions
Calculates the cross-entropy loss for classification. |
|
Calculates the Mean Squared Error (MSE) loss for regression. |
|
Computes a weighted cosine similarity loss. |
|
Calculates the Cox partial log-likelihood loss for survival analysis. |
|
Calculates the Gaussian (RBF) kernel similarity between two tensors. |
|
Calculates the Maximum Mean Discrepancy (MMD) loss. |
|
Calculates a prototype-based contrastive loss. |
|
Calculates the L1 regularization loss (mean absolute value) for a weight matrix. |