tirank.Visualization.model_predict
- tirank.Visualization.model_predict(model, data_tensor, mode)[source]
Generates predictions from a trained model based on the specified mode.
- Parameters:
model (torch.nn.Module) – The trained PyTorch model to use for prediction.
data_tensor (torch.Tensor) – The input data as a PyTorch tensor.
mode (str) – The operational mode, determining how to interpret the model’s output. Expected values are “Cox”, “Classification”, or “Regression”.
- Returns:
- A tuple containing:
pred_label (np.ndarray): Predicted labels. For “Classification”, these are the class indices. For “Regression” and “Cox”, this is the same as pred_prob.
pred_prob (np.ndarray): Predicted probability scores. For “Classification”, this is the probability of class 1.
- Return type:
tuple