tirank.Model.TiRankModel
- class tirank.Model.TiRankModel(*args: Any, **kwargs: Any)[source]
Bases:
ModuleThe main TiRank multi-task learning model.
This model combines one of the available encoders (MLP, Transformer, DenseNet) with a primary prediction head (for Cox, Classification, or Regression) and an optional auxiliary head for pathology prediction (used in ST mode). It also includes a learnable feature weight layer for L1 regularization.
- Parameters:
n_features (int) – Input feature size (number of gene pairs).
nhead (int) – Number of heads for Transformer.
nhid1 (int) – Hidden dimension for the encoder.
nhid2 (int) – Hidden dimension for the predictor heads.
nlayers (int) – Number of layers in the encoder.
n_output (int) – Output dimension of the encoder (embedding size).
n_pred (int, optional) – Output dimension of the primary predictor. Defaults to 1.
n_patho (int, optional) – Output dimension of the pathology predictor (number of classes). Defaults to 0.
dropout (float, optional) – Dropout value. Defaults to 0.5.
mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’). Defaults to “Cox”.
encoder_type (str, optional) – Type of encoder to use. Defaults to “MLP”.
- forward(x)[source]
The main forward pass for the TiRank model.
- Parameters:
x (torch.Tensor) – Input gene pair feature tensor.
- Returns:
- A tuple containing:
torch.Tensor: The learned embedding.
torch.Tensor: The primary prediction (risk score, class, etc.).
torch.Tensor: The auxiliary pathology prediction.
- Return type:
tuple