tirank.Loss.regularization_loss

tirank.Loss.regularization_loss(feature_weights)[source]

Calculates the L1 regularization loss (mean absolute value) for a weight matrix.

Parameters:

feature_weights (torch.Tensor) – The learnable weight matrix (e.g., from the encoder).

Returns:

The scalar L1 regularization loss.

Return type:

torch.Tensor