tirank.TrainPre

tirank.TrainPre.Train_one_epoch(model, dataloader_A, dataloader_B, mode='Cox', infer_mode='SC', adj_A=None, adj_B=None, pre_patho_labels=None, optimizer=None, alphas=[1, 1, 1, 1], device='cpu', epoch=0)[source]

Performs a single training epoch for the TiRank multi-task model.

This function iterates through the sc/st data (dataloader_B) and computes the composite loss against the bulk data (dataloader_A) and regularization terms (adjacency matrices, MMD, etc.). It then performs a backward pass and optimizer step.

Parameters:
  • model (TiRankModel) – The TiRank model to be trained.

  • dataloader_A (DataLoader) – DataLoader for the bulk data (train).

  • dataloader_B (DataLoader) – DataLoader for the sc/st data.

  • mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’).

  • infer_mode (str, optional) – Inference data type (‘SC’ or ‘ST’).

  • adj_A (torch.Tensor, optional) – Adjacency matrix for sc/st data (e.g., from connectivities).

  • adj_B (torch.Tensor, optional) – Adjacency matrix for spatial distance (ST only).

  • pre_patho_labels (pd.Series, optional) – Ground truth pathology labels (ST only).

  • optimizer (torch.optim.Optimizer, optional) – The optimizer.

  • alphas (list, optional) – List of floats to weight the loss components [reg_loss, bulk_loss, cosine_loss_exp, spatial/patho_loss].

  • device (str, optional) – The compute device (‘cpu’ or ‘cuda’).

  • epoch (int, optional) – The current epoch number (for prototype loss scheduling).

Returns:

A dictionary of the average loss values for this epoch.

Return type:

dict

tirank.TrainPre.Validate_model(model, dataloader_A, dataloader_B, mode='Cox', infer_mode='SC', adj_A=None, adj_B=None, pre_patho_labels=None, alphas=[1, 1, 1], device='cpu')[source]

Performs a single validation step.

This function iterates through the dataloaders, performs a forward pass, and calculates the total validation loss (without backpropagation).

Parameters:
  • model (TiRankModel) – The TiRank model to be evaluated.

  • dataloader_A (DataLoader) – DataLoader for the bulk data (validation).

  • dataloader_B (DataLoader) – DataLoader for the sc/st data.

  • mode (str, optional) – Analysis mode (‘Cox’, ‘Classification’, ‘Regression’).

  • infer_mode (str, optional) – Inference data type (‘SC’ or ‘ST’).

  • adj_A (torch.Tensor, optional) – Adjacency matrix for sc/st data.

  • adj_B (torch.Tensor, optional) – Adjacency matrix for spatial distance (ST only).

  • pre_patho_labels (pd.Series, optional) – Ground truth pathology labels (ST only).

  • alphas (list, optional) – List of floats to weight the loss components.

  • device (str, optional) – The compute device (‘cpu’ or ‘cuda’).

Returns:

The average validation loss for this epoch.

Return type:

float

tirank.TrainPre.Reject_With_GMM_Bio(pred_bulk, pred_sc, tolerance, min_components, max_components)[source]

Performs GMM-based rejection for Classification and Cox modes.

This function identifies phenotype-associated clusters by fitting a GMM to the bulk scores (to find target means 0 and 1) and another GMM to the sc/st scores, then finding sc/st clusters whose means align with the bulk targets within a given tolerance.

Parameters:
  • pred_bulk (np.ndarray) – Predicted scores from the bulk data (n_samples, 1).

  • pred_sc (np.ndarray) – Predicted scores from the sc/st data (n_cells, 1).

  • tolerance (float) – The maximum distance a sc/st cluster mean can be from a bulk target mean to be considered aligned.

  • min_components (int) – The minimum number of GMM components to try.

  • max_components (int) – The maximum number of GMM components to try.

Returns:

A binary mask (n_cells, 1) where 1 indicates a cell

