Source code for tirank.Loss

# loss function
import torch
import torch.nn as nn

"""
Custom loss functions for training the TiRank model.

This module defines the different loss components used in the multitask
learning framework, including Cox loss for survival, cosine similarity loss
for spatial regularization, MMD loss, and standard classification/regression losses.
"""

[docs] def regularization_loss(feature_weights): """ Calculates the L1 regularization loss (mean absolute value) for a weight matrix. Args: feature_weights (torch.Tensor): The learnable weight matrix (e.g., from the encoder). Returns: torch.Tensor: The scalar L1 regularization loss. """ return torch.mean(torch.abs(feature_weights))
[docs] def cox_loss(pred, t, e, margin=0.1): """ Calculates the Cox partial log-likelihood loss for survival analysis. This implementation uses a pairwise comparison approach with a margin. Args: pred (torch.Tensor): Predicted risk scores, shape [batch_size]. t (torch.Tensor): Event times, shape [batch_size]. e (torch.Tensor): Event indicators (1 if event occurred, 0 otherwise), shape [batch_size]. margin (float, optional): A margin parameter to stabilize the loss. Defaults to 0.1. Returns: torch.Tensor: The scalar Cox partial log-likelihood loss. """ assert len(pred) == len(t) == len(e) # Compute pairwise differences between predictions pred_diffs = pred.unsqueeze(1) - pred.unsqueeze(0) # Compute pairwise time differences time_diffs = t.unsqueeze(1) - t.unsqueeze(0) # Compute pairwise event differences event_diffs = e.unsqueeze(1) - e.unsqueeze(0) # Get a mask for pairs where both events occurred, and the first occurred earlier, or where the first is censored # but has a longer observed survival time than the second who experienced the event. mask = ((e == 1) | (event_diffs == 1)) & (time_diffs < 0) # Compute the loss for these pairs, incorporating the margin losses = torch.log(1 + torch.exp(-(pred_diffs[mask] - margin))) # Average the losses loss = torch.mean(losses) return loss
[docs] def cosine_loss(embeddings, A, weight_connected=1.0, weight_unconnected=0.1): """ Computes a weighted cosine similarity loss. This loss encourages embeddings of connected nodes (A=1) to have high cosine similarity (B=1) and pushes unconnected nodes (A=0) to be dissimilar (B=-1). It uses different weights for connected and unconnected pairs to handle sparse adjacency matrices. Args: embeddings (torch.Tensor): Embeddings, shape [n_cells, embedding_dim]. A (torch.Tensor): The sparse adjacency matrix, shape [n_cells, n_cells]. weight_connected (float, optional): Weight for connected pairs (A=1). Defaults to 1.0. weight_unconnected (float, optional): Weight for unconnected pairs (A=0). Defaults to 0.1. Returns: torch.Tensor: The scalar weighted cosine loss. """ embeddings = embeddings.to(A.device) # Compute cosine similarity matrix B B = torch.mm(embeddings, embeddings.T) magnitudes = torch.norm(embeddings, dim=1, keepdim=True) B = B / (magnitudes * magnitudes.T) B = B - torch.eye(B.shape[0], device=B.device) # Zero out diagonal # Scale A to match B’s range A_scaled = 2 * A - 1 # Weight matrix for balancing weights = A * weight_connected + (1 - A) * weight_unconnected # Weighted loss loss = torch.mean(weights * torch.abs(B - A_scaled)) return loss
[docs] def gaussian_kernel(a, b, sigma = 1.0): """ Calculates the Gaussian (RBF) kernel similarity between two tensors. Args: a (torch.Tensor): First input tensor (samples x features). b (torch.Tensor): Second input tensor (samples x features). sigma (float, optional): The sigma value (bandwidth) of the Gaussian kernel. Defaults to 1.0. Returns: torch.Tensor: The pairwise Gaussian kernel similarity matrix. """ dim1_1, dim1_2 = a.shape[0], b.shape[0] depth = a.shape[1] a = a.view(dim1_1, 1, depth) b = b.view(1, dim1_2, depth) a_core = a.expand(dim1_1, dim1_2, depth) b_core = b.expand(dim1_1, dim1_2, depth) numerator = (a_core - b_core).pow(2).sum(2) / (sigma ** 2) return torch.exp(-numerator)
[docs] def mmd_loss(embeddings_A, embeddings_B, sigma = 1.0): """ Calculates the Maximum Mean Discrepancy (MMD) loss. MMD is used to measure the distance between the distributions of two sets of embeddings. Args: embeddings_A (torch.Tensor): Embeddings from the first distribution. embeddings_B (torch.Tensor): Embeddings from the second distribution. sigma (float, optional): The sigma value (bandwidth) for the Gaussian kernel. Defaults to 1.0. Returns: torch.Tensor: The scalar MMD loss. """ kernel_matrix_A = gaussian_kernel(embeddings_A, embeddings_A, sigma) kernel_matrix_B = gaussian_kernel(embeddings_B, embeddings_B, sigma) kernel_matrix_AB = gaussian_kernel(embeddings_A, embeddings_B, sigma) mmd_loss_ = kernel_matrix_A.mean() + kernel_matrix_B.mean() - 2 * kernel_matrix_AB.mean() return mmd_loss_
[docs] def CrossEntropy_loss(y_pred, y_true): """ Calculates the cross-entropy loss for classification. Wrapper for `nn.CrossEntropyLoss`. Args: y_pred (torch.Tensor): Predicted logits (N x C). y_true (torch.Tensor): True class labels (N). Returns: torch.Tensor: The scalar cross-entropy loss. """ loss_fn = nn.CrossEntropyLoss() loss = loss_fn(y_pred, y_true) # loss_fn = nn.NLLLoss() # loss = loss_fn(torch.log(y_pred), y_true) return loss
[docs] def MSE_loss(y_pred, y_true): """ Calculates the Mean Squared Error (MSE) loss for regression. Wrapper for `nn.MSELoss`. Args: y_pred (torch.Tensor): Predicted values. y_true (torch.Tensor): True values. Returns: torch.Tensor: The scalar MSE loss. """ loss_fn = nn.MSELoss() loss = loss_fn(y_pred, y_true) return loss
[docs] def prototype_loss(cell_embeddings, bulk_embeddings, bulk_labels, threshold=0.1, margin=1.0): """ Calculates a prototype-based contrastive loss. This loss computes prototype embeddings (class means) from the bulk data. It then assigns pseudo-labels to single-cell embeddings based on the closest prototype. Finally, it applies a contrastive loss to pull confident cells (those with a large distance difference) closer to their correct prototype and push them away from the incorrect one. Args: cell_embeddings (torch.Tensor): Embeddings for sc/st data. bulk_embeddings (torch.Tensor): Embeddings for bulk data. bulk_labels (torch.Tensor): Class labels (0 or 1) for bulk data. threshold (float, optional): Confidence threshold. Only cells with a distance difference greater than this are used. Defaults to 0.1. margin (float, optional): Margin for the contrastive loss. Defaults to 1.0. Returns: torch.Tensor: The scalar prototype loss. """ # Compute prototypes from bulk RNA-seq data using class indices rank_plus_proto = bulk_embeddings[bulk_labels == 0].mean(dim=0) # 0 for 'Rank+' rank_minus_proto = bulk_embeddings[bulk_labels == 1].mean(dim=0) # 1 for 'Rank-' # Compute distances for single-cell embeddings to both prototypes dist_to_plus = torch.norm(cell_embeddings - rank_plus_proto, dim=1) dist_to_minus = torch.norm(cell_embeddings - rank_minus_proto, dim=1) # Confidence: difference in distances confidence = torch.abs(dist_to_plus - dist_to_minus) mask = (confidence > threshold).float() # Only use confident cells # Pseudo-labels for single-cell data: 0 if closer to Rank+, 1 if closer to Rank- pseudo_labels = (dist_to_plus < dist_to_minus).long() # Contrastive distances correct_dist = torch.where(pseudo_labels == 0, dist_to_plus, dist_to_minus) incorrect_dist = torch.where(pseudo_labels == 0, dist_to_minus, dist_to_plus) # Contrastive loss: minimize correct_dist, maximize incorrect_dist up to margin loss = torch.mean(mask * (correct_dist + torch.relu(margin - incorrect_dist))) return loss