tirank.TrainPre.Validate_model

tirank.TrainPre.Validate_model(model, dataloader_A, dataloader_B, mode='Cox', infer_mode='SC', adj_A=None, adj_B=None, pre_patho_labels=None, alphas=[1, 1, 1], device='cpu')[source]

Performs a single validation step.

This function iterates through the dataloaders, performs a forward pass, and calculates the total validation loss (without backpropagation).

Parameters:
  • model (TiRankModel) – The TiRank model to be evaluated.

  • dataloader_A (DataLoader) – DataLoader for the bulk data (validation).

  • dataloader_B (DataLoader) – DataLoader for the sc/st data.

  • mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’).

  • infer_mode (str, optional) – Inference data type (‘SC’ or ‘ST’).

  • adj_A (torch.Tensor, optional) – Adjacency matrix for sc/st data.

  • adj_B (torch.Tensor, optional) – Adjacency matrix for spatial distance (ST only).

  • pre_patho_labels (pd.Series, optional) – Ground truth pathology labels (ST only).

  • alphas (list, optional) – List of floats to weight the loss components.

  • device (str, optional) – The compute device (‘cpu’ or ‘cuda’).

Returns:

The average validation loss for this epoch.

Return type:

float