tirank.TrainPre.Validate_model
- tirank.TrainPre.Validate_model(model, dataloader_A, dataloader_B, mode='Cox', infer_mode='SC', adj_A=None, adj_B=None, pre_patho_labels=None, alphas=[1, 1, 1], device='cpu')[source]
Performs a single validation step.
This function iterates through the dataloaders, performs a forward pass, and calculates the total validation loss (without backpropagation).
- Parameters:
model (TiRankModel) – The TiRank model to be evaluated.
dataloader_A (DataLoader) – DataLoader for the bulk data (validation).
dataloader_B (DataLoader) – DataLoader for the sc/st data.
mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’).
infer_mode (str, optional) – Inference data type (‘SC’ or ‘ST’).
adj_A (torch.Tensor, optional) – Adjacency matrix for sc/st data.
adj_B (torch.Tensor, optional) – Adjacency matrix for spatial distance (ST only).
pre_patho_labels (pd.Series, optional) – Ground truth pathology labels (ST only).
alphas (list, optional) – List of floats to weight the loss components.
device (str, optional) – The compute device (‘cpu’ or ‘cuda’).
- Returns:
The average validation loss for this epoch.
- Return type:
float