to be rejected (phenotype-independent) and 0 indicates a cell to be kept.

Return type:

np.ndarray

tirank.TrainPre.Reject_With_GMM_Reg(pred_bulk, pred_sc, tolerance)[source]

Performs GMM-based rejection for Regression mode.

Fits a single-component GMM to both bulk and sc/st scores to find their means. If the means are too divergent, rejects all cells. Otherwise, rejects cells that fall outside a tolerance range around the bulk mean.

Parameters:
  • pred_bulk (np.ndarray) – Predicted scores from the bulk data (n_samples, 1).

  • pred_sc (np.ndarray) – Predicted scores from the sc/st data (n_cells, 1).

  • tolerance (float) – The tolerance (std dev or max value) to define the acceptance range around the bulk mean.

Returns:

A binary mask (n_cells, 1) where 1 indicates rejection.

Return type:

np.ndarray

tirank.TrainPre.Reject_With_StrictNumber(pred_bulk, pred_sc, tolerance)[source]

Performs rejection based on a strict percentile range.

Fits a 2-component GMM to bulk scores to find means and std deviations. It then defines an acceptance range based on the percentile (tolerance) of a normal distribution (e.g., tolerance=0.95 keeps the central 95% of each bulk cluster).

Parameters:
  • pred_bulk (np.ndarray) – Predicted scores from the bulk data (n_samples, 1).

  • pred_sc (np.ndarray) – Predicted scores from the sc/st data (n_cells, 1).

  • tolerance (float) – The percentile of the distribution to keep (e.g., 0.95).

Returns:

A binary mask (n_cells, 1) where 1 indicates rejection.

Return type:

np.ndarray

tirank.TrainPre.objective(trial, n_features, nhead, nhid1, nhid2, n_output, nlayers, n_pred, dropout, n_patho, mode, encoder_type, train_loader_Bulk, val_loader_Bulk, train_loader_SC, adj_A, adj_B, pre_patho_labels, device, infer_mode, model_save_path)[source]

The objective function for Optuna hyperparameter optimization.

This function defines the search space for hyperparameters (lr, epochs, loss weights), builds a model, trains it, and returns the validation loss.

