tirank.TrainPre.objective

tirank.TrainPre.objective(trial, n_features, nhead, nhid1, nhid2, n_output, nlayers, n_pred, dropout, n_patho, mode, encoder_type, train_loader_Bulk, val_loader_Bulk, train_loader_SC, adj_A, adj_B, pre_patho_labels, device, infer_mode, model_save_path)[source]

The objective function for Optuna hyperparameter optimization.

This function defines the search space for hyperparameters (lr, epochs, loss weights), builds a model, trains it, and returns the validation loss.

Parameters:
  • trial (optuna.trial.Trial) – An Optuna trial object.

  • wrapper) ((All other args are passed from the tune_hyperparameters)

Returns:

The validation loss for the trial.

Return type:

float