tirank.Visualization.model_predict

tirank.Visualization.model_predict(model, data_tensor, mode)[source]

Generates predictions from a trained model based on the specified mode.

Parameters:
  • model (torch.nn.Module) – The trained PyTorch model to use for prediction.

  • data_tensor (torch.Tensor) – The input data as a PyTorch tensor.

  • mode (str) – The operational mode, determining how to interpret the model’s output. Expected values are “Cox”, “Classification”, or “Regression”.

Returns:

A tuple containing:
  • pred_label (np.ndarray): Predicted labels. For “Classification”, these are the class indices. For “Regression” and “Cox”, this is the same as pred_prob.

  • pred_prob (np.ndarray): Predicted probability scores. For “Classification”, this is the probability of class 1.

Return type:

tuple