tirank.Model.PathologyPredictor

class tirank.Model.PathologyPredictor(*args: Any, **kwargs: Any)[source]

Bases: Module

Auxiliary prediction head for spatial pathology class.

Used for the WSI-guided spatial location-aware module in ST mode.

Parameters:
  • n_features (int) – Input embedding dimension (from encoder).

  • nhid (int) – Hidden dimension of the predictor MLP.

  • nclass (int) – Number of pathology classes to predict.

  • dropout (float, optional) – Dropout value. Defaults to 0.5.

forward(embedding)[source]

Forward pass for the pathology predictor.

Parameters:

embedding (torch.Tensor) – Input embedding tensor from the encoder.

Returns:

Predicted pathology class probabilities.

Return type:

torch.Tensor