tirank.TrainPre.objective
- tirank.TrainPre.objective(trial, n_features, nhead, nhid1, nhid2, n_output, nlayers, n_pred, dropout, n_patho, mode, encoder_type, train_loader_Bulk, val_loader_Bulk, train_loader_SC, adj_A, adj_B, pre_patho_labels, device, infer_mode, model_save_path)[source]
The objective function for Optuna hyperparameter optimization.
This function defines the search space for hyperparameters (lr, epochs, loss weights), builds a model, trains it, and returns the validation loss.
- Parameters:
trial (optuna.trial.Trial) – An Optuna trial object.
wrapper) ((All other args are passed from the tune_hyperparameters)
- Returns:
The validation loss for the trial.
- Return type:
float