tirank.TrainPre.Train_one_epoch
- tirank.TrainPre.Train_one_epoch(model, dataloader_A, dataloader_B, mode='Cox', infer_mode='SC', adj_A=None, adj_B=None, pre_patho_labels=None, optimizer=None, alphas=[1, 1, 1, 1], device='cpu', epoch=0)[source]
Performs a single training epoch for the TiRank multi-task model.
This function iterates through the sc/st data (dataloader_B) and computes the composite loss against the bulk data (dataloader_A) and regularization terms (adjacency matrices, MMD, etc.). It then performs a backward pass and optimizer step.
- Parameters:
model (TiRankModel) – The TiRank model to be trained.
dataloader_A (DataLoader) – DataLoader for the bulk data (train).
dataloader_B (DataLoader) – DataLoader for the sc/st data.
mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’).
infer_mode (str, optional) – Inference data type (‘SC’ or ‘ST’).
adj_A (torch.Tensor, optional) – Adjacency matrix for sc/st data (e.g., from connectivities).
adj_B (torch.Tensor, optional) – Adjacency matrix for spatial distance (ST only).
pre_patho_labels (pd.Series, optional) – Ground truth pathology labels (ST only).
optimizer (torch.optim.Optimizer, optional) – The optimizer.
alphas (list, optional) – List of floats to weight the loss components [reg_loss, bulk_loss, cosine_loss_exp, spatial/patho_loss].
device (str, optional) – The compute device (‘cpu’ or ‘cuda’).
epoch (int, optional) – The current epoch number (for prototype loss scheduling).
- Returns:
A dictionary of the average loss values for this epoch.
- Return type:
dict