Tutorial 1: Spatial Transcriptomics (ST) + Cox Survival Analysis

This tutorial demonstrates TiRank survival analysis (Cox mode) using spatial transcriptomics as inference data and bulk survival labels as supervision.

Example resources

Example datasets are hosted on Zenodo:

Recommended placement:

TiRank/data/ExampleData/CRC_ST_Prog/
├── GSE39582_clinical_os.csv
├── GSE39582_exp_os.csv
└── SN048_A121573_Rep1/        (ST folder)

Run the example script (Python)

From the repository root:

python Example/ST-Cox-CRC.py

Notes

  • If your local data paths differ, edit the dataPath / savePath variables at the top of the example script.

  • For a fully automated run with environment management, see Snakemake Workflow.

Example script (for reference)

  1# tirank Analysis Pipeline Example
  2# This script demonstrates how to use the tirank library to integrate spatial transcriptomics (ST)
  3# and bulk transcriptomics data to identify phenotype-associated spots and determine significant clusters.
  4
  5# Import necessary libraries and modules
  6import warnings
  7warnings.filterwarnings("ignore")
  8
  9import torch
 10import pickle
 11import os
 12
 13from tirank.Model import setup_seed, initial_model_para
 14from tirank.LoadData import *
 15from tirank.SCSTpreprocess import *
 16from tirank.Imageprocessing import GetPathoClass
 17from tirank.GPextractor import GenePairExtractor
 18from tirank.Dataloader import generate_val, PackData
 19from tirank.TrainPre import tune_hyperparameters, Predict, Pcluster, IdenHub
 20from tirank.Visualization import plot_score_distribution, DEG_analysis, DEG_volcano, Pathway_Enrichment
 21from tirank.Visualization import plot_score_umap, plot_label_distribution_among_conditions,plot_STmap
 22
 23# Set random seed for reproducibility
 24setup_seed(619)
 25
 26# --------------------------------------------
 27# 1. Load Data
 28# --------------------------------------------
 29
 30## 1.1 Select a path to save the results
 31savePath = "./ST_Survival_CRC"  # Main directory for saving results
 32savePath_1 = os.path.join(savePath, "1_loaddata")
 33if not os.path.exists(savePath_1):
 34    os.makedirs(savePath_1, exist_ok=True)
 35
 36## 1.2 Load clinical data
 37dataPath = "./data/ExampleData/CRC_ST_Prog/"  # Directory containing your data
 38path_to_bulk_cli = os.path.join(dataPath, "GSE39582_clinical_os.csv")
 39bulkClinical = load_bulk_clinical(path_to_bulk_cli)
 40view_dataframe(bulkClinical)  # Optional: view the clinical data DataFrame
 41
 42## 1.3 Load bulk expression profile
 43path_to_bulk_exp = os.path.join(dataPath, "GSE39582_exp_os.csv")
 44bulkExp = load_bulk_exp(path_to_bulk_exp)
 45view_dataframe(bulkExp)  # Optional: view the bulk expression DataFrame
 46
 47## 1.4 Check consistency between bulk expression and clinical data
 48check_bulk(savePath, bulkExp, bulkClinical)
 49
 50## 1.5 Load spatial transcriptomics (ST) data
 51path_to_st_folder = os.path.join(dataPath, "SN048_A121573_Rep1")
 52scAnndata = load_st_data(path_to_st_folder, savePath)
 53st_exp_df = transfer_exp_profile(scAnndata)
 54view_dataframe(st_exp_df)  # Optional: view the ST expression DataFrame
 55
 56# --------------------------------------------
 57# 2. Preprocessing
 58# --------------------------------------------
 59
 60## 2.1 Select a path to save preprocessing results
 61savePath_2 = os.path.join(savePath, "2_preprocessing")
 62if not os.path.exists(savePath_2):
 63    os.makedirs(savePath_2, exist_ok=True)
 64
 65## 2.2 Load the saved AnnData object from step 1
 66with open(os.path.join(savePath_1, "anndata.pkl"), "rb") as f:
 67    scAnndata = pickle.load(f)
 68
 69## 2.3 Preprocess the ST data
 70# Define the inference mode (e.g., "ST" for spatial transcriptomics)
 71infer_mode = "ST"  # Optional parameter
 72
 73# Filtering the data based on counts and mitochondrial gene proportion
 74scAnndata = FilteringAnndata(
 75    scAnndata,
 76    max_count=35000,    # Maximum total counts per cell
 77    min_count=5000,     # Minimum total counts per cell
 78    MT_propor=10,       # Maximum percentage of mitochondrial genes
 79    min_cell=10,        # Minimum number of cells expressing the gene
 80    imgPath=savePath_2  # Path to save images/results
 81)
 82# Optional parameters: max_count, min_count, MT_propor, min_cell
 83
 84# Normalize the data
 85scAnndata = Normalization(scAnndata)
 86
 87# Log-transform the data
 88scAnndata = Logtransformation(scAnndata)
 89
 90# Perform clustering on the data
 91scAnndata = Clustering(scAnndata, infer_mode=infer_mode, savePath=savePath)
 92
 93# Compute similarity matrix (optional distance calculation)
 94compute_similarity(
 95    savePath=savePath,
 96    ann_data=scAnndata,
 97    calculate_distance=False  # Set to True if distance calculation is needed
 98)
 99
