tirank.Imageprocessing
- class tirank.Imageprocessing.ConvStem(*args: Any, **kwargs: Any)[source]
Bases:
ModuleCustom Convolutional Stem for the CTransPath model (replaces the patch embed layer).
This stem uses a series of convolutions to create patch embeddings instead of a single large-kernel convolution.
- Parameters:
img_size (int, optional) – The size of the input image. Defaults to 224.
patch_size (int, optional) – The size of the patch. Must be 4. Defaults to 4.
in_chans (int, optional) – Number of input image channels. Defaults to 3.
embed_dim (int, optional) – The dimension of the output embedding. Defaults to 768.
norm_layer (nn.Module, optional) – Normalization layer to use. Defaults to None.
flatten (bool, optional) – Whether to flatten the output. Defaults to True.
- forward(x)[source]
Forward pass of the convolutional stem.
- Parameters:
x (torch.Tensor) – Input tensor of shape (B, C, H, W).
- Returns:
Output patch embeddings.
- Return type:
torch.Tensor
- tirank.Imageprocessing.ctranspath()[source]
Factory function to create the CTransPath model.
Initializes a Swin Transformer (swin_tiny_patch4_window7_224) with the custom ConvStem as the embedding layer.
- Returns:
The CTransPath model instance.
- Return type:
torch.nn.Module
- tirank.Imageprocessing.scale_coordinate(data)[source]
Scales ST spot coordinates to match the high-resolution image.
This function reads the scale factor from the AnnData object and applies it to the spatial coordinates, adding the results to data.obs as ‘imagecol’ and ‘imagerow’.
- Parameters:
data (anndata.AnnData) – The AnnData object containing spatial info.
- Returns:
The AnnData object, modified in place.
- Return type:
anndata.AnnData
- tirank.Imageprocessing.crop_images(data, crop_size=25)[source]
Crops image tiles (patches) from the H&E slide for each spot.
Uses the ‘imagecol’ and ‘imagerow’ coordinates from data.obs to crop square patches of (2*crop_size) x (2*crop_size) from the high-resolution image.
- Parameters:
data (anndata.AnnData) – The AnnData object, after running scale_coordinate.
crop_size (int, optional) – The “radius” for cropping. A size of 25 creates 50x50 pixel tiles. Defaults to 25.
- Returns:
- A NumPy array stack of image tiles of shape
(n_spots, 2*crop_size, 2*crop_size, 3).
- Return type:
np.ndarray
- class tirank.Imageprocessing.ImageDataset(*args: Any, **kwargs: Any)[source]
Bases:
DatasetPyTorch Dataset for loading image tiles (patches).
- Parameters:
images_array (np.ndarray) – A stack of image tiles from crop_images.
transform (callable, optional) – A torchvision transform to apply to each image. Defaults to None.
- static collate_fn(batch)[source]
Static collate function to stack images into a batch.
- Parameters:
batch (list) – A list of image tensors.
- Returns:
A stacked batch of images.
- Return type:
torch.Tensor
- tirank.Imageprocessing.infer_by_pretrain(images, pretrain_path)[source]
Generates feature embeddings from image tiles using the pre-trained CTransPath.
- Parameters:
images (np.ndarray) – A stack of image tiles (n_spots, H, W, C).
pretrain_path (str) – Path to the CTransPath model’s .pth weights file.
- Returns:
A tensor of feature embeddings of shape (n_spots, n_features).
- Return type:
torch.Tensor
- tirank.Imageprocessing.process_embeddings(embeddings, n_components, n_clusters)[source]
Performs PCA dimensionality reduction and K-means clustering.
- Parameters:
embeddings (torch.Tensor or np.ndarray) – The feature embeddings.
n_components (int) – The number of principal components to keep.
n_clusters (int) – The number of clusters to find (k).
- Returns:
- A tuple containing:
np.ndarray: The PCA-transformed embeddings.
np.ndarray: The cluster labels for each spot.
- Return type:
tuple
- tirank.Imageprocessing.plot_patho_class_heatmap(data, save_path)[source]
Generates and saves a spatial scatter plot of pathological classes.
- Parameters:
data (anndata.AnnData) – The AnnData object with ‘patho_class’ in .obs.
save_path (str) – The file path to save the resulting plot.
- Returns:
None
- tirank.Imageprocessing.GetPathoClass(adata, pretrain_path, n_components=50, n_clusters=6, plot_classes=True, image_save_path=None)[source]
Orchestrates the full image processing pipeline.
Crops tiles, generates embeddings using CTransPath, performs PCA and K-means, and adds the results to the AnnData object.
- Parameters:
adata (anndata.AnnData) – The AnnData object for an ST sample.
pretrain_path (str) – Path to the CTransPath model’s .pth weights file.
n_components (int, optional) – Number of PCA components. Defaults to 50.
n_clusters (int, optional) – Number of K-means clusters. Defaults to 6.
plot_classes (bool, optional) – Whether to generate and save the spatial heatmap. Defaults to True.
image_save_path (str, optional) – File path to save the heatmap. Required if plot_classes is True. Defaults to None.
- Returns:
- The AnnData object, modified in place with:
adata.obs[“patho_class”]: Cluster labels for each spot.
adata.obsm[“patho_emd”]: PCA embeddings for each spot.
- Return type:
anndata.AnnData
Functions
Orchestrates the full image processing pipeline. |
|
Crops image tiles (patches) from the H&E slide for each spot. |
|
Factory function to create the CTransPath model. |
|
Generates feature embeddings from image tiles using the pre-trained CTransPath. |
|
Generates and saves a spatial scatter plot of pathological classes. |
|
Performs PCA dimensionality reduction and K-means clustering. |
|
Scales ST spot coordinates to match the high-resolution image. |
Classes
Custom Convolutional Stem for the CTransPath model (replaces the patch embed layer). |
|
PyTorch Dataset for loading image tiles (patches). |