tirank.TrainPre.Train_one_epoch

tirank.TrainPre.Train_one_epoch(model, dataloader_A, dataloader_B, mode='Cox', infer_mode='SC', adj_A=None, adj_B=None, pre_patho_labels=None, optimizer=None, alphas=[1, 1, 1, 1], device='cpu', epoch=0)[source]

Performs a single training epoch for the TiRank multi-task model.

This function iterates through the sc/st data (dataloader_B) and computes the composite loss against the bulk data (dataloader_A) and regularization terms (adjacency matrices, MMD, etc.). It then performs a backward pass and optimizer step.

Parameters:
  • model (TiRankModel) – The TiRank model to be trained.

  • dataloader_A (DataLoader) – DataLoader for the bulk data (train).

  • dataloader_B (DataLoader) – DataLoader for the sc/st data.

  • mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’).

  • infer_mode (str, optional) – Inference data type (‘SC’ or ‘ST’).

  • adj_A (torch.Tensor, optional) – Adjacency matrix for sc/st data (e.g., from connectivities).

  • adj_B (torch.Tensor, optional) – Adjacency matrix for spatial distance (ST only).

  • pre_patho_labels (pd.Series, optional) – Ground truth pathology labels (ST only).

  • optimizer (torch.optim.Optimizer, optional) – The optimizer.

  • alphas (list, optional) – List of floats to weight the loss components [reg_loss, bulk_loss, cosine_loss_exp, spatial/patho_loss].

  • device (str, optional) – The compute device (‘cpu’ or ‘cuda’).

  • epoch (int, optional) – The current epoch number (for prototype loss scheduling).

Returns:

A dictionary of the average loss values for this epoch.

Return type:

dict