tirank.Imageprocessing

class tirank.Imageprocessing.ConvStem(*args: Any, **kwargs: Any)[source]

Bases: Module

Custom Convolutional Stem for the CTransPath model (replaces the patch embed layer).

This stem uses a series of convolutions to create patch embeddings instead of a single large-kernel convolution.

Parameters:
  • img_size (int, optional) – The size of the input image. Defaults to 224.

  • patch_size (int, optional) – The size of the patch. Must be 4. Defaults to 4.

  • in_chans (int, optional) – Number of input image channels. Defaults to 3.

  • embed_dim (int, optional) – The dimension of the output embedding. Defaults to 768.

  • norm_layer (nn.Module, optional) – Normalization layer to use. Defaults to None.

  • flatten (bool, optional) – Whether to flatten the output. Defaults to True.

forward(x)[source]

Forward pass of the convolutional stem.

Parameters:

x (torch.Tensor) – Input tensor of shape (B, C, H, W).

Returns:

Output patch embeddings.

Return type:

torch.Tensor

tirank.Imageprocessing.ctranspath()[source]

Factory function to create the CTransPath model.

Initializes a Swin Transformer (swin_tiny_patch4_window7_224) with the custom ConvStem as the embedding layer.

Returns:

The CTransPath model instance.

Return type:

torch.nn.Module

tirank.Imageprocessing.scale_coordinate(data)[source]

Scales ST spot coordinates to match the high-resolution image.

This function reads the scale factor from the AnnData object and applies it to the spatial coordinates, adding the results to data.obs as ‘imagecol’ and ‘imagerow’.

Parameters:

data (anndata.AnnData) – The AnnData object containing spatial info.

Returns:

The AnnData object, modified in place.

Return type:

anndata.AnnData

tirank.Imageprocessing.crop_images(data, crop_size=25)[source]

Crops image tiles (patches) from the H&E slide for each spot.

Uses the ‘imagecol’ and ‘imagerow’ coordinates from data.obs to crop square patches of (2*crop_size) x (2*crop_size) from the high-resolution image.

Parameters:
  • data (anndata.AnnData) – The AnnData object, after running scale_coordinate.

  • crop_size (int, optional) – The “radius” for cropping. A size of 25 creates 50x50 pixel tiles. Defaults to 25.

Returns:

A NumPy array stack of image tiles of shape

(n_spots, 2*crop_size, 2*crop_size, 3).

Return type:

np.ndarray

class tirank.Imageprocessing.ImageDataset(*args: Any, **kwargs: Any)[source]

Bases: Dataset

PyTorch Dataset for loading image tiles (patches).

Parameters:
  • images_array (np.ndarray) – A stack of image tiles from crop_images.

  • transform (callable, optional) – A torchvision transform to apply to each image. Defaults to None.

static collate_fn(batch)[source]

Static collate function to stack images into a batch.

Parameters:

batch (list) – A list of image tensors.

Returns:

A stacked batch of images.

Return type:

torch.Tensor

tirank.Imageprocessing.infer_by_pretrain(images, pretrain_path)[source]

Generates feature embeddings from image tiles using the pre-trained CTransPath.

Parameters:
  • images (np.ndarray) – A stack of image tiles (n_spots, H, W, C).

  • pretrain_path (str) – Path to the CTransPath model’s .pth weights file.

Returns:

A tensor of feature embeddings of shape (n_spots, n_features).

Return type:

torch.Tensor

tirank.Imageprocessing.process_embeddings(embeddings, n_components, n_clusters)[source]

Performs PCA dimensionality reduction and K-means clustering.

Parameters:
  • embeddings (torch.Tensor or np.ndarray) – The feature embeddings.

  • n_components (int) – The number of principal components to keep.

  • n_clusters (int) – The number of clusters to find (k).

Returns:

A tuple containing:
  • np.ndarray: The PCA-transformed embeddings.

  • np.ndarray: The cluster labels for each spot.

Return type:

tuple

tirank.Imageprocessing.plot_patho_class_heatmap(data, save_path)[source]

Generates and saves a spatial scatter plot of pathological classes.

Parameters:
  • data (anndata.AnnData) – The AnnData object with ‘patho_class’ in .obs.

  • save_path (str) – The file path to save the resulting plot.

Returns:

None

tirank.Imageprocessing.GetPathoClass(adata, pretrain_path, n_components=50, n_clusters=6, plot_classes=True, image_save_path=None)[source]

Orchestrates the full image processing pipeline.

Crops tiles, generates embeddings using CTransPath, performs PCA and K-means, and adds the results to the AnnData object.

Parameters:
  • adata (anndata.AnnData) – The AnnData object for an ST sample.

  • pretrain_path (str) – Path to the CTransPath model’s .pth weights file.

  • n_components (int, optional) – Number of PCA components. Defaults to 50.

  • n_clusters (int, optional) – Number of K-means clusters. Defaults to 6.

  • plot_classes (bool, optional) – Whether to generate and save the spatial heatmap. Defaults to True.

  • image_save_path (str, optional) – File path to save the heatmap. Required if plot_classes is True. Defaults to None.

Returns:

The AnnData object, modified in place with:
  • adata.obs[“patho_class”]: Cluster labels for each spot.

  • adata.obsm[“patho_emd”]: PCA embeddings for each spot.

Return type:

anndata.AnnData

Functions

tirank.Imageprocessing.GetPathoClass

Orchestrates the full image processing pipeline.

tirank.Imageprocessing.crop_images

Crops image tiles (patches) from the H&E slide for each spot.

tirank.Imageprocessing.ctranspath

Factory function to create the CTransPath model.

tirank.Imageprocessing.infer_by_pretrain

Generates feature embeddings from image tiles using the pre-trained CTransPath.

tirank.Imageprocessing.plot_patho_class_heatmap

Generates and saves a spatial scatter plot of pathological classes.

tirank.Imageprocessing.process_embeddings

Performs PCA dimensionality reduction and K-means clustering.

tirank.Imageprocessing.scale_coordinate

Scales ST spot coordinates to match the high-resolution image.

Classes

tirank.Imageprocessing.ConvStem

Custom Convolutional Stem for the CTransPath model (replaces the patch embed layer).

tirank.Imageprocessing.ImageDataset

PyTorch Dataset for loading image tiles (patches).