import matplotlib.pyplot as plt
from PIL import Image
import torch
import timm
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as F
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from itertools import repeat
import collections.abc
import os
"""
This module provides utilities for processing spatial transcriptomics (ST)
H&E (Hematoxylin and Eosin) images. It includes functions for:
1. Cropping image tiles (patches) centered on each ST spot.
2. Generating feature embeddings from these tiles using a pre-trained
CTransPath (Swin Transformer) model.
3. Clustering these embeddings using PCA and K-means to identify
spatial "pathological classes".
"""
def _ntuple(n):
"""
Private helper function to create a tuple of size n.
Args:
n (int): The desired size of the tuple.
Returns:
function: A function that parses its input into a tuple of size n.
"""
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
[docs]
class ConvStem(nn.Module):
"""
Custom Convolutional Stem for the CTransPath model (replaces the patch embed layer).
This stem uses a series of convolutions to create patch embeddings instead
of a single large-kernel convolution.
Args:
img_size (int, optional): The size of the input image. Defaults to 224.
patch_size (int, optional): The size of the patch. Must be 4. Defaults to 4.
in_chans (int, optional): Number of input image channels. Defaults to 3.
embed_dim (int, optional): The dimension of the output embedding. Defaults to 768.
norm_layer (nn.Module, optional): Normalization layer to use. Defaults to None.
flatten (bool, optional): Whether to flatten the output. Defaults to True.
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
assert patch_size == 4
assert embed_dim % 8 == 0
to_2tuple = _ntuple(2)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
stem = []
input_dim, output_dim = 3, embed_dim // 8
for l in range(2):
stem.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=2, padding=1, bias=False))
stem.append(nn.BatchNorm2d(output_dim))
stem.append(nn.ReLU(inplace=True))
input_dim = output_dim
output_dim *= 2
stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
self.proj = nn.Sequential(*stem)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
[docs]
def forward(self, x):
"""
Forward pass of the convolutional stem.
Args:
x (torch.Tensor): Input tensor of shape (B, C, H, W).
Returns:
torch.Tensor: Output patch embeddings.
"""
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
[docs]
def ctranspath():
"""
Factory function to create the CTransPath model.
Initializes a Swin Transformer (swin_tiny_patch4_window7_224) with
the custom ConvStem as the embedding layer.
Returns:
torch.nn.Module: The CTransPath model instance.
"""
model = timm.create_model('swin_tiny_patch4_window7_224', embed_layer=ConvStem, pretrained=False)
return model
[docs]
def scale_coordinate(data):
"""
Scales ST spot coordinates to match the high-resolution image.
This function reads the scale factor from the AnnData object and applies
it to the spatial coordinates, adding the results to `data.obs` as
'imagecol' and 'imagerow'.
Args:
data (anndata.AnnData): The AnnData object containing spatial info.
Returns:
anndata.AnnData: The AnnData object, modified in place.
"""
"""Convert imagecol and imagerow into high-resolution coordinates."""
library_id = list(data.uns["spatial"].keys())[0]
scale = data.uns["spatial"][library_id]["scalefactors"]["tissue_hires_scalef"]
if type(scale) != type(0.001):
scale = float(scale)
if type(data.obsm["spatial"][1,1]) == type('a'):
data.obsm["spatial"] = data.obsm["spatial"].astype("float")
image_coordinates = data.obsm["spatial"] * scale
data.obs["imagecol"] = image_coordinates[:, 0]
data.obs["imagerow"] = image_coordinates[:, 1]
return data
[docs]
def crop_images(data, crop_size=25):
"""
Crops image tiles (patches) from the H&E slide for each spot.
Uses the 'imagecol' and 'imagerow' coordinates from `data.obs` to
crop square patches of (2*crop_size) x (2*crop_size) from the
high-resolution image.
Args:
data (anndata.AnnData): The AnnData object, after running `scale_coordinate`.
crop_size (int, optional): The "radius" for cropping. A size of 25
creates 50x50 pixel tiles. Defaults to 25.
Returns:
np.ndarray: A NumPy array stack of image tiles of shape
(n_spots, 2*crop_size, 2*crop_size, 3).
"""
"""Crop image based on crop_size and return an array of cropped images."""
data = scale_coordinate(data)
library_id = list(data.uns["spatial"].keys())[0]
image_data = data.uns["spatial"][library_id]["images"]["hires"]
# img = Image.fromarray(image_data)
img = Image.fromarray((image_data * 255).astype(np.uint8))
cropped_images = [
img.crop((col - crop_size, row - crop_size, col + crop_size, row + crop_size))
for row, col in zip(data.obs["imagerow"], data.obs["imagecol"])
]
return np.stack([np.array(tile) / 255 for tile in cropped_images])
[docs]
class ImageDataset(Dataset):
"""
PyTorch Dataset for loading image tiles (patches).
Args:
images_array (np.ndarray): A stack of image tiles from `crop_images`.
transform (callable, optional): A torchvision transform to apply to each
image. Defaults to None.
"""
def __init__(self, images_array, transform=None):
self.images_array = images_array
self.transform = transform
def __len__(self):
"""Returns the total number of images (spots)."""
return self.images_array.shape[0]
def __getitem__(self, idx):
"""
Gets a single image and applies transformations.
Args:
idx (int): The index of the image to fetch.
Returns:
torch.Tensor: The transformed image tensor.
"""
image = self.images_array[idx]
if self.transform:
# Convert numpy image to PIL image for transformation
image = F.to_pil_image(image)
image = self.transform(image)
return image
[docs]
@staticmethod
def collate_fn(batch):
"""
Static collate function to stack images into a batch.
Args:
batch (list): A list of image tensors.
Returns:
torch.Tensor: A stacked batch of images.
"""
images = torch.stack(batch, dim=0)
return images
[docs]
def infer_by_pretrain(images, pretrain_path):
"""
Generates feature embeddings from image tiles using the pre-trained CTransPath.
Args:
images (np.ndarray): A stack of image tiles (n_spots, H, W, C).
pretrain_path (str): Path to the CTransPath model's .pth weights file.
Returns:
torch.Tensor: A tensor of feature embeddings of shape (n_spots, n_features).
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pre_model = ctranspath().to(device)
pre_model.head = nn.Identity()
pre_model.load_state_dict(torch.load(pretrain_path, map_location=device)['model'])
pre_model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), antialias=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
images_set = ImageDataset(images, transform=transform)
dataloader = DataLoader(images_set, batch_size=32, shuffle=False, num_workers=4, collate_fn=ImageDataset.collate_fn)
with torch.no_grad():
features = [pre_model(batch.to(device)).cpu() for batch in dataloader]
return torch.cat(features, dim=0)
[docs]
def process_embeddings(embeddings, n_components, n_clusters):
"""
Performs PCA dimensionality reduction and K-means clustering.
Args:
embeddings (torch.Tensor or np.ndarray): The feature embeddings.
n_components (int): The number of principal components to keep.
n_clusters (int): The number of clusters to find (k).
Returns:
tuple: A tuple containing:
- np.ndarray: The PCA-transformed embeddings.
- np.ndarray: The cluster labels for each spot.
"""
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.numpy()
pca = PCA(n_components=n_components)
pca_embeddings = pca.fit_transform(embeddings)
cluster_labels = KMeans(n_clusters=n_clusters, random_state=0).fit_predict(pca_embeddings)
return pca_embeddings, cluster_labels
[docs]
def plot_patho_class_heatmap(data, save_path):
"""
Generates and saves a spatial scatter plot of pathological classes.
Args:
data (anndata.AnnData): The AnnData object with 'patho_class' in `.obs`.
save_path (str): The file path to save the resulting plot.
Returns:
None
"""
"""Plots a heatmap based on 'patho_class' labels and image coordinates."""
# Extracting data
x_coords = data.obs["array_col"].values
y_coords = data.obs["array_row"].values
if type(x_coords[0]) == type('a'):
x_coords = x_coords.astype("int")
y_coords = y_coords.astype("int")
labels = data.obs["patho_class"].values
# Number of unique labels (assuming they are sequential integers starting from 0)
unique_labels = np.unique(labels)
num_labels = len(unique_labels)
# Define a list of distinct colors. Add more colors if you have more classes.
colors = ['red', 'blue', 'green', 'yellow', 'purple', 'orange', 'cyan', 'magenta', 'brown', 'pink', 'lime', 'violet', 'indigo', 'gold', 'crimson']
if num_labels > len(colors):
print(f"Warning: Not enough unique colors for {num_labels} clusters. Some colors will be repeated.")
colors = colors * (num_labels // len(colors) + 1)
# Create scatter plot with distinct colors
plt.figure(figsize=(10, 10))
for i, label in enumerate(unique_labels):
mask = (labels == label)
plt.scatter(x_coords[mask], y_coords[mask], color=colors[i], label=label, s=10)
# Add legend, axis labels, and title
plt.legend()
plt.xlabel("Image Column")
plt.ylabel("Image Row")
plt.title("Patho Class Heatmap")
# Display the plot
plt.gca().invert_yaxis() # Invert y-axis for typical image display
plt.savefig(save_path,bbox_inches ="tight", pad_inches = 1)
return None
[docs]
def GetPathoClass(adata, pretrain_path, n_components = 50, n_clusters = 6, plot_classes = True, image_save_path = None):
"""
Orchestrates the full image processing pipeline.
Crops tiles, generates embeddings using CTransPath, performs PCA and
K-means, and adds the results to the AnnData object.
Args:
adata (anndata.AnnData): The AnnData object for an ST sample.
pretrain_path (str): Path to the CTransPath model's .pth weights file.
n_components (int, optional): Number of PCA components. Defaults to 50.
n_clusters (int, optional): Number of K-means clusters. Defaults to 6.
plot_classes (bool, optional): Whether to generate and save the spatial
heatmap. Defaults to True.
image_save_path (str, optional): File path to save the heatmap.
Required if `plot_classes` is True. Defaults to None.
Returns:
anndata.AnnData: The AnnData object, modified in place with:
- `adata.obs["patho_class"]`: Cluster labels for each spot.
- `adata.obsm["patho_emd"]`: PCA embeddings for each spot.
"""
images = crop_images(adata)
features = infer_by_pretrain(images, pretrain_path)
# Example values for PCA and clustering
pca_embeddings, cluster_labels = process_embeddings(features, n_components, n_clusters)
adata.obs["patho_class"] = cluster_labels
adata.obsm["patho_emd"] = pca_embeddings
if plot_classes:
if image_save_path is None:
raise ValueError("'image_save_path' must be provided if 'plot_classes' is True.")
plot_patho_class_heatmap(adata,image_save_path)
return adata