tirank.Imageprocessing.ConvStem

class tirank.Imageprocessing.ConvStem(*args: Any, **kwargs: Any)[source]

Bases: Module

Custom Convolutional Stem for the CTransPath model (replaces the patch embed layer).

This stem uses a series of convolutions to create patch embeddings instead of a single large-kernel convolution.

Parameters:
  • img_size (int, optional) – The size of the input image. Defaults to 224.

  • patch_size (int, optional) – The size of the patch. Must be 4. Defaults to 4.

  • in_chans (int, optional) – Number of input image channels. Defaults to 3.

  • embed_dim (int, optional) – The dimension of the output embedding. Defaults to 768.

  • norm_layer (nn.Module, optional) – Normalization layer to use. Defaults to None.

  • flatten (bool, optional) – Whether to flatten the output. Defaults to True.

forward(x)[source]

Forward pass of the convolutional stem.

Parameters:

x (torch.Tensor) – Input tensor of shape (B, C, H, W).

Returns:

Output patch embeddings.

Return type:

torch.Tensor