tirank.Model.TiRankModel

class tirank.Model.TiRankModel(*args: Any, **kwargs: Any)[source]

Bases: Module

The main TiRank multi-task learning model.

This model combines one of the available encoders (MLP, Transformer, DenseNet) with a primary prediction head (for Cox, Classification, or Regression) and an optional auxiliary head for pathology prediction (used in ST mode). It also includes a learnable feature weight layer for L1 regularization.

Parameters:
  • n_features (int) – Input feature size (number of gene pairs).

  • nhead (int) – Number of heads for Transformer.

  • nhid1 (int) – Hidden dimension for the encoder.

  • nhid2 (int) – Hidden dimension for the predictor heads.

  • nlayers (int) – Number of layers in the encoder.

  • n_output (int) – Output dimension of the encoder (embedding size).

  • n_pred (int, optional) – Output dimension of the primary predictor. Defaults to 1.

  • n_patho (int, optional) – Output dimension of the pathology predictor (number of classes). Defaults to 0.

  • dropout (float, optional) – Dropout value. Defaults to 0.5.

  • mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’). Defaults to “Cox”.

  • encoder_type (str, optional) – Type of encoder to use. Defaults to “MLP”.

forward(x)[source]

The main forward pass for the TiRank model.

Parameters:

x (torch.Tensor) – Input gene pair feature tensor.

Returns:

A tuple containing:
  • torch.Tensor: The learned embedding.

  • torch.Tensor: The primary prediction (risk score, class, etc.).

  • torch.Tensor: The auxiliary pathology prediction.

Return type:

tuple

init_weights(m)[source]

Applies Xavier uniform initialization to linear layers.

Parameters:

m (nn.Module) – A module (or layer) from the network.