Source code for tirank.GPextractor
# Gene-pairs (GP) extractor
import numpy as np
import pandas as pd
import os, pickle
from lifelines import CoxPHFitter
from scipy.stats import pearsonr, ttest_ind
from statsmodels.stats.multitest import multipletests
from .Dataloader import transform_test_exp
from .Visualization import plot_genepair
"""
Gene Pair Extractor module for TiRank.
This module contains the GenePairExtractor class, which is the core component
for implementing the Relative Expression Ordering (REO) transformation.
It identifies phenotype-associated genes from bulk data, forms gene pairs,
filters them, and then applies this transformation to both bulk and
single-cell/spatial data.
"""
[docs]
class GenePairExtractor:
"""
A class to extract and filter phenotype-associated gene pairs (PGPs).
This class loads bulk and sc/st expression data, identifies genes associated
with a clinical phenotype (via t-test, Cox regression, or Pearson correlation),
creates all possible pairs between positive and negative-associated genes,
filters these pairs based on co-occurrence and variance, and finally
transforms both bulk and sc/st datasets into gene pair matrices.
Args:
savePath (str): The main project directory path.
analysis_mode (str): The analysis mode ('Classification', 'Cox', 'Regression').
top_var_genes (int, optional): The number of top variable genes to
pre-filter from the sc/st data. Defaults to 500.
top_gene_pairs (int, optional): The number of top variable gene pairs
to select after filtering. Defaults to 2000.
p_value_threshold (float, optional): P-value threshold for selecting
phenotype-associated genes. Defaults to None (which may be an issue,
but reflects the original code's None default).
max_cutoff (float, optional): Maximum co-occurrence proportion for filtering
gene pairs (removes highly redundant pairs). Defaults to 0.8.
min_cutoff (float, optional): Minimum co-occurrence proportion for filtering
gene pairs (removes pairs with no co-occurrence). Defaults to 0.2.
"""
def __init__(
self,
savePath,
analysis_mode,
top_var_genes=500,
top_gene_pairs=2000,
p_value_threshold=None,
max_cutoff=0.8,
min_cutoff=0.2,
):
self.savePath = savePath
self.analysis_mode = analysis_mode
self.top_var_genes = top_var_genes
self.top_gene_pairs = top_gene_pairs
self.p_value_threshold = p_value_threshold
self.max_cutoff = max_cutoff
self.min_cutoff = min_cutoff
[docs]
def load_data(self):
"""
Loads the required expression and clinical data from disk.
Loads 'bulkExp_train.pkl', 'bulkClinical_train.pkl', and 'scAnndata.pkl'
from the '2_preprocessing' directory and stores them as attributes.
Returns:
None
"""
print(f"Starting load data for gene pair transformation.")
savePath_2 = os.path.join(self.savePath, "2_preprocessing")
savePath_splitData = os.path.join(savePath_2, "split_data")
## bulk
f = open(os.path.join(savePath_splitData, "bulkExp_train.pkl"), "rb")
self.bulk_expression = pickle.load(f)
f.close()
f = open(os.path.join(savePath_splitData, "bulkClinical_train.pkl"), "rb")
self.clinical_data = pickle.load(f)
f.close()
## sc
f = open(os.path.join(savePath_2, "scAnndata.pkl"), "rb")
scAnndata = pickle.load(f)
f.close()
if type(scAnndata.X) == type(np.array(1)):
scExp = pd.DataFrame(scAnndata.X.T)
else:
scExp = pd.DataFrame(scAnndata.X.toarray().T)
scExp.index = scAnndata.var_names
scExp.columns = scAnndata.obs.index
self.single_cell_expression = scExp
return None
[docs]
def save_data(self):
"""
Saves the generated gene pair matrices to disk.
Saves 'train_bulk_gene_pairs_mat.pkl', 'val_bulkExp_gene_pairs_mat.pkl',
and 'sc_gene_pairs_mat.pkl' to the '2_preprocessing' directory.
The validation matrix is created by transforming the validation
expression data using the training gene pairs.
Returns:
None
"""
print(f"Starting save gene pair matrices.")
savePath_2 = os.path.join(self.savePath, "2_preprocessing")
savePath_splitData = os.path.join(savePath_2, "split_data")
## Load val bulk
f = open(os.path.join(savePath_splitData, "bulkExp_val.pkl"), "rb")
bulkExp_val = pickle.load(f)
f.close()
train_bulk_gene_pairs_mat = pd.DataFrame(self.bulk_gene_pairs_mat.T)
val_bulkExp_gene_pairs_mat = transform_test_exp(
train_exp=train_bulk_gene_pairs_mat, test_exp=bulkExp_val
)
sc_gene_pairs_mat = pd.DataFrame(self.single_cell_gene_pairs_mat.T)
with open(os.path.join(savePath_2, "train_bulk_gene_pairs_mat.pkl"), "wb") as f:
pickle.dump(train_bulk_gene_pairs_mat, f) ## training bulk gene pair matrix
f.close()
with open(
os.path.join(savePath_2, "val_bulkExp_gene_pairs_mat.pkl"), "wb"
) as f:
pickle.dump(
val_bulkExp_gene_pairs_mat, f
) ## validating bulk gene pair matrix
f.close()
with open(os.path.join(savePath_2, "sc_gene_pairs_mat.pkl"), "wb") as f:
pickle.dump(sc_gene_pairs_mat, f) ## single cell gene pair matrix
f.close()
print(f"Save gene pair matrices done.")
return None
[docs]
def run_extraction(self):
"""
Main orchestration function to run the full gene pair extraction pipeline.
This function performs the following steps:
1. Finds intersecting genes between bulk and sc/st data.
2. Selects top variable genes from sc/st data.
3. Subsets all expression data to these genes.
4. Identifies phenotype-associated gene sets (e.g., risk/protective)
based on the specified 'analysis_mode'.
5. Transforms the bulk expression data into a gene pair matrix.
6. Filters the gene pair matrix by co-occurrence and variance.
7. Transforms the sc/st expression data using the filtered gene pairs.
8. Saves the final matrices as attributes and plots them.
Returns:
None
"""
print(f"Starting gene pair extraction.")
# Find the intersection of genes in bulk and single-cell datasets
intersect_genes = np.intersect1d(
self.single_cell_expression.index, self.bulk_expression.index
)
intersect_single_cell_expression = self.single_cell_expression.loc[
intersect_genes
]
# Sort genes by variance in the single-cell dataset
gene_variances = np.var(intersect_single_cell_expression, axis=1)
sorted_genes = gene_variances.sort_values(ascending=False)
# Select the top variable genes
top_variable_genes = sorted_genes[: self.top_var_genes].index.tolist()
# Extract the candidate genes
self.bulk_expression, self.single_cell_expression = (
self.extract_candidate_genes(top_variable_genes)
)
print(f"Get candidate genes done.")
# Obtain the list of candidate genes
if self.analysis_mode == "Classification":
regulated_genes_r, regulated_genes_p = self.calculate_binomial_gene_pairs()
print(
f"There are {len(regulated_genes_r)} genes up-regulated in Group 0 and {len(regulated_genes_p)} genes up-regulated in Group 1."
)
elif self.analysis_mode == "Cox":
regulated_genes_r, regulated_genes_p = self.calculate_survival_gene_pairs()
print(
f"There are {len(regulated_genes_r)} Risk genes and {len(regulated_genes_p)} Protective genes."
)
elif self.analysis_mode == "Regression":
regulated_genes_r, regulated_genes_p = (
self.calculate_regression_gene_pairs()
)
print(
f"There are {len(regulated_genes_r)} positive-associated genes and {len(regulated_genes_p)} negative-associated genes."
)
else:
raise ValueError(f"Unsupported mode: {self.analysis_mode}")
if (len(regulated_genes_r) == 0) or (len(regulated_genes_p) == 0):
raise ValueError(
"A set of genes is empty. Try increasing the 'top_var_genes' value or loosening the 'p.value' threshold."
)
print(f"Get candidate gene pairs done.")
# Transform the bulk gene pairs
bulk_gene_pairs = self.transform_bulk_gene_pairs(
regulated_genes_r, regulated_genes_p
)
# Filter the gene pairs
bulk_gene_pairs_mat = self.filter_gene_pairs(bulk_gene_pairs)
# Transform the single-cell gene pairs
single_cell_gene_pairs_mat = self.transform_single_cell_gene_pairs(
bulk_gene_pairs_mat
)
print(f"Profile transformation done.")
# Return the bulk and single-cell gene pairs
self.bulk_gene_pairs_mat = bulk_gene_pairs_mat
self.single_cell_gene_pairs_mat = single_cell_gene_pairs_mat
# Visualize the gene pair
plot_genepair(self.bulk_gene_pairs_mat, "bulk", self.savePath)
plot_genepair(self.single_cell_gene_pairs_mat, "sc", self.savePath)
return None
[docs]
def extract_candidate_genes(self, gene_names):
"""
Subsets the expression matrices to a list of candidate genes.
Args:
gene_names (list): A list of gene names to keep.
Returns:
tuple: A tuple containing:
- pd.DataFrame: The subsetted bulk expression matrix.
- pd.DataFrame: The subsetted single-cell expression matrix.
"""
# Construct gene pairs
single_cell_gene_subset = self.single_cell_expression.loc[gene_names]
bulk_gene_subset = self.bulk_expression.loc[gene_names, :]
# Remove rows in bulk dataset where all entries are 0
bulk_gene_subset = bulk_gene_subset.loc[(bulk_gene_subset != 0).any(axis=1)]
gene_names = bulk_gene_subset.index.tolist()
single_cell_gene_subset = single_cell_gene_subset.loc[gene_names]
return bulk_gene_subset, single_cell_gene_subset
[docs]
def calculate_binomial_gene_pairs(self):
"""
Finds phenotype-associated genes for 'Classification' mode.
Performs a t-test for each gene between two groups in the clinical data.
Returns:
tuple: A tuple containing:
- list: Genes up-regulated in group 0 (t-stat > 0).
- list: Genes up-regulated in group 1 (t-stat < 0).
"""
# Calculate group means and perform t-test
group_labels = self.clinical_data.iloc[:, 0]
group_0 = self.bulk_expression.loc[:, group_labels == 0]
group_1 = self.bulk_expression.loc[:, group_labels == 1]
# Calculate t-tests and log fold changes
p_values = []
t_stats = []
for gene in self.bulk_expression.index:
t_stat, p_value = ttest_ind(group_0.loc[gene], group_1.loc[gene])
t_stats.append(t_stat)
p_values.append(p_value)
# Store the results in a DataFrame
DEGs = pd.DataFrame(
{
"AveExpr": self.bulk_expression.mean(axis=1),
"t": t_stats,
"P.Value": p_values,
"gene": self.bulk_expression.index,
}
)
# Drop the row which p_values is NULL
DEGs = DEGs.dropna()
# Adjust p-values for multiple testing
# DEGs['adj.P.Val'] = multipletests(DEGs['P.Value'], method='fdr_bh')[1]
# Filter significant genes
DEGs = DEGs[DEGs["P.Value"] < self.p_value_threshold]
# Separate up- and down-regulated genes
regulated_genes_in_g0 = DEGs[DEGs["t"] > 0]["gene"].tolist()
regulated_genes_in_g1 = DEGs[DEGs["t"] < 0]["gene"].tolist()
return regulated_genes_in_g0, regulated_genes_in_g1
[docs]
def calculate_survival_gene_pairs(self):
"""
Finds phenotype-associated genes for 'Cox' survival mode.
Performs a univariate Cox proportional hazards model for each gene.
Returns:
tuple: A tuple containing:
- list: Risk genes (Hazard Ratio > 1).
- list: Protective genes (Hazard Ratio < 1).
"""
# Perform univariate Cox analysis on the bulk dataset using CoxPHFitter
survival_results = pd.DataFrame(columns=["gene", "HR", "p_value"])
for i in range(self.bulk_expression.shape[0]):
exp_gene = self.bulk_expression.iloc[i, :].astype(float)
clinical_temp = pd.concat([self.clinical_data, exp_gene], axis=1)
cph = CoxPHFitter()
try:
cph.fit(
clinical_temp,
duration_col=self.clinical_data.columns[0],
event_col=self.clinical_data.columns[1],
)
except Exception:
continue
hr = cph.summary["exp(coef)"].values[0]
p_value = cph.summary["p"].values[0]
survival_results = survival_results.append(
{"gene": self.bulk_expression.index[i], "HR": hr, "p_value": p_value},
ignore_index=True,
)
survival_results = survival_results.dropna()
survival_results["HR"] = survival_results["HR"].astype(float)
survival_results["p_value"] = survival_results["p_value"].astype(float)
# survival_results['adj.P.Val'] = multipletests(survival_results['p_value'], method='fdr_bh')[1Next_Step]
# 'adj.P.Val'], method='fdr_bh')[1]
# Filter significant genes
survival_results = survival_results[
survival_results["p_value"] < self.p_value_threshold
]
# Construct gene pairs for HR>1 and HR<1 separately
regulated_genes_r = survival_results[survival_results["HR"] > 1]["gene"]
regulated_genes_p = survival_results[survival_results["HR"] < 1]["gene"]
return regulated_genes_r, regulated_genes_p
[docs]
def calculate_regression_gene_pairs(self):
"""
Finds phenotype-associated genes for 'Regression' mode.
Performs a Pearson correlation for each gene against the continuous
clinical variable.
Returns:
tuple: A tuple containing:
- list: Positively correlated genes.
- list: Negatively correlated genes.
"""
# Bulk dataset Pearson correlation. Define gene pairs based on correlation coefficient and p-value
correlation_results = pd.DataFrame(columns=["gene", "correlation", "pvalue"])
for i in range(self.bulk_expression.shape[0]):
exp_gene = self.bulk_expression.iloc[i, :].astype(float)
correlation, pvalue = pearsonr(exp_gene, self.clinical_data.iloc[:, 0])
correlation_results = pd.concat(
[
correlation_results,
pd.Series(
{
"gene": self.bulk_expression.index[i],
"correlation": correlation,
"pvalue": pvalue,
}
)
.to_frame()
.T,
],
axis=0,
ignore_index=True,
)
correlation_results = correlation_results.dropna()
correlation_results["correlation"] = correlation_results["correlation"].astype(
float
)
correlation_results["pvalue"] = correlation_results["pvalue"].astype(float)
# correlation_results['adj.P.Val'] = multipletests(
# correlation_results['pvalue'], method='fdr_bh')[1]
# Filter significant genes
correlation_results = correlation_results[
correlation_results["pvalue"] < self.p_value_threshold
]
# Define gene pairs based on whether correlation is >0 or <0
positive_correlation_genes = correlation_results[
correlation_results["correlation"] > 0
]["gene"]
negative_correlation_genes = correlation_results[
correlation_results["correlation"] < 0
]["gene"]
return positive_correlation_genes, negative_correlation_genes
# Construct bulk gene pairs
[docs]
def transform_bulk_gene_pairs(self, genes_r, genes_p):
"""
Transforms the bulk expression matrix into a gene pair matrix (REO).
Creates all possible pairs between the two gene sets (e.g., risk/protective).
A pair is 1 if gene_r > gene_p, else -1.
Args:
genes_r (list): The list of genes for the "positive" set (e.g., risk genes).
genes_p (list): The list of genes for the "negative" set (e.g., protective genes).
Returns:
pd.DataFrame: The transformed bulk gene pair matrix (gene pairs x samples).
"""
# Get genes
exp1 = self.bulk_expression.loc[genes_r]
exp2 = self.bulk_expression.loc[genes_p]
# Compute result matrix
result_values = np.where(exp1.values[:, None] > exp2.values, 1, -1)
result_values = np.vstack(result_values)
# Create result DataFrame
row_names = [f"{i}__{j}" for i in genes_r for j in genes_p]
result_df = pd.DataFrame(
result_values, index=row_names, columns=self.bulk_expression.columns
)
return result_df
[docs]
def filter_gene_pairs(self, bulk_GPMat):
"""
Filters the bulk gene pair matrix based on co-occurrence and variance.
Args:
bulk_GPMat (pd.DataFrame): The raw bulk gene pair matrix.
Returns:
pd.DataFrame: The filtered bulk gene pair matrix.
"""
# Filter results of gene pair construction. max_cutoff and min_cutoff define the upper and lower proportions
bulk_GPMat = bulk_GPMat[
(np.sum(bulk_GPMat, axis=1) < self.max_cutoff * bulk_GPMat.shape[1])
& (np.sum(bulk_GPMat, axis=1) > self.min_cutoff * bulk_GPMat.shape[1])
]
if bulk_GPMat.shape[0] >= self.top_gene_pairs:
# Compute variance of gene pairs and sort
gene_pair_variances = np.var(bulk_GPMat, axis=1)
sorted_gene_pairs = gene_pair_variances.sort_values(ascending=False)
# Select top variable gene pairs
top_var_gene_pairs = sorted_gene_pairs[: self.top_gene_pairs].index.tolist()
bulk_GPMat = bulk_GPMat.loc[top_var_gene_pairs]
return bulk_GPMat
[docs]
def transform_single_cell_gene_pairs(self, bulk_GPMat):
"""
Transforms the sc/st expression matrix into a gene pair matrix.
Uses the *exact* same gene pairs that were filtered from the bulk data.
Args:
bulk_GPMat (pd.DataFrame): The filtered bulk gene pair matrix.
The index of this DataFrame defines the gene pairs to use.
Returns:
pd.DataFrame: The transformed sc/st gene pair matrix.
"""
# Get gene pairs
gene_pairs = bulk_GPMat.index.tolist()
# Split gene pairs
genes_1, genes_2 = self.split_gene_pairs(gene_pairs)
# Construct gene pairs
exp1 = self.single_cell_expression.loc[genes_1]
exp2 = self.single_cell_expression.loc[genes_2]
result = np.where(exp1.values > exp2.values, 1, -1)
# Create result DataFrame
result_df = pd.DataFrame(
result, index=bulk_GPMat.index, columns=self.single_cell_expression.columns
)
return result_df
[docs]
def split_gene_pairs(self, gene_pairs):
"""
Helper function to split gene pair names.
Args:
gene_pairs (list): A list of gene pair strings (e.g., "GENE1__GENE2").
Returns:
tuple: A tuple containing:
- list: The list of first genes (e.g., "GENE1").
- list: The list of second genes (e.g., "GENE2").
"""
# Split gene pairs to get two lists of genes, gene1 and gene2.
# gene1 contains the genes in the first position of gene_pairs, gene2 contains the second genes.
gene1 = [x.split("__")[0] for x in gene_pairs]
gene2 = [x.split("__")[1] for x in gene_pairs]
return gene1, gene2