tirank.Loss.cox_loss

tirank.Loss.cox_loss(pred, t, e, margin=0.1)[source]

Calculates the Cox partial log-likelihood loss for survival analysis.

This implementation uses a pairwise comparison approach with a margin.

Parameters:
  • pred (torch.Tensor) – Predicted risk scores, shape [batch_size].

  • t (torch.Tensor) – Event times, shape [batch_size].

  • e (torch.Tensor) – Event indicators (1 if event occurred, 0 otherwise), shape [batch_size].

  • margin (float, optional) – A margin parameter to stabilize the loss. Defaults to 0.1.

Returns:

The scalar Cox partial log-likelihood loss.

Return type:

torch.Tensor