Parameters:
  • trial (optuna.trial.Trial) – An Optuna trial object.

  • wrapper) ((All other args are passed from the tune_hyperparameters)

Returns:

The validation loss for the trial.

Return type:

float

tirank.TrainPre.tune_hyperparameters(savePath, device='cpu', n_trials=50)[source]

Runs the Optuna hyperparameter tuning study.

This function loads all necessary data and model parameters, then initializes and runs an Optuna study to find the best hyperparameters by minimizing the validation loss.

Parameters:
  • savePath (str) – The main project directory path.

  • device (str, optional) – The compute device. Defaults to “cpu”.

  • n_trials (int, optional) – The number of Optuna trials to run. Defaults to 50.

Returns:

None

tirank.TrainPre.get_best_model(savePath)[source]

Loads the best performing model from the hyperparameter tuning.

This function reads the ‘best_params.pkl’ file, reconstructs the corresponding model filename, initializes a new TiRankModel with the saved parameters, and loads the weights.

Parameters:

savePath (str) – The main project directory path.

Returns:

The trained TiRank model with the best weights loaded.

Return type:

TiRankModel

tirank.TrainPre.Predict(savePath, mode, do_reject=True, tolerance=0.05, reject_mode='GMM')[source]

Performs inference using the best trained TiRank model.

This function loads the best model, loads the full bulk and sc/st gene pair matrices, and predicts scores for all samples. It then applies the chosen rejection (filtering) method to classify cells/spots as ‘Rank+’, ‘Rank-’, or ‘Background’. The final results are saved to ‘spot_predict_score.csv’ and ‘final_anndata.h5ad’.

Parameters:
  • savePath (str) – The main project directory path.

  • mode (str) – The analysis mode (‘Cox’, ‘Classification’, ‘Regression’).

  • do_reject (bool, optional) – Whether to perform the GMM-based rejection. Defaults to True.

  • tolerance (float, optional) – The tolerance parameter for the rejection method. Defaults to 0.05.

  • reject_mode (str, optional) – The rejection method to use (‘GMM’ or ‘Strict’). Defaults to “GMM”.

Returns:

None

tirank.TrainPre.permute_once(Rank_Labels, Labels, unique_labels)[source]

Helper function for a single permutation test shuffle.

Parameters:
  • Rank_Labels (list) – The list of all ‘Rank_Label’ assignments.

  • Labels (list) – The list of all cluster assignments.

  • unique_labels (set) – The set of unique cluster labels.

Returns:

A dictionary with permuted counts for each cluster.

Return type:

dict

tirank.TrainPre.AssignPcluster(df_p_values)[source]

Assigns a final phenotype (‘Rank+’, ‘Rank-’, ‘Background’) to clusters.

This assignment is based on the p-values from the permutation test.

Parameters:

df_p_values (pd.DataFrame) – DataFrame with p-values for ‘Rank+’, ‘Rank-’, and ‘Background’ enrichment for each cluster.

Returns:

A dictionary mapping the final assignment (‘Rank+’, ‘Rank-‘,

’Background’) to a list of cluster IDs.

Return type:

dict

tirank.TrainPre.Pcluster(savePath, clusterColName, perm_n=1001)[source]

Performs a permutation test to identify significantly enriched clusters.

This function tests whether the observed number of ‘Rank+’ or ‘Rank-’ labels within any given cluster (e.g., ‘leiden_clusters’) is significantly higher than expected by chance.

Parameters:
  • savePath (str) – The main project directory path.

  • clusterColName (str) – The column name in ‘spot_predict_score.csv’ that contains the cluster labels (e.g., ‘leiden_clusters’, ‘patho_class’).

  • perm_n (int, optional) – The number of permutations to run. Defaults to 1001.

Returns:

None

tirank.TrainPre.IdenHub(savePath, cateCol1, cateCol2, min_spots)[source]

Identifies “hubs” by combining two categorical cluster labels.

This function creates a new ‘combine_cluster’ column by merging two existing cluster labels (e.g., ‘patho_class’ and ‘leiden_clusters’). It also filters out any combined clusters that have fewer than min_spots members, labeling them as ‘NA’.

Parameters:
  • savePath (str) – The main project directory path.

  • cateCol1 (str) – The name of the first categorical column in ‘spot_predict_score.csv’.

  • cateCol2 (str) – The name of the second categorical column.

  • min_spots (int) – The minimum number of spots required to keep a combined cluster.

Returns:

None

Functions

tirank.TrainPre.AssignPcluster

Assigns a final phenotype ('Rank+', 'Rank-', 'Background') to clusters.

tirank.TrainPre.IdenHub

Identifies "hubs" by combining two categorical cluster labels.

tirank.TrainPre.Pcluster

Performs a permutation test to identify significantly enriched clusters.

tirank.TrainPre.Predict

Performs inference using the best trained TiRank model.

tirank.TrainPre.Reject_With_GMM_Bio

Performs GMM-based rejection for Classification and Cox modes.

tirank.TrainPre.Reject_With_GMM_Reg

Performs GMM-based rejection for Regression mode.

tirank.TrainPre.Reject_With_StrictNumber

Performs rejection based on a strict percentile range.

tirank.TrainPre.Train_one_epoch

Performs a single training epoch for the TiRank multi-task model.

tirank.TrainPre.Validate_model

Performs a single validation step.

tirank.TrainPre.get_best_model

Loads the best performing model from the hyperparameter tuning.

tirank.TrainPre.objective

The objective function for Optuna hyperparameter optimization.

tirank.TrainPre.permute_once

Helper function for a single permutation test shuffle.

tirank.TrainPre.tune_hyperparameters

Runs the Optuna hyperparameter tuning study.