tirank.Loss.CrossEntropy_loss

tirank.Loss.CrossEntropy_loss(y_pred, y_true)[source]

Calculates the cross-entropy loss for classification.

Wrapper for nn.CrossEntropyLoss.

Parameters:
  • y_pred (torch.Tensor) – Predicted logits (N x C).

  • y_true (torch.Tensor) – True class labels (N).

Returns:

The scalar cross-entropy loss.

Return type:

torch.Tensor