diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index ecb5fe69..4ca3cd0f 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -80,7 +80,7 @@ def print_full_help(): "usage: flexynesis [-h] --data_path DATA_PATH --model_class " "{DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN," "RandomForest,SVM,XGBoost,RandomSurvivalForest} " - "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] " + "[--gnn_conv_type {GC,GCN,SAGE,GAT}] [--target_variables TARGET_VARIABLES] " "[--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] " "[--surv_time_var SURV_TIME_VAR] [--config_path CONFIG_PATH] " "[--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] " @@ -130,7 +130,7 @@ def print_full_help(): "CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) print(" (Required) The kind of model class to instantiate") - print(" --gnn_conv_type {GC,GCN,SAGE}") + print(" --gnn_conv_type {GC,GCN,SAGE,GAT}") print( " If model_class is set to GNN, choose which graph " "convolution type to use" @@ -337,7 +337,7 @@ def main(): **Required** for training. --gnn_conv_type (str): - If `--model_class=GNN`, choose graph convolution: ["GC", "GCN", "SAGE"]. + If `--model_class=GNN`, choose graph convolution: ["GC", "GCN", "SAGE", "GAT"]. --target_variables (str): Comma-separated target variables from `clin.csv`. Optional if survival @@ -534,7 +534,7 @@ def main(): parser.add_argument( "--gnn_conv_type", type=str, - choices=["GC", "GCN", "SAGE"], + choices=["GC", "GCN", "SAGE", "GAT"], help="If model_class is set to GNN, choose which graph convolution type to use", ) parser.add_argument( @@ -1047,7 +1047,7 @@ def main(): if args.model_class == "GNN": if not args.gnn_conv_type: warnings.warn( - "\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). " + "\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE/GAT). " "Falling back to GC !!!\n" ) time.sleep(3)