100# Path to the pre-trained image processing model (ensure this file is in the package)
101# Note: Ensure you have downloaded ctranspath.pth into data/pretrainModel/
102pretrain_path = "./data/pretrainModel/ctranspath.pth"
103
104# Number of pathological clusters to identify
105n_patho_cluster = 7  # Optional variable (adjust based on your data)
106
107# Perform image processing to get pathological classifications
108scAnndata = GetPathoClass(
109    adata=scAnndata,
110    pretrain_path=pretrain_path,
111    n_clusters=n_patho_cluster,
112    image_save_path=os.path.join(savePath_2, "patho_label.png")
113    # Advanced parameters: n_components (PCA components), n_clusters
114)
115
116# Save the processed AnnData object
117with open(os.path.join(savePath_2, "scAnndata.pkl"), "wb") as f:
118    pickle.dump(scAnndata, f)
119
120## 2.4 Clinical data preparation and splitting bulk data
121# Define the analysis mode (e.g., "Cox" for survival analysis)
122mode = "Cox"
123
124# Split data into training and validation sets
125generate_val(
126    savePath=savePath,
127    validation_proportion=0.15,  # Optional parameter: proportion of data for validation
128    mode=mode
129)
130
131## 2.5 Gene pair transformation
132# Initialize the GenePairExtractor with parameters
133GPextractor = GenePairExtractor(
134    savePath=savePath,
135    analysis_mode=mode,
136    top_var_genes=2000,       # Optional: number of top variable genes to select
137    top_gene_pairs=1000,      # Optional: number of top gene pairs to select
138    p_value_threshold=0.05,   # Optional: p-value threshold for gene pair selection
139    max_cutoff=0.8,           # Optional: upper cutoff for correlation coefficient
140    min_cutoff=-0.8           # Optional: lower cutoff for correlation coefficient
141)
142
143# Load data for gene pair extraction
144GPextractor.load_data()
145
146# Run the gene pair extraction process
147GPextractor.run_extraction()
148
149# Save the extracted gene pairs
150GPextractor.save_data()
151
152# --------------------------------------------
153# 3. Analysis
154# --------------------------------------------
155
156## 3.1 tirank Analysis
157# Define paths for saving analysis results
158savePath_3 = os.path.join(savePath, "3_Analysis")
159if not os.path.exists(savePath_3):
160    os.makedirs(savePath_3, exist_ok=True)
161
162### 3.1.1 Data Loading and Preparation
163# Ensure the 'mode' variable is consistent throughout the analysis
164mode = "Cox"          # Analysis mode (e.g., "Cox" for survival analysis)
165infer_mode = "ST"     # Inference mode (e.g., "ST" for spatial transcriptomics)
166device = "cuda" if torch.cuda.is_available() else "cpu"  # Use GPU if available
167
168# Pack the data into DataLoader objects for training and validation
169PackData(
170    savePath=savePath,
171    mode=mode,
172    infer_mode=infer_mode,
173    batch_size=1024   # Optional parameter: batch size for DataLoader
174)
175
176### 3.1.2 Model Training
177# Set the encoder type for the model (e.g., "MLP" for multi-layer perceptron)
178encoder_type = "MLP"  # Optional parameter (options: "MLP", "Transformer", etc.)
179
180# Initialize model parameters
181initial_model_para(
182    savePath=savePath,
183    nhead=2,           # Optional: number of heads in multi-head attention (if using Transformer)
184    nhid1=96,          # Optional: hidden layer size 1
185    nhid2=8,           # Optional: hidden layer size 2
186    n_output=32,       # Optional: output size
187    nlayers=3,         # Optional: number of layers
188    n_pred=1,          # Optional: number of predictions (e.g., 1 for regression)
189    dropout=0.5,       # Optional: dropout rate
190    mode=mode,
191    encoder_type=encoder_type,
192    infer_mode=infer_mode
193)
194
195# Tune hyperparameters using Optuna or other optimization libraries
196tune_hyperparameters(
197    savePath=savePath,
198    device=device,
199    n_trials=5    # Optional parameter: number of hyperparameter tuning trials
200)
201
202### 3.1.3 Model Inference
203# Predict phenotype-associated spots and perform rejection (uncertainty estimation)
204Predict(
205    savePath=savePath,
206    mode=mode,
207    do_reject=True,        # Optional: whether to perform rejection
208    tolerance=0.05,        # Optional: tolerance level for rejection
209    reject_mode="GMM"      # Optional: rejection mode (e.g., "GMM" for Gaussian Mixture Model)
210)
211
212### 3.1.4 Identify Hubs and Significant Clusters
213# Identify hub spots based on categorical columns
214IdenHub(
215    savePath=savePath,
216    cateCol1="patho_class",        # First categorical column (e.g., pathological class)
217    cateCol2="leiden_clusters",    # Second categorical column (e.g., clustering result)
218    min_spots=10                   # Optional: minimum number of spots to consider a hub
219)
220
221# Perform permutation tests to identify significant clusters
222Pcluster(savePath=savePath, clusterColName="patho_class", perm_n=1001)
223Pcluster(savePath=savePath, clusterColName="leiden_clusters", perm_n=1001)
224Pcluster(savePath=savePath, clusterColName="combine_cluster", perm_n=1001)
225
226### 3.1.5 Visualization
227# Plot the distribution of prediction scores
228plot_score_distribution(savePath)  # Displays the probability score distribution
229
230# Plot UMAP embedding colored by prediction scores
231plot_score_umap(savePath, infer_mode)
232
233# Plot the distribution of labels among different conditions
234plot_label_distribution_among_conditions(savePath, group="patho_class")
235plot_label_distribution_among_conditions(savePath, group="leiden_clusters")
236plot_label_distribution_among_conditions(savePath, group="combine_cluster")
237
238# Plot spatial maps of the spots with cluster labels
239plot_STmap(savePath=savePath, group="combine_cluster")
240
241## 3.2 Differential Expression and Pathway Enrichment Analysis
242# Set thresholds for differential expression analysis
243fc_threshold = 2          # Optional: fold-change threshold
244Pvalue_threshold = 0.05   # Optional: p-value threshold
245do_p_adjust = True        # Optional: whether to adjust p-values for multiple testing
246
247# Perform differential expression analysis
248DEG_analysis(
249    savePath=savePath,
250    fc_threshold=fc_threshold,
251    Pvalue_threshold=Pvalue_threshold,
252    do_p_adjust=do_p_adjust
253)
254
255# Plot volcano plots for differential expression results
256DEG_volcano(
257    savePath=savePath,
258    fc_threshold=fc_threshold,
259    Pvalue_threshold=Pvalue_threshold,
260    do_p_adjust=do_p_adjust
261)
262
263# Perform pathway enrichment analysis using specified databases
264# Available databases can be found at: https://maayanlab.cloud/Enrichr/#libraries
265Pathway_Enrichment(
266    savePath=savePath,
267    database=["GO_Biological_Process_2023"]  # Optional: replace with desired databases
268)