# Preprocessing function for scRNA-seq data
import scanpy as sc
import pandas as pd
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
from scipy.stats import zscore
# unbalanced
from imblearn.over_sampling import SMOTE, RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler, TomekLinks
"""
Preprocessing utilities for scRNA-seq and ST data using Scanpy.
This module provides a pipeline of functions to load, filter, normalize,
log-transform, and cluster single-cell and spatial data. It also includes
functions for handling class imbalance in bulk training data (sampling) and
for computing similarity/distance matrices.
"""
[docs]
def merge_datasets(bulkClinical_1, bulkClinical_2, bulkExp_1, bulkExp_2):
"""
Merges two bulk expression and clinical datasets, finding intersecting genes.
Args:
bulkClinical_1 (pd.DataFrame): Clinical data for the first cohort.
bulkClinical_2 (pd.DataFrame): Clinical data for the second cohort.
bulkExp_1 (pd.DataFrame): Expression data for the first cohort (genes x samples).
bulkExp_2 (pd.DataFrame): Expression data for the second cohort (genes x samples).
Returns:
tuple: A tuple containing:
- pd.DataFrame: The merged expression DataFrame.
- pd.DataFrame: The merged clinical DataFrame.
Or returns 0 if no intersecting genes are found.
"""
genes1 = {x for x in bulkExp_1.index.values}
genes2 = {x for x in bulkExp_2.index.values}
intersectGenes = genes1.intersection(genes2)
if len(intersectGenes) == 0:
print(f"The length of interaction genes between these two bulk RNA-seq datasets was zero!")
return 0
intersectGenes_list = [x for x in intersectGenes]
bulkExp_1 = bulkExp_1.loc[intersectGenes_list, :]
bulkExp_2 = bulkExp_2.loc[intersectGenes_list, :]
bulkClinical = np.vstack((bulkClinical_1, bulkClinical_2))
bulkExp = np.hstack((bulkExp_1, bulkExp_2))
pid1 = [x for x in bulkExp_1.columns]
pid2 = [y for y in bulkExp_2.columns]
pid1.extend(pid2)
pd.DataFrame(bulkExp)
bulkClinical = pd.DataFrame(
bulkClinical, columns=bulkClinical_1.columns, index=pid1)
bulkExp = pd.DataFrame(bulkExp, columns=pid1, index=bulkExp_1.index)
return bulkExp, bulkClinical
[docs]
def normalize_data(exp):
"""
Normalize gene expression data using z-score normalization (row-wise).
Args:
exp (pd.DataFrame): A pandas DataFrame with genes as rows and samples as columns.
Returns:
pd.DataFrame: A z-score normalized DataFrame.
"""
# Apply z-score normalization
normalized_exp = exp.apply(zscore, axis=1)
normalized_exp = normalized_exp.dropna()
return normalized_exp
[docs]
def is_imbalanced(bulkClinical, threshold):
"""
Checks if the primary clinical variable is imbalanced.
Args:
bulkClinical (pd.DataFrame): DataFrame with clinical data. Assumes
the variable of interest is in the first column.
threshold (float): The minimum proportion for a class to be
considered 'balanced'.
Returns:
bool: True if the minority class is below the threshold, False otherwise.
"""
counts = bulkClinical.iloc[:, 0].value_counts(normalize=True)
return counts.min() < threshold
# Perform standard workflow on ST or SC
# Filtering cells or spots
[docs]
def FilteringAnndata(adata, max_count=35000, min_count=5000, MT_propor=10, min_cell=10, imgPath="./"):
"""
Filters an AnnData object based on QC metrics.
Filters cells/spots based on total counts and mitochondrial percentage.
Filters genes based on minimum cell count. Also saves a QC violin plot.
Args:
adata (sc.AnnData): The AnnData object to filter.
max_count (int, optional): Maximum total counts per cell/spot.
Defaults to 35000.
min_count (int, optional): Minimum total counts per cell/spot.
Defaults to 5000.
MT_propor (int, optional): Maximum percentage of mitochondrial gene
counts. Defaults to 10.
min_cell (int, optional): Minimum number of cells/spots a gene must
be expressed in. Defaults to 10.
imgPath (str, optional): Path to save the QC violin plot.
Defaults to "./".
Returns:
sc.AnnData: The filtered AnnData object.
"""
adata.var_names_make_unique()
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)
# Plot
sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt'],jitter=0.4, multi_panel=True,show=False)
plt.savefig(os.path.join(imgPath,"qc_violins.png"))
plt.close()
# Filtering
sc.pp.filter_cells(adata, min_counts=min_count)
sc.pp.filter_cells(adata, max_counts=max_count)
adata = adata[adata.obs["pct_counts_mt"] < MT_propor]
sc.pp.filter_genes(adata, min_cells=min_cell)
return adata
# Normalization
[docs]
def Normalization(adata):
"""
Performs total count normalization (target_sum=1e4) on an AnnData object.
Args:
adata (sc.AnnData): The AnnData object.
Returns:
sc.AnnData: The normalized AnnData object.
"""
sc.pp.normalize_total(adata, target_sum=1e4, inplace = True)
return adata
# log-transformation
[docs]
def Clustering(ann_data,infer_mode, savePath):
"""
Performs standard clustering (HVGs, PCA, neighbors, UMAP, Leiden).
If neighbors are already computed, it just re-runs Leiden. Otherwise,
it runs the full pipeline. Saves a UMAP or spatial plot.
Args:
ann_data (sc.AnnData): The AnnData object.
infer_mode (str): The inference data type ('SC' or 'ST') for plotting.
savePath (str): The main project directory path to save plots.
Returns:
sc.AnnData: The clustered AnnData object.
"""
savePath_2 = os.path.join(savePath,"2_preprocessing")
if ('connectivities' in ann_data.obsp) and ('leiden' in ann_data.uns):
sc.tl.leiden(ann_data, key_added="leiden_clusters")
else:
# Identify highly variable genes
sc.pp.highly_variable_genes(ann_data, flavor="seurat", n_top_genes=2000)
# Perform PCA dimension reduction
sc.tl.pca(ann_data)
# Computing the neighborhood graph
sc.pp.neighbors(ann_data, use_rep='X_pca')
# UMAP and leiden
sc.tl.umap(ann_data)
sc.tl.leiden(ann_data, key_added="leiden_clusters")
if infer_mode == "SC":
sc.pl.umap(ann_data, color=['leiden_clusters'],show = False)
plt.savefig(os.path.join(savePath_2,"leiden cluster.png"))
if infer_mode == "ST":
fig, axs = plt.subplots(1, 2, figsize=(8, 4)) # Create a 1x2 grid for the plots
ann_data.obsm["spatial"] = np.array(ann_data.obsm["spatial"],dtype = float)
sc.pl.spatial(ann_data, img_key="hires", color=["leiden_clusters"],show = False,ax=axs[0])
sc.pl.umap(ann_data, color=["leiden_clusters"],show = False, ax=axs[1])
plt.tight_layout() # Ensure proper spacing between the two plots
plt.savefig(os.path.join(savePath_2,"leiden cluster.png"))
plt.close()
return ann_data
# This function computes the cell similarity network for single-cell or spatial transcriptomics data.
[docs]
def compute_similarity(savePath, ann_data, calculate_distance=False):
"""
Extracts and saves the cell/spot similarity matrix (connectivities).
Optionally, it can also calculate a spatial distance-based adjacency
matrix (6 nearest neighbors) for ST data.
Args:
savePath (str): The main project directory path.
ann_data (sc.AnnData): A clustered AnnData object (must have
`ann_data.obsp['connectivities']`).
calculate_distance (bool, optional): Whether to compute the spatial
distance matrix (ST only). Defaults to False.
Returns:
None
"""
savePath_2 = os.path.join(savePath,"2_preprocessing")
# data_path refers to the output directory from the Space Ranger.
# perform_normalization indicates whether the input data needs to be normalized.
# Obtain the cell-cell similarity matrix
cell_cell_similarity = ann_data.obsp['connectivities']
dense_similarity_matrix = cell_cell_similarity.toarray()
similarity_df = pd.DataFrame(
dense_similarity_matrix, columns=ann_data.obs_names, index=ann_data.obs_names)
if calculate_distance:
# Obtain the spatial positions and calculate the Euclidean distances
spatial_positions = ann_data.obsm['spatial']
euclidean_distances = cdist(
spatial_positions, spatial_positions, metric='euclidean')
# Create an adjacency matrix initialized with zeros
adjacency_matrix = np.zeros_like(euclidean_distances, dtype=int)
# For each spot, mark the six closest spots as neighbors
for i in range(adjacency_matrix.shape[0]):
# Get the indices of the six smallest distances
# Skip the 0th index because it's the distance to itself
closest_indices = euclidean_distances[i].argsort()[1:7]
adjacency_matrix[i, closest_indices] = 1
distance_df = pd.DataFrame(
adjacency_matrix, columns=ann_data.obs_names, index=ann_data.obs_names)
with open(os.path.join(savePath_2, 'distance_df.pkl'), 'wb') as f:
pickle.dump(distance_df, f)
f.close()
with open(os.path.join(savePath_2, 'similarity_df.pkl'), 'wb') as f:
pickle.dump(similarity_df, f)
f.close()
return None
# This function calculates the cell subpopulation mean rank.
[docs]
def calculate_populations_meanRank(input_data, category):
"""
Calculates the mean feature values for each cell subpopulation (category).
Args:
input_data (pd.DataFrame): Input DataFrame (samples x features).
category (pd.Series): A Series indicating the category (e.g., cluster)
of each sample. Must share the same index as `input_data`.
Returns:
pd.DataFrame: A DataFrame where rows are categories and columns
are the mean of features for that category.
"""
# First, ensure the category Series has the same index as the input_data DataFrame
category.index = input_data.index
# Combine the category series with input dataframe
input_data_combined = pd.concat(
[input_data, category.rename('Category')], axis=1)
# Now group by the 'Category' column and find the mean of each group
meanrank_df = input_data_combined.groupby('Category').mean()
print(f"Cell subpopulation mean feature calculation done!")
return meanrank_df