From 36bafaf4f157cef007d1d8f026548f5a820f8469 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 16:27:51 +0200 Subject: [PATCH 01/32] use black to reformat the scripts --- flexynesis/__init__.py | 47 +- flexynesis/__main__.py | 1176 ++++++++++++++----- flexynesis/config.py | 76 +- flexynesis/data.py | 754 ++++++++---- flexynesis/feature_selection.py | 139 ++- flexynesis/generate_coexpression_network.py | 194 +-- flexynesis/inference.py | 99 +- flexynesis/main.py | 445 ++++--- flexynesis/models/__init__.py | 9 +- flexynesis/models/crossmodal_pred.py | 472 +++++--- flexynesis/models/direct_pred.py | 351 ++++-- flexynesis/models/gnn_early.py | 365 ++++-- flexynesis/models/supervised_vae.py | 457 ++++--- flexynesis/models/triplet_encoder.py | 477 +++++--- flexynesis/modules.py | 141 ++- flexynesis/utils.py | 980 ++++++++++------ tests/test_mps_device.py | 38 +- 17 files changed, 4140 insertions(+), 2080 deletions(-) diff --git a/flexynesis/__init__.py b/flexynesis/__init__.py index eba3fd9d..fe80b0bf 100644 --- a/flexynesis/__init__.py +++ b/flexynesis/__init__.py @@ -13,52 +13,60 @@ except PackageNotFoundError: # Happens in some dev scenarios before the package is installed __version__ = "0+unknown" - + + class LazyModule: """Lazy module that only imports when accessed.""" - + def __init__(self, module_name): self._module_name = module_name self._module = None self._import_error = None - + def _import_module(self): if self._module is None and self._import_error is None: try: import importlib - self._module = importlib.import_module(f'.{self._module_name}', package=__name__) + + self._module = importlib.import_module( + f".{self._module_name}", package=__name__ + ) except ImportError as e: self._import_error = e - raise ImportError(f"Failed to import {self._module_name} module: {e}. " - f"This usually means some dependencies are missing. " - f"Try installing required packages or check your environment.") + raise ImportError( + f"Failed to import {self._module_name} module: {e}. " + f"This usually means some dependencies are missing. " + f"Try installing required packages or check your environment." + ) elif self._import_error is not None: raise self._import_error return self._module - + def __getattr__(self, name): module = self._import_module() return getattr(module, name) - + def __dir__(self): if self._module is None: return [] module = self._import_module() return dir(module) - + def __repr__(self): if self._module is None: return f"" else: return f"" + # Create lazy module proxies - these are NOT imported yet -modules = LazyModule('modules') -data = LazyModule('data') -main = LazyModule('main') -models = LazyModule('models') -feature_selection = LazyModule('feature_selection') -utils = LazyModule('utils') +modules = LazyModule("modules") +data = LazyModule("data") +main = LazyModule("main") +models = LazyModule("models") +feature_selection = LazyModule("feature_selection") +utils = LazyModule("utils") + # Import commonly used classes directly for easy access # These will be imported lazily when first accessed @@ -66,19 +74,20 @@ def _get_data_importer(): """Lazy getter for DataImporter class.""" return data.DataImporter + def _get_models(): """Lazy getter for model classes.""" return models + # Export all modules and commonly used classes __all__ = [ "search_spaces", "modules", - "data", + "data", "main", "models", "feature_selection", "utils", - "DataImporter" + "DataImporter", ] - diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index a5d9c04f..f6cc00b2 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -16,46 +16,68 @@ def print_test_installation(): print("Test Installation:") print(" # Download and extract test dataset") - print(" curl -L -o dataset1.tgz https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dataset1.tgz") + print( + " curl -L -o dataset1.tgz https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dataset1.tgz" + ) print(" tar -xzvf dataset1.tgz") print() print(" # Test the installation (should finish within a minute on a typical CPU)") - print(" flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --hpo_iter 1 --features_top_percentile 5 --data_types gex,cnv") + print( + " flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --hpo_iter 1 --features_top_percentile 5 --data_types gex,cnv" + ) def print_help(): - print("usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} --data_types DATA_TYPES") + print( + "usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} --data_types DATA_TYPES" + ) print() print("Flexynesis model training interface") print() print("options:") print(" -h, --help show complete help with all options") print(" --data_path DATA_PATH") - print(" (Required) Path to the folder with train/test data files") - print(" --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}") + print( + " (Required) Path to the folder with train/test data files" + ) + print( + " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" + ) print(" (Required) The kind of model class to instantiate") print(" --data_types DATA_TYPES") - print(" (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'") - print(" --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)") + print( + " (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'" + ) + print( + " --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)" + ) print(" --device {auto,cuda,mps,cpu}") - print(" Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)") - print(" --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.") + print( + " Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" + ) + print( + " --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available." + ) print() print_test_installation() print() - print(" See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/.") + print( + " See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/." + ) def print_full_help(): - print("usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} " - "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] [--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] [--surv_time_var SURV_TIME_VAR] " - "[--config_path CONFIG_PATH] [--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] [--finetuning_samples FINETUNING_SAMPLES] " - "[--variance_threshold VARIANCE_THRESHOLD] [--correlation_threshold CORRELATION_THRESHOLD] [--restrict_to_features RESTRICT_TO_FEATURES] " - "[--subsample SUBSAMPLE] [--features_min FEATURES_MIN] [--features_top_percentile FEATURES_TOP_PERCENTILE] --data_types DATA_TYPES " - "[--input_layers INPUT_LAYERS] [--output_layers OUTPUT_LAYERS] [--outdir OUTDIR] [--prefix PREFIX] [--log_transform {True,False}] " - "[--early_stop_patience EARLY_STOP_PATIENCE] [--hpo_patience HPO_PATIENCE] [--val_size VAL_SIZE] [--use_cv] [--use_loss_weighting {True,False}] " - "[--evaluate_baseline_performance] [--threads THREADS] [--num_workers NUM_WORKERS] [--device {auto,cuda,mps,cpu}] [--use_gpu] [--feature_importance_method {IntegratedGradients,GradientShap,Both}] " - "[--disable_marker_finding] [--string_organism STRING_ORGANISM] [--string_node_name {gene_name,gene_id}] [--user_graph USER_GRAPH] [--safetensors]") + print( + "usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} " + "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] [--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] [--surv_time_var SURV_TIME_VAR] " + "[--config_path CONFIG_PATH] [--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] [--finetuning_samples FINETUNING_SAMPLES] " + "[--variance_threshold VARIANCE_THRESHOLD] [--correlation_threshold CORRELATION_THRESHOLD] [--restrict_to_features RESTRICT_TO_FEATURES] " + "[--subsample SUBSAMPLE] [--features_min FEATURES_MIN] [--features_top_percentile FEATURES_TOP_PERCENTILE] --data_types DATA_TYPES " + "[--input_layers INPUT_LAYERS] [--output_layers OUTPUT_LAYERS] [--outdir OUTDIR] [--prefix PREFIX] [--log_transform {True,False}] " + "[--early_stop_patience EARLY_STOP_PATIENCE] [--hpo_patience HPO_PATIENCE] [--val_size VAL_SIZE] [--use_cv] [--use_loss_weighting {True,False}] " + "[--evaluate_baseline_performance] [--threads THREADS] [--num_workers NUM_WORKERS] [--device {auto,cuda,mps,cpu}] [--use_gpu] [--feature_importance_method {IntegratedGradients,GradientShap,Both}] " + "[--disable_marker_finding] [--string_organism STRING_ORGANISM] [--string_node_name {gene_name,gene_id}] [--user_graph USER_GRAPH] [--safetensors]" + ) print() print("Flexynesis model training interface") print() @@ -63,7 +85,9 @@ def print_full_help(): # --- NEW: inference-only flags (keep in full help) --- print(" --pretrained_model PRETRAINED_MODEL") - print(" Use a saved .pth/.safetensors model for inference (skip training)") + print( + " Use a saved .pth/.safetensors model for inference (skip training)" + ) print(" --artifacts ARTIFACTS") print(" Path to training-time artifacts .joblib or .json") print(" --data_path_test DATA_PATH_TEST") @@ -73,83 +97,159 @@ def print_full_help(): # --- existing flags (keep full list) --- print(" -h, --help show this help message and exit") print(" --data_path DATA_PATH") - print(" (Required) Path to the folder with train/test data files") - print(" --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}") + print( + " (Required) Path to the folder with train/test data files" + ) + print( + " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" + ) print(" (Required) The kind of model class to instantiate") print(" --gnn_conv_type {GC,GCN,SAGE}") - print(" If model_class is set to GNN, choose which graph convolution type to use") + print( + " If model_class is set to GNN, choose which graph convolution type to use" + ) print(" --target_variables TARGET_VARIABLES") - print(" (Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple") + print( + " (Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple" + ) print(" --covariates COVARIATES") - print(" Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple") + print( + " Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple" + ) print(" --surv_event_var SURV_EVENT_VAR") - print(" Which column in 'clin.csv' to use as event/status indicator for survival modeling") + print( + " Which column in 'clin.csv' to use as event/status indicator for survival modeling" + ) print(" --surv_time_var SURV_TIME_VAR") - print(" Which column in 'clin.csv' to use as time/duration indicator for survival modeling") + print( + " Which column in 'clin.csv' to use as time/duration indicator for survival modeling" + ) print(" --config_path CONFIG_PATH") - print(" Optional path to an external hyperparameter configuration file in YAML format.") + print( + " Optional path to an external hyperparameter configuration file in YAML format." + ) print(" --fusion_type {early,intermediate}") - print(" How to fuse the omics layers (default: intermediate)") - print(" --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)") + print( + " How to fuse the omics layers (default: intermediate)" + ) + print( + " --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)" + ) print(" --finetuning_samples FINETUNING_SAMPLES") - print(" Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning (default: 0)") + print( + " Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning (default: 0)" + ) print(" --variance_threshold VARIANCE_THRESHOLD") - print(" Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)") + print( + " Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)" + ) print(" --correlation_threshold CORRELATION_THRESHOLD") - print(" Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)") + print( + " Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)" + ) print(" --restrict_to_features RESTRICT_TO_FEATURES") - print(" Restrict the analysis to the list of features provided by the user (default is None)") + print( + " Restrict the analysis to the list of features provided by the user (default is None)" + ) print(" --subsample SUBSAMPLE") - print(" Downsample training set to randomly drawn N samples for training. Disabled when set to 0 (default: 0)") + print( + " Downsample training set to randomly drawn N samples for training. Disabled when set to 0 (default: 0)" + ) print(" --features_min FEATURES_MIN") - print(" Minimum number of features to retain after feature selection (default: 500)") + print( + " Minimum number of features to retain after feature selection (default: 500)" + ) print(" --features_top_percentile FEATURES_TOP_PERCENTILE") - print(" Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection (default: 20)") + print( + " Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection (default: 20)" + ) print(" --data_types DATA_TYPES") - print(" (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'") + print( + " (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'" + ) print(" --input_layers INPUT_LAYERS") - print(" If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple") + print( + " If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple" + ) print(" --output_layers OUTPUT_LAYERS") - print(" If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple") - print(" --outdir OUTDIR Path to the output folder to save the model outputs (default: current working directory)") + print( + " If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple" + ) + print( + " --outdir OUTDIR Path to the output folder to save the model outputs (default: current working directory)" + ) print(" --prefix PREFIX Job prefix to use for output files (default: 'job')") print(" --log_transform {True,False}") - print(" whether to apply log-transformation to input data matrices (default: False)") + print( + " whether to apply log-transformation to input data matrices (default: False)" + ) print(" --early_stop_patience EARLY_STOP_PATIENCE") - print(" How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)") + print( + " How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)" + ) print(" --hpo_patience HPO_PATIENCE") - print(" How many hyperparameter optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)") - print(" --val_size VAL_SIZE Proportion of training data to be used as validation split (default: 0.2)") - print(" --use_cv (Optional) If set, a 5-fold cross-validation training will be done. Otherwise, a single training on 80 percent of the dataset is done.") + print( + " How many hyperparameter optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)" + ) + print( + " --val_size VAL_SIZE Proportion of training data to be used as validation split (default: 0.2)" + ) + print( + " --use_cv (Optional) If set, a 5-fold cross-validation training will be done. Otherwise, a single training on 80 percent of the dataset is done." + ) print(" --use_loss_weighting {True,False}") - print(" whether to apply loss-balancing using uncertainty weights method (default: True)") + print( + " whether to apply loss-balancing using uncertainty weights method (default: True)" + ) print(" --evaluate_baseline_performance") - print(" whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset") - print(" --threads THREADS (Optional) How many threads to use when using CPU (default is 4)") + print( + " whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset" + ) + print( + " --threads THREADS (Optional) How many threads to use when using CPU (default is 4)" + ) print(" --num_workers NUM_WORKERS") - print(" (Optional) How many workers to use for model training (default is 0)") + print( + " (Optional) How many workers to use for model training (default is 0)" + ) print(" --device {auto,cuda,mps,cpu}") - print(" Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)") - print(" --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.") + print( + " Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" + ) + print( + " --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available." + ) print(" --feature_importance_method {IntegratedGradients,GradientShap,Both}") - print(" Choose feature importance score method (default: IntegratedGradients)") + print( + " Choose feature importance score method (default: IntegratedGradients)" + ) print(" --disable_marker_finding") - print(" (Optional) If set, marker discovery after model training is disabled.") + print( + " (Optional) If set, marker discovery after model training is disabled." + ) print(" --string_organism STRING_ORGANISM") print(" STRING DB organism id. (default: 9606)") print(" --string_node_name {gene_name,gene_id}") print(" Type of node name. (default: gene_name)") print(" --user_graph USER_GRAPH") - print(" Path to user-provided gene-gene interaction network file.") + print( + " Path to user-provided gene-gene interaction network file." + ) print(" Must have at least 3 columns: GeneA, GeneB, Score.") - print(" If provided, this will be used instead of STRING DB.") + print( + " If provided, this will be used instead of STRING DB." + ) print(" --safetensors") - print(" If set, the model will be saved in the SafeTensors format and the artifacts saved as JSON.") + print( + " If set, the model will be saved in the SafeTensors format and the artifacts saved as JSON." + ) print(" Default is False.") print() print_test_installation() print() - print(" See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/.") + print( + " See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/." + ) def main(): @@ -327,112 +427,289 @@ def main(): if len(sys.argv) == 1: print_help() return - if any(arg in ['-h', '--help'] for arg in sys.argv): + if any(arg in ["-h", "--help"] for arg in sys.argv): print_full_help() return # ------------- Parser (lightweight) ------------- parser = argparse.ArgumentParser( description="Flexynesis model training interface", - formatter_class=argparse.ArgumentDefaultsHelpFormatter + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("-v", "--version", action="version", version=f"%(prog)s {__version__}") + parser.add_argument( + "-v", "--version", action="version", version=f"%(prog)s {__version__}" + ) # Existing core flags (made not-required here; enforced conditionally below) - parser.add_argument("--data_path", type=str, required=False, - help="Path to the folder with train/test data files") - parser.add_argument("--model_class", type=str, required=False, - choices=["DirectPred", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNN", "RandomForest", "SVM", "XGBoost", "RandomSurvivalForest"], - help="The kind of model class to instantiate") - parser.add_argument("--gnn_conv_type", type=str, choices=["GC", "GCN", "SAGE"], - help="If model_class is set to GNN, choose which graph convolution type to use") - parser.add_argument("--target_variables", type=str, default=None, - help="(Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple") - parser.add_argument("--covariates", type=str, default=None, - help="Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple") - parser.add_argument("--surv_event_var", type=str, default=None, - help="Which column in 'clin.csv' to use as event/status indicator for survival modeling") - parser.add_argument("--surv_time_var", type=str, default=None, - help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling") - parser.add_argument('--config_path', type=str, default=None, - help='Optional path to an external hyperparameter configuration file in YAML format.') - parser.add_argument("--fusion_type", type=str, choices=["early", "intermediate"], default='intermediate', - help="How to fuse the omics layers") - parser.add_argument("--hpo_iter", type=int, default=100, - help="Number of iterations for hyperparameter optimisation") - parser.add_argument("--finetuning_samples", type=int, default=0, - help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning") - parser.add_argument("--variance_threshold", type=float, default=1, - help="Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)") - parser.add_argument("--correlation_threshold", type=float, default=0.8, - help="Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)") - parser.add_argument("--restrict_to_features", type=str, default=None, - help="Restrict the analyis to the list of features provided by the user (default is None)") - parser.add_argument("--subsample", type=int, default=0, - help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0") - parser.add_argument("--features_min", type=int, default=500, - help="Minimum number of features to retain after feature selection") - parser.add_argument("--features_top_percentile", type=float, default=20, - help="Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection") - parser.add_argument("--data_types", type=str, required=False, - help="Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'") - parser.add_argument("--input_layers", type=str, default=None, - help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple") - parser.add_argument("--output_layers", type=str, default=None, - help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple") - parser.add_argument("--outdir", type=str, default=os.getcwd(), - help="Path to the output folder to save the model outputs") - parser.add_argument("--prefix", type=str, default='job', - help="Job prefix to use for output files") - parser.add_argument("--log_transform", type=str, choices=['True', 'False'], default='False', - help="whether to apply log-transformation to input data matrices") - parser.add_argument("--early_stop_patience", type=int, default=10, - help="How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)") - parser.add_argument("--hpo_patience", type=int, default=20, - help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)") - parser.add_argument("--val_size", type=float, default=0.2, - help="Proportion of training data to be used as validation split (default: 0.2)") - parser.add_argument("--use_cv", action="store_true", - help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainig on 80 percent of the dataset is done.") - parser.add_argument("--use_loss_weighting", type=str, choices=['True', 'False'], default='True', - help="whether to apply loss-balancing using uncertainty weights method") - parser.add_argument("--evaluate_baseline_performance", action="store_true", - help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset") - parser.add_argument("--threads", type=int, default=4, - help="(Optional) How many threads to use when using CPU (default is 4)") - parser.add_argument("--num_workers", type=int, default=0, - help="(Optional) How many workers to use for model training (default is 0)") - parser.add_argument("--use_gpu", action="store_true", - help="(Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.") - parser.add_argument("--device", type=str, - choices=["auto", "cuda", "mps", "cpu"], default="auto", - help="Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu'") - parser.add_argument("--feature_importance_method", type=str, - choices=["IntegratedGradients", "GradientShap", "Both"], default="IntegratedGradients", - help="Choose feature importance score method") - parser.add_argument("--disable_marker_finding", action="store_true", - help="(Optional) If set, marker discovery after model training is disabled.") + parser.add_argument( + "--data_path", + type=str, + required=False, + help="Path to the folder with train/test data files", + ) + parser.add_argument( + "--model_class", + type=str, + required=False, + choices=[ + "DirectPred", + "supervised_vae", + "MultiTripletNetwork", + "CrossModalPred", + "GNN", + "RandomForest", + "SVM", + "XGBoost", + "RandomSurvivalForest", + ], + help="The kind of model class to instantiate", + ) + parser.add_argument( + "--gnn_conv_type", + type=str, + choices=["GC", "GCN", "SAGE"], + help="If model_class is set to GNN, choose which graph convolution type to use", + ) + parser.add_argument( + "--target_variables", + type=str, + default=None, + help="(Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", + ) + parser.add_argument( + "--covariates", + type=str, + default=None, + help="Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple", + ) + parser.add_argument( + "--surv_event_var", + type=str, + default=None, + help="Which column in 'clin.csv' to use as event/status indicator for survival modeling", + ) + parser.add_argument( + "--surv_time_var", + type=str, + default=None, + help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help="Optional path to an external hyperparameter configuration file in YAML format.", + ) + parser.add_argument( + "--fusion_type", + type=str, + choices=["early", "intermediate"], + default="intermediate", + help="How to fuse the omics layers", + ) + parser.add_argument( + "--hpo_iter", + type=int, + default=100, + help="Number of iterations for hyperparameter optimisation", + ) + parser.add_argument( + "--finetuning_samples", + type=int, + default=0, + help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning", + ) + parser.add_argument( + "--variance_threshold", + type=float, + default=1, + help="Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)", + ) + parser.add_argument( + "--correlation_threshold", + type=float, + default=0.8, + help="Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)", + ) + parser.add_argument( + "--restrict_to_features", + type=str, + default=None, + help="Restrict the analyis to the list of features provided by the user (default is None)", + ) + parser.add_argument( + "--subsample", + type=int, + default=0, + help="Downsample training set to randomly drawn N samples for training. Disabled when set to 0", + ) + parser.add_argument( + "--features_min", + type=int, + default=500, + help="Minimum number of features to retain after feature selection", + ) + parser.add_argument( + "--features_top_percentile", + type=float, + default=20, + help="Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection", + ) + parser.add_argument( + "--data_types", + type=str, + required=False, + help="Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'", + ) + parser.add_argument( + "--input_layers", + type=str, + default=None, + help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple", + ) + parser.add_argument( + "--output_layers", + type=str, + default=None, + help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple", + ) + parser.add_argument( + "--outdir", + type=str, + default=os.getcwd(), + help="Path to the output folder to save the model outputs", + ) + parser.add_argument( + "--prefix", type=str, default="job", help="Job prefix to use for output files" + ) + parser.add_argument( + "--log_transform", + type=str, + choices=["True", "False"], + default="False", + help="whether to apply log-transformation to input data matrices", + ) + parser.add_argument( + "--early_stop_patience", + type=int, + default=10, + help="How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)", + ) + parser.add_argument( + "--hpo_patience", + type=int, + default=20, + help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)", + ) + parser.add_argument( + "--val_size", + type=float, + default=0.2, + help="Proportion of training data to be used as validation split (default: 0.2)", + ) + parser.add_argument( + "--use_cv", + action="store_true", + help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainig on 80 percent of the dataset is done.", + ) + parser.add_argument( + "--use_loss_weighting", + type=str, + choices=["True", "False"], + default="True", + help="whether to apply loss-balancing using uncertainty weights method", + ) + parser.add_argument( + "--evaluate_baseline_performance", + action="store_true", + help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset", + ) + parser.add_argument( + "--threads", + type=int, + default=4, + help="(Optional) How many threads to use when using CPU (default is 4)", + ) + parser.add_argument( + "--num_workers", + type=int, + default=0, + help="(Optional) How many workers to use for model training (default is 0)", + ) + parser.add_argument( + "--use_gpu", + action="store_true", + help="(Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available.", + ) + parser.add_argument( + "--device", + type=str, + choices=["auto", "cuda", "mps", "cpu"], + default="auto", + help="Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu'", + ) + parser.add_argument( + "--feature_importance_method", + type=str, + choices=["IntegratedGradients", "GradientShap", "Both"], + default="IntegratedGradients", + help="Choose feature importance score method", + ) + parser.add_argument( + "--disable_marker_finding", + action="store_true", + help="(Optional) If set, marker discovery after model training is disabled.", + ) # GNN args. - parser.add_argument("--string_organism", type=int, default=9606, - help="STRING DB organism id.") - parser.add_argument("--string_node_name", type=str, choices=["gene_name", "gene_id"], default="gene_name", - help="Type of node name.") - parser.add_argument("--user_graph", type=str, default=None, - help="Path to user-provided gene-gene interaction network file. " - "Must have at least 3 columns: GeneA, GeneB, Score. " - "If provided, this will be used instead of STRING DB.") + parser.add_argument( + "--string_organism", type=int, default=9606, help="STRING DB organism id." + ) + parser.add_argument( + "--string_node_name", + type=str, + choices=["gene_name", "gene_id"], + default="gene_name", + help="Type of node name.", + ) + parser.add_argument( + "--user_graph", + type=str, + default=None, + help="Path to user-provided gene-gene interaction network file. " + "Must have at least 3 columns: GeneA, GeneB, Score. " + "If provided, this will be used instead of STRING DB.", + ) # safetensors args - parser.add_argument("--safetensors", action="store_true", - help="If set, use SafeTensors + JSON artifacts for save/load (training and inference). Default is False.") + parser.add_argument( + "--safetensors", + action="store_true", + help="If set, use SafeTensors + JSON artifacts for save/load (training and inference). Default is False.", + ) # NEW: inference flags - parser.add_argument("--pretrained_model", type=str, default=None, - help="Path to a saved model (.pth/.safetensors) to use for inference") - parser.add_argument("--artifacts", type=str, default=None, - help="Path to artifacts .joblib or .json saved during training") - parser.add_argument("--data_path_test", type=str, default=None, - help="Folder with test-only dataset for inference") - parser.add_argument("--join_key", type=str, default="JoinKey", - help="Column name in 'clin.csv' (test metadata) used to join sample IDs") + parser.add_argument( + "--pretrained_model", + type=str, + default=None, + help="Path to a saved model (.pth/.safetensors) to use for inference", + ) + parser.add_argument( + "--artifacts", + type=str, + default=None, + help="Path to artifacts .joblib or .json saved during training", + ) + parser.add_argument( + "--data_path_test", + type=str, + default=None, + help="Folder with test-only dataset for inference", + ) + parser.add_argument( + "--join_key", + type=str, + default="JoinKey", + help="Column name in 'clin.csv' (test metadata) used to join sample IDs", + ) args = parser.parse_args() @@ -445,10 +722,16 @@ def main(): # Only require core training flags if NOT doing inference in_infer = bool(args.pretrained_model) if not in_infer: - missing = [k for k in ("data_path", "model_class", "data_types") if not getattr(args, k, None)] + missing = [ + k + for k in ("data_path", "model_class", "data_types") + if not getattr(args, k, None) + ] if missing: - parser.error("the following arguments are required in training mode: " + - ", ".join(f"--{m}" for m in missing)) + parser.error( + "the following arguments are required in training mode: " + + ", ".join(f"--{m}" for m in missing) + ) # ---------- Inference mode: early exit path ---------- if args.pretrained_model and args.artifacts and args.data_path_test: @@ -457,16 +740,22 @@ def main(): # quick existence checks if not os.path.exists(args.pretrained_model): - raise FileNotFoundError(f"--pretrained_model not found: {args.pretrained_model}") + raise FileNotFoundError( + f"--pretrained_model not found: {args.pretrained_model}" + ) if not os.path.exists(args.artifacts): raise FileNotFoundError(f"--artifacts not found: {args.artifacts}") # Handle device selection for inference (same logic as training) if args.use_gpu: - warnings.warn("--use_gpu is deprecated. Use --device instead.", DeprecationWarning) + warnings.warn( + "--use_gpu is deprecated. Use --device instead.", DeprecationWarning + ) if args.device != "auto": device_preference = args.device - print(f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}.") + print( + f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}." + ) else: # Let auto-detection find the best GPU device (CUDA or MPS) device_preference = "auto" @@ -479,21 +768,25 @@ def main(): # Check model file content type from .inference import check_model_type + model_format = check_model_type(args.pretrained_model) if args.safetensors and model_format != "safetensors": - raise ValueError(f"[ERROR] The file {args.pretrained_model} is not a valid safetensors file.") + raise ValueError( + f"[ERROR] The file {args.pretrained_model} is not a valid safetensors file." + ) # Route to safetensors reconstruction or standard torch.load if model_format == "safetensors": from .inference import reconstruct_model + # Derive config path from model basename model_base = os.path.splitext(args.pretrained_model)[0] - config_path = model_base + '_config.json' + config_path = model_base + "_config.json" if not os.path.exists(config_path): # Try alongside the safetensors file with standard naming - base = args.pretrained_model.replace('.final_model.safetensors', '') - config_path = base + '.final_model_config.json' + base = args.pretrained_model.replace(".final_model.safetensors", "") + config_path = base + ".final_model_config.json" if not os.path.exists(config_path): raise FileNotFoundError( f"Cannot find config JSON for safetensors model. " @@ -508,12 +801,16 @@ def main(): else: # Standard .pth load — robust across PyTorch versions try: - model = torch.load(args.pretrained_model, map_location=device, weights_only=False) + model = torch.load( + args.pretrained_model, map_location=device, weights_only=False + ) except TypeError: model = torch.load(args.pretrained_model, map_location=device) except Exception: try: - model = torch.load(args.pretrained_model, map_location=device, weights_only=True) + model = torch.load( + args.pretrained_model, map_location=device, weights_only=True + ) except TypeError: model = torch.load(args.pretrained_model, map_location=device) @@ -524,41 +821,58 @@ def main(): # Load test data using DataImporterInference from .data import DataImporterInference - print('[INFO] Loading test data for inference...') + + print("[INFO] Loading test data for inference...") importer = DataImporterInference( test_data_path=args.data_path_test, artifacts_path=args.artifacts, - verbose=True + verbose=True, ) test_dataset = importer.import_data() # Convert to GNN dataset if needed - if args.model_class == 'GNN': + if args.model_class == "GNN": print("[INFO] Overlaying the dataset with network data from STRINGDB") from .main import STRING from .data import MultiOmicDatasetNW + # Get STRING organism from artifacts - string_organism = importer.artifacts.get('string_organism', 9606) # default human - string_node_name = importer.artifacts.get('string_node_name', 'HGNC') + string_organism = importer.artifacts.get( + "string_organism", 9606 + ) # default human + string_node_name = importer.artifacts.get("string_node_name", "HGNC") # Load STRING network obj = STRING( - os.path.join(args.data_path_test, '_'.join(['processed', args.prefix])), + os.path.join(args.data_path_test, "_".join(["processed", args.prefix])), string_organism, - string_node_name + string_node_name, ) # Get modality order from artifacts for consistent feature stacking - modality_order = importer.artifacts.get("original_modalities", importer.artifacts.get("data_types")) - test_dataset = MultiOmicDatasetNW(test_dataset, obj.graph_df, modality_order=modality_order) + modality_order = importer.artifacts.get( + "original_modalities", importer.artifacts.get("data_types") + ) + test_dataset = MultiOmicDatasetNW( + test_dataset, obj.graph_df, modality_order=modality_order + ) train_dataset = None # No training data in inference mode # Move dataset to same device as model - if hasattr(test_dataset, 'to_device'): + if hasattr(test_dataset, "to_device"): test_dataset.to_device(device) - print(f'[INFO] Test dataset loaded: {len(test_dataset.samples)} samples') + print(f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples") # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import evaluate_baseline_performance, evaluate_baseline_survival_performance, get_predicted_labels, evaluate_wrapper, get_optimal_device, get_device_memory_info, create_device_from_string + from .utils import ( + evaluate_baseline_performance, + evaluate_baseline_survival_performance, + get_predicted_labels, + evaluate_wrapper, + get_optimal_device, + get_device_memory_info, + create_device_from_string, + ) + if not (args.pretrained_model and args.artifacts and args.data_path_test): import flexynesis from lightning import seed_everything @@ -584,29 +898,38 @@ def main(): # --------- Sanity checks on args --------- # 1. survival variables consistency if (args.surv_event_var is None) != (args.surv_time_var is None): - parser.error("Both --surv_event_var and --surv_time_var must be provided together or left as None.") + parser.error( + "Both --surv_event_var and --surv_time_var must be provided together or left as None." + ) # 2. required variables for model classes if args.model_class not in ("supervised_vae", "CrossModalPred"): if not any([args.target_variables, args.surv_event_var]): - parser.error("When selecting a model other than 'supervised_vae' or 'CrossModalPred', you must provide at least one of --target_variables, or survival variables (--surv_event_var and --surv_time_var)") + parser.error( + "When selecting a model other than 'supervised_vae' or 'CrossModalPred', you must provide at least one of --target_variables, or survival variables (--surv_event_var and --surv_time_var)" + ) # 3. Check for compatibility of fusion_type with CrossModalPred if args.fusion_type == "early": - if args.model_class == 'CrossModalPred': - parser.error("The 'CrossModalPred' model cannot be used with early fusion type. " - "Use --fusion_type intermediate instead.") - + if args.model_class == "CrossModalPred": + parser.error( + "The 'CrossModalPred' model cannot be used with early fusion type. " + "Use --fusion_type intermediate instead." + ) # 4. Handle device selection with MPS support # Support legacy --use_gpu flag for backward compatibility if args.use_gpu: - warnings.warn("--use_gpu is deprecated. Use --device instead.", DeprecationWarning) + warnings.warn( + "--use_gpu is deprecated. Use --device instead.", DeprecationWarning + ) # If --device is not explicitly set (still at default auto), let auto-detection handle it if args.device != "auto": # If both --use_gpu and explicit --device are provided, respect --device but warn device_preference = args.device - print(f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}.") + print( + f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}." + ) else: # Let auto-detection find the best GPU device (CUDA or MPS) device_preference = "auto" @@ -618,18 +941,20 @@ def main(): # Print device information print(f"[INFO] Using device: {device_str}") - if device_str != 'cpu': + if device_str != "cpu": memory_info = get_device_memory_info(device_str) print(f"[INFO] Device name: {memory_info['device_name']}") - if device_str == 'cuda': + if device_str == "cuda": print(f"[INFO] Available CUDA devices: {memory_info['device_count']}") # gnn - if args.model_class == 'GNN': + if args.model_class == "GNN": if not args.gnn_conv_type: - warnings.warn("\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). Falling back to GC !!!\n") + warnings.warn( + "\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). Falling back to GC !!!\n" + ) time.sleep(3) - gnn_conv_type = 'GC' + gnn_conv_type = "GC" else: gnn_conv_type = args.gnn_conv_type else: @@ -638,20 +963,26 @@ def main(): # CrossModalPred IO layers input_layers = args.input_layers output_layers = args.output_layers - datatypes = args.data_types.strip().split(',') - if args.model_class == 'CrossModalPred': + datatypes = args.data_types.strip().split(",") + if args.model_class == "CrossModalPred": if args.input_layers: - input_layers = input_layers.strip().split(',') + input_layers = input_layers.strip().split(",") if not all(layer in datatypes for layer in input_layers): - raise ValueError(f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes}).") + raise ValueError( + f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes})." + ) if args.output_layers: - output_layers = output_layers.strip().split(',') + output_layers = output_layers.strip().split(",") if not all(layer in datatypes for layer in output_layers): - raise ValueError(f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes}).") + raise ValueError( + f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes})." + ) # paths if not os.path.exists(args.data_path): - raise FileNotFoundError(f"Input --data_path doesn't exist at: {args.data_path}") + raise FileNotFoundError( + f"Input --data_path doesn't exist at: {args.data_path}" + ) if not os.path.exists(args.outdir): raise FileNotFoundError(f"Path to --outdir doesn't exist at: {args.outdir}") @@ -673,16 +1004,18 @@ def main(): model_class, config_name = model_info # Set concatenate to True to use early fusion, otherwise intermediate - concatenate = args.fusion_type == 'early' and args.model_class != 'GNN' + concatenate = args.fusion_type == "early" and args.model_class != "GNN" # covariates if args.covariates: - if args.model_class == 'GNN': # Covariates not yet supported for GNNs - warnings.warn("\n\n!!! Covariates are currently not supported for GNN models, they will be ignored. !!!\n") + if args.model_class == "GNN": # Covariates not yet supported for GNNs + warnings.warn( + "\n\n!!! Covariates are currently not supported for GNN models, they will be ignored. !!!\n" + ) time.sleep(3) covariates = None else: - covariates = args.covariates.strip().split(',') + covariates = args.covariates.strip().split(",") else: covariates = None @@ -691,14 +1024,14 @@ def main(): data_types=datatypes, covariates=covariates, concatenate=concatenate, - log_transform=args.log_transform == 'True', + log_transform=args.log_transform == "True", variance_threshold=args.variance_threshold / 100, correlation_threshold=args.correlation_threshold, restrict_to_features=args.restrict_to_features, min_features=args.features_min, top_percentile=args.features_top_percentile, - processed_dir='_'.join(['processed', args.prefix]), - downsample=args.subsample + processed_dir="_".join(["processed", args.prefix]), + downsample=args.subsample, ) # import data @@ -720,72 +1053,125 @@ def main(): ) if args.model_class in ["RandomForest", "SVM", "XGBoost"]: if args.target_variables: - var = args.target_variables.strip().split(',')[0] + var = args.target_variables.strip().split(",")[0] print(f"Training {args.model_class} on variable: {var}") metrics, predictions = evaluate_baseline_performance( - train_dataset, test_dataset, variable_name=var, methods=[args.model_class], n_folds=5, n_jobs=args.threads + train_dataset, + test_dataset, + variable_name=var, + methods=[args.model_class], + n_folds=5, + n_jobs=args.threads, + ) + metrics.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + header=True, + index=False, + ) + predictions.to_csv( + os.path.join( + args.outdir, ".".join([args.prefix, "predicted_labels.csv"]) + ), + header=True, + index=False, ) - metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False) - predictions.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False) print(f"{args.model_class} evaluation complete. Results saved.") sys.exit(0) else: - raise ValueError("At least one target variable is required to run RandomForest/SVM/XGBoost models. Set --target_variables") + raise ValueError( + "At least one target variable is required to run RandomForest/SVM/XGBoost models. Set --target_variables" + ) if args.model_class == "RandomSurvivalForest": if args.surv_event_var and args.surv_time_var: - print(f"Training {args.model_class} on survival variables: {args.surv_event_var} and {args.surv_time_var}") + print( + f"Training {args.model_class} on survival variables: {args.surv_event_var} and {args.surv_time_var}" + ) metrics, predictions = evaluate_baseline_survival_performance( - train_dataset, test_dataset, args.surv_time_var, args.surv_event_var, n_folds=5, n_jobs=int(args.threads) + train_dataset, + test_dataset, + args.surv_time_var, + args.surv_event_var, + n_folds=5, + n_jobs=int(args.threads), + ) + metrics.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + header=True, + index=False, + ) + predictions.to_csv( + os.path.join( + args.outdir, ".".join([args.prefix, "predicted_labels.csv"]) + ), + header=True, + index=False, ) - metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False) - predictions.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False) print(f"{args.model_class} evaluation complete. Results saved.") sys.exit(0) else: - raise ValueError("Missing survival variables. Set --surv_event_var --surv_time_var") + raise ValueError( + "Missing survival variables. Set --surv_event_var --surv_time_var" + ) # GNN overlay (training mode only - inference mode already handled this) - if args.model_class == 'GNN' and train_dataset is not None: + if args.model_class == "GNN" and train_dataset is not None: # Check if user provided a custom graph if args.user_graph is not None: print(f"[INFO] Using user-provided network from: {args.user_graph}") from flexynesis.data import read_user_graph + graph_df = read_user_graph(args.user_graph) print(f"[INFO] Loaded {len(graph_df)} interactions from user graph") else: # Fallback to STRING DB (default behavior) print("[INFO] No user graph provided. Using STRING DB network (default)") - obj = STRING(os.path.join(args.data_path, '_'.join(['processed', args.prefix])), - args.string_organism, args.string_node_name) + obj = STRING( + os.path.join(args.data_path, "_".join(["processed", args.prefix])), + args.string_organism, + args.string_node_name, + ) graph_df = obj.graph_df # Use data_types order from args for consistent modality ordering - modality_order = args.data_types.split(',') + modality_order = args.data_types.split(",") # Overlay the graph onto datasets - train_dataset = MultiOmicDatasetNW(train_dataset, graph_df, modality_order=modality_order) + train_dataset = MultiOmicDatasetNW( + train_dataset, graph_df, modality_order=modality_order + ) train_dataset.print_stats() - test_dataset = MultiOmicDatasetNW(test_dataset, graph_df, modality_order=modality_order) + test_dataset = MultiOmicDatasetNW( + test_dataset, graph_df, modality_order=modality_order + ) # Training only happens when train_dataset exists (not in inference mode) if train_dataset is not None: # feature logs feature_logs = data_importer.feature_logs for key in feature_logs.keys(): - feature_logs[key].to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_logs', key, 'csv'])), - header=True, index=False) + feature_logs[key].to_csv( + os.path.join( + args.outdir, ".".join([args.prefix, "feature_logs", key, "csv"]) + ), + header=True, + index=False, + ) # tuner tuner = HyperparameterTuning( dataset=train_dataset, model_class=model_class, - target_variables=args.target_variables.strip().split(',') if args.target_variables is not None else [], + target_variables=( + args.target_variables.strip().split(",") + if args.target_variables is not None + else [] + ), batch_variables=None, surv_event_var=args.surv_event_var, surv_time_var=args.surv_time_var, config_name=config_name, config_path=args.config_path, n_iter=int(args.hpo_iter), - use_loss_weighting=args.use_loss_weighting == 'True', + use_loss_weighting=args.use_loss_weighting == "True", val_size=args.val_size, use_cv=args.use_cv, early_stop_patience=int(args.early_stop_patience), @@ -793,7 +1179,7 @@ def main(): gnn_conv_type=gnn_conv_type, input_layers=input_layers, output_layers=output_layers, - num_workers=args.num_workers + num_workers=args.num_workers, ) # do a hyperparameter search training multiple models and get the best configuration @@ -809,6 +1195,7 @@ def main(): # split test dataset into finetuning and holdout datasets all_indices = range(len(test_dataset)) import random as _random + finetune_indices = _random.sample(list(all_indices), finetuneSampleN) holdout_indices = list(set(all_indices) - set(finetune_indices)) finetune_dataset = test_dataset.subset(finetune_indices) @@ -827,105 +1214,194 @@ def main(): print("[INFO] Extracting sample embeddings") if train_dataset is not None: embeddings_train = model.transform(train_dataset) - embeddings_train.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_train.csv'])), header=True) + embeddings_train.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "embeddings_train.csv"])), + header=True, + ) embeddings_test = model.transform(test_dataset) - embeddings_test.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'embeddings_test.csv'])), header=True) + embeddings_test.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "embeddings_test.csv"])), + header=True, + ) # evaluate predictions; (if any supervised learning happened) if any([args.target_variables, args.surv_event_var]): - if not args.disable_marker_finding and train_dataset is not None: # unless marker discovery is disabled + if ( + not args.disable_marker_finding and train_dataset is not None + ): # unless marker discovery is disabled # compute feature importance values - if args.feature_importance_method == 'Both': - explainers = ['IntegratedGradients', 'GradientShap'] + if args.feature_importance_method == "Both": + explainers = ["IntegratedGradients", "GradientShap"] else: explainers = [args.feature_importance_method] for explainer in explainers: - print("[INFO] Computing variable importance scores using explainer:", explainer) + print( + "[INFO] Computing variable importance scores using explainer:", + explainer, + ) for var in model.target_variables: - model.compute_feature_importance(train_dataset, var, steps_or_samples=25, method=explainer) + model.compute_feature_importance( + train_dataset, var, steps_or_samples=25, method=explainer + ) import pandas as pd - df_imp = pd.concat([model.feature_importances[x] for x in model.target_variables], ignore_index=True) - df_imp['explainer'] = explainer - df_imp.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'feature_importance', explainer, 'csv'])), header=True, index=False) + + df_imp = pd.concat( + [model.feature_importances[x] for x in model.target_variables], + ignore_index=True, + ) + df_imp["explainer"] = explainer + df_imp.to_csv( + os.path.join( + args.outdir, + ".".join([args.prefix, "feature_importance", explainer, "csv"]), + ), + header=True, + index=False, + ) # print known/predicted labels if train_dataset is not None: - predicted_labels = pd.concat([ - get_predicted_labels(model.predict(train_dataset), train_dataset, 'train', args.model_class), - get_predicted_labels(model.predict(test_dataset), test_dataset, 'test', args.model_class) - ], ignore_index=True) + predicted_labels = pd.concat( + [ + get_predicted_labels( + model.predict(train_dataset), + train_dataset, + "train", + args.model_class, + ), + get_predicted_labels( + model.predict(test_dataset), + test_dataset, + "test", + args.model_class, + ), + ], + ignore_index=True, + ) else: - predicted_labels = get_predicted_labels(model.predict(test_dataset), test_dataset, 'test', args.model_class) - predicted_labels.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'predicted_labels.csv'])), header=True, index=False) + predicted_labels = get_predicted_labels( + model.predict(test_dataset), test_dataset, "test", args.model_class + ) + predicted_labels.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "predicted_labels.csv"])), + header=True, + index=False, + ) print("[INFO] Computing model evaluation metrics") metrics_df = evaluate_wrapper( - args.model_class, model.predict(test_dataset), test_dataset, + args.model_class, + model.predict(test_dataset), + test_dataset, surv_event_var=model.surv_event_var, - surv_time_var=model.surv_time_var + surv_time_var=model.surv_time_var, + ) + metrics_df.to_csv( + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + header=True, + index=False, ) - metrics_df.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False) # for architectures with decoders; print decoded output layers - if args.model_class == 'CrossModalPred': + if args.model_class == "CrossModalPred": print("[INFO] Printing decoded output layers") # In inference mode, only decode test dataset if train_dataset is not None: output_layers_train = model.decode(train_dataset) for layer in output_layers_train.keys(): output_layers_train[layer].to_csv( - os.path.join(args.outdir, '.'.join([args.prefix, 'train_decoded', layer, 'csv'])), - header=True + os.path.join( + args.outdir, + ".".join([args.prefix, "train_decoded", layer, "csv"]), + ), + header=True, ) output_layers_test = model.decode(test_dataset) for layer in output_layers_test.keys(): output_layers_test[layer].to_csv( - os.path.join(args.outdir, '.'.join([args.prefix, 'test_decoded', layer, 'csv'])), - header=True + os.path.join( + args.outdir, ".".join([args.prefix, "test_decoded", layer, "csv"]) + ), + header=True, ) # evaluate off-the-shelf methods on the main target variable if args.evaluate_baseline_performance: - print("[INFO] Computing off-the-shelf method performance on first target variable:",model.target_variables[0]) + print( + "[INFO] Computing off-the-shelf method performance on first target variable:", + model.target_variables[0], + ) var = model.target_variables[0] metrics = pd.DataFrame() # in the case when GNNEarly was used, the we use the initial multiomicdataset for train/test # because GNNEarly requires a modified dataset structure to fit the networks (temporary solution) - if args.model_class == 'GNN': - train = getattr(train_dataset, 'multiomic_dataset', train_dataset) - test = getattr(test_dataset, 'multiomic_dataset', test_dataset) + if args.model_class == "GNN": + train = getattr(train_dataset, "multiomic_dataset", train_dataset) + test = getattr(test_dataset, "multiomic_dataset", test_dataset) else: train = train_dataset test = test_dataset if var != model.surv_event_var: - metrics, predictions = evaluate_baseline_performance(train, test, - variable_name = var, - methods = ['RandomForest', 'SVM', 'XGBoost'], - n_folds = 5, - n_jobs = int(args.threads)) - predictions.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'baseline.predicted_labels.csv'])), header=True, index=False) + metrics, predictions = evaluate_baseline_performance( + train, + test, + variable_name=var, + methods=["RandomForest", "SVM", "XGBoost"], + n_folds=5, + n_jobs=int(args.threads), + ) + predictions.to_csv( + os.path.join( + args.outdir, + ".".join([args.prefix, "baseline.predicted_labels.csv"]), + ), + header=True, + index=False, + ) if model.surv_event_var and model.surv_time_var: - print("[INFO] Computing off-the-shelf method performance on survival variable:",model.surv_time_var) - metrics_baseline_survival = evaluate_baseline_survival_performance(train, test, - model.surv_time_var, - model.surv_event_var, - n_folds = 5, - n_jobs = int(args.threads)) - metrics = pd.concat([metrics, metrics_baseline_survival], axis = 0, ignore_index = True) + print( + "[INFO] Computing off-the-shelf method performance on survival variable:", + model.surv_time_var, + ) + metrics_baseline_survival = evaluate_baseline_survival_performance( + train, + test, + model.surv_time_var, + model.surv_event_var, + n_folds=5, + n_jobs=int(args.threads), + ) + metrics = pd.concat( + [metrics, metrics_baseline_survival], axis=0, ignore_index=True + ) if not metrics.empty: - metrics.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'baseline.stats.csv'])), header=True, index=False) + metrics.to_csv( + os.path.join( + args.outdir, ".".join([args.prefix, "baseline.stats.csv"]) + ), + header=True, + index=False, + ) # save the trained model in file (skip in inference mode) if not in_infer: if not args.safetensors: - torch.save(model, os.path.join(args.outdir, '.'.join([args.prefix, 'final_model.pth']))) + torch.save( + model, + os.path.join(args.outdir, ".".join([args.prefix, "final_model.pth"])), + ) else: - save_file(model.state_dict(), os.path.join(args.outdir, '.'.join([args.prefix, 'final_model.safetensors']))) + save_file( + model.state_dict(), + os.path.join( + args.outdir, ".".join([args.prefix, "final_model.safetensors"]) + ), + ) # save model config as JSON config = { "model_class": model.__class__.__name__, @@ -933,50 +1409,90 @@ def main(): } # Common attributes to save common_attrs = [ - 'input_dims', 'layers', 'input_layers', 'output_layers', - 'device_type', 'target_variables', - 'surv_event_var', 'surv_time_var', - 'config', 'current_epoch', 'num_layers' + "input_dims", + "layers", + "input_layers", + "output_layers", + "device_type", + "target_variables", + "surv_event_var", + "surv_time_var", + "config", + "current_epoch", + "num_layers", ] for attr in common_attrs: if hasattr(model, attr): config[attr] = getattr(model, attr) - if hasattr(model, 'layers'): - config['num_layers'] = len(model.layers) + if hasattr(model, "layers"): + config["num_layers"] = len(model.layers) - if hasattr(model, 'config'): + if hasattr(model, "config"): model_specific_config = model.config config.update(model_specific_config) - with open(os.path.join(args.outdir, '.'.join([args.prefix, 'final_model_config.json'])), 'w') as f: + with open( + os.path.join( + args.outdir, ".".join([args.prefix, "final_model_config.json"]) + ), + "w", + ) as f: json.dump(config, f, indent=2, default=str) # --- write inference artifacts joblib (auto-generated after training) --- if train_dataset is not None: # Only save artifacts in training mode try: import joblib + artifacts = { - 'schema_version': 1, - 'data_types': list(data_importer.train_features.keys()) if hasattr(data_importer, 'train_features') else args.data_types.split(','), # Use actual data structure keys (e.g. ['all'] for early fusion) - 'original_modalities': args.data_types.split(','), # Original modalities from CLI before concatenation - 'target_variables': args.target_variables.split(',') if args.target_variables else [], - 'feature_lists': data_importer.train_features if hasattr(data_importer, 'train_features') else {}, - 'transforms': data_importer.scalers if hasattr(data_importer, 'scalers') else {}, - 'label_encoders': data_importer.label_encoders if hasattr(data_importer, 'label_encoders') else {}, - 'covariate_vars': covariates if covariates is not None else [], # Store covariate variable names - 'join_key': args.join_key, - 'string_organism': args.string_organism, - 'string_node_name': args.string_node_name, + "schema_version": 1, + "data_types": ( + list(data_importer.train_features.keys()) + if hasattr(data_importer, "train_features") + else args.data_types.split(",") + ), # Use actual data structure keys (e.g. ['all'] for early fusion) + "original_modalities": args.data_types.split( + "," + ), # Original modalities from CLI before concatenation + "target_variables": ( + args.target_variables.split(",") if args.target_variables else [] + ), + "feature_lists": ( + data_importer.train_features + if hasattr(data_importer, "train_features") + else {} + ), + "transforms": ( + data_importer.scalers if hasattr(data_importer, "scalers") else {} + ), + "label_encoders": ( + data_importer.label_encoders + if hasattr(data_importer, "label_encoders") + else {} + ), + "covariate_vars": ( + covariates if covariates is not None else [] + ), # Store covariate variable names + "join_key": args.join_key, + "string_organism": args.string_organism, + "string_node_name": args.string_node_name, } if not args.safetensors: - joblib_path = os.path.join(args.outdir, '.'.join([args.prefix, 'artifacts.joblib'])) + joblib_path = os.path.join( + args.outdir, ".".join([args.prefix, "artifacts.joblib"]) + ) joblib.dump(artifacts, joblib_path) - print(f'[INFO] Wrote inference artifacts to {joblib_path}') + print(f"[INFO] Wrote inference artifacts to {joblib_path}") elif args.safetensors: import numpy as np - from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler + from sklearn.preprocessing import ( + LabelEncoder, + OrdinalEncoder, + StandardScaler, + ) + json_ready = { "schema_version": artifacts["schema_version"], "data_types": artifacts["data_types"], @@ -986,7 +1502,10 @@ def main(): "join_key": artifacts["join_key"], "string_organism": artifacts["string_organism"], "string_node_name": artifacts["string_node_name"], - "feature_lists": {modality: list(features) for modality, features in artifacts["feature_lists"].items()}, + "feature_lists": { + modality: list(features) + for modality, features in artifacts["feature_lists"].items() + }, "transforms": {}, "label_encoders": {}, } @@ -997,16 +1516,33 @@ def main(): json_ready["transforms"][modality] = None continue if not isinstance(scaler, StandardScaler): - raise ValueError(f"Unsupported scaler type for modality '{modality}': {type(scaler).__name__}.") - scaler_dict = {"type": "StandardScaler", "with_mean": scaler.with_mean, "with_std": scaler.with_std} - if hasattr(scaler, "mean_"): scaler_dict["mean"] = scaler.mean_.tolist() - if hasattr(scaler, "scale_"): scaler_dict["scale"] = scaler.scale_.tolist() - if hasattr(scaler, "var_"): scaler_dict["var"] = scaler.var_.tolist() - if hasattr(scaler, "n_features_in_"): scaler_dict["n_features_in"] = int(scaler.n_features_in_) - if hasattr(scaler, "feature_names_in_"): scaler_dict["feature_names_in"] = scaler.feature_names_in_.tolist() + raise ValueError( + f"Unsupported scaler type for modality '{modality}': {type(scaler).__name__}." + ) + scaler_dict = { + "type": "StandardScaler", + "with_mean": scaler.with_mean, + "with_std": scaler.with_std, + } + if hasattr(scaler, "mean_"): + scaler_dict["mean"] = scaler.mean_.tolist() + if hasattr(scaler, "scale_"): + scaler_dict["scale"] = scaler.scale_.tolist() + if hasattr(scaler, "var_"): + scaler_dict["var"] = scaler.var_.tolist() + if hasattr(scaler, "n_features_in_"): + scaler_dict["n_features_in"] = int(scaler.n_features_in_) + if hasattr(scaler, "feature_names_in_"): + scaler_dict["feature_names_in"] = ( + scaler.feature_names_in_.tolist() + ) if hasattr(scaler, "n_samples_seen_"): n_samples = scaler.n_samples_seen_ - scaler_dict["n_samples_seen"] = n_samples.tolist() if isinstance(n_samples, np.ndarray) else int(n_samples) + scaler_dict["n_samples_seen"] = ( + n_samples.tolist() + if isinstance(n_samples, np.ndarray) + else int(n_samples) + ) json_ready["transforms"][modality] = scaler_dict # Convert LabelEncoder/OrdinalEncoder objects @@ -1015,7 +1551,10 @@ def main(): json_ready["label_encoders"][variable] = None continue if isinstance(encoder, LabelEncoder): - encoder_dict = {"type": "LabelEncoder", "classes": encoder.classes_.tolist()} + encoder_dict = { + "type": "LabelEncoder", + "classes": encoder.classes_.tolist(), + } elif isinstance(encoder, OrdinalEncoder): encoder_dict = { "type": "OrdinalEncoder", @@ -1025,21 +1564,40 @@ def main(): } if hasattr(encoder, "encoded_missing_value"): val = encoder.encoded_missing_value - encoder_dict["encoded_missing_value"] = "__NaN__" if isinstance(val, float) and np.isnan(val) else val - if hasattr(encoder, "n_features_in_"): encoder_dict["n_features_in"] = int(encoder.n_features_in_) - if hasattr(encoder, "feature_names_in_"): encoder_dict["feature_names_in"] = encoder.feature_names_in_.tolist() + encoder_dict["encoded_missing_value"] = ( + "__NaN__" + if isinstance(val, float) and np.isnan(val) + else val + ) + if hasattr(encoder, "n_features_in_"): + encoder_dict["n_features_in"] = int(encoder.n_features_in_) + if hasattr(encoder, "feature_names_in_"): + encoder_dict["feature_names_in"] = ( + encoder.feature_names_in_.tolist() + ) if hasattr(encoder, "_missing_indices"): missing_indices = encoder._missing_indices - encoder_dict["_missing_indices"] = {str(k): v for k, v in missing_indices.items()} if isinstance(missing_indices, dict) else missing_indices - if hasattr(encoder, "_infrequent_enabled"): encoder_dict["_infrequent_enabled"] = encoder._infrequent_enabled + encoder_dict["_missing_indices"] = ( + {str(k): v for k, v in missing_indices.items()} + if isinstance(missing_indices, dict) + else missing_indices + ) + if hasattr(encoder, "_infrequent_enabled"): + encoder_dict["_infrequent_enabled"] = ( + encoder._infrequent_enabled + ) else: - raise ValueError(f"Unknown encoder type: {type(encoder).__name__}") + raise ValueError( + f"Unknown encoder type: {type(encoder).__name__}" + ) json_ready["label_encoders"][variable] = encoder_dict - json_path = os.path.join(args.outdir, '.'.join([args.prefix, 'artifacts.json'])) + json_path = os.path.join( + args.outdir, ".".join([args.prefix, "artifacts.json"]) + ) with open(json_path, "w") as f: json.dump(json_ready, f, indent=2) - print(f'[INFO] Wrote inference artifacts to {json_path}') + print(f"[INFO] Wrote inference artifacts to {json_path}") except Exception as e: - print(f'[WARN] Could not write inference artifacts: {e}') + print(f"[WARN] Could not write inference artifacts: {e}") diff --git a/flexynesis/config.py b/flexynesis/config.py index 817347bd..1a77c1d1 100644 --- a/flexynesis/config.py +++ b/flexynesis/config.py @@ -4,41 +4,49 @@ epochs = [500] search_spaces = { - 'DirectPred': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Integer(8, 32, name='supervisor_hidden_dim'), - Categorical(epochs, name='epochs') - ], - 'supervised_vae': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Integer(8, 32, name='supervisor_hidden_dim'), - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Categorical(epochs, name='epochs') + "DirectPred": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Integer(8, 32, name="supervisor_hidden_dim"), + Categorical(epochs, name="epochs"), ], - 'CrossModalPred': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Integer(8, 32, name='supervisor_hidden_dim'), - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Categorical(epochs, name='epochs') + "supervised_vae": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Integer(8, 32, name="supervisor_hidden_dim"), + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Categorical(epochs, name="epochs"), ], - 'MultiTripletNetwork': [ - Integer(16, 128, name='latent_dim'), - Real(0.2, 0.5, name='hidden_dim_factor'), # relative size of the hidden_dim w.r.t input_dim - Integer(8, 32, name='supervisor_hidden_dim'), - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Categorical(epochs, name='epochs') + "CrossModalPred": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Integer(8, 32, name="supervisor_hidden_dim"), + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Categorical(epochs, name="epochs"), + ], + "MultiTripletNetwork": [ + Integer(16, 128, name="latent_dim"), + Real( + 0.2, 0.5, name="hidden_dim_factor" + ), # relative size of the hidden_dim w.r.t input_dim + Integer(8, 32, name="supervisor_hidden_dim"), + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Categorical(epochs, name="epochs"), + ], + "GNN": [ + Integer(16, 128, name="latent_dim"), + Integer(4, 32, name="node_embedding_dim"), # node embedding dimensions + Integer(1, 4, name="num_convs"), # number of convolutional layers + Real(0.0001, 0.01, prior="log-uniform", name="lr"), + Integer(8, 32, name="supervisor_hidden_dim"), + Categorical(epochs, name="epochs"), + Categorical(["relu"], name="activation"), ], - 'GNN': [ - Integer(16, 128, name='latent_dim'), - Integer(4, 32, name='node_embedding_dim'), # node embedding dimensions - Integer(1, 4, name='num_convs'), # number of convolutional layers - Real(0.0001, 0.01, prior='log-uniform', name='lr'), - Integer(8, 32, name='supervisor_hidden_dim'), - Categorical(epochs, name='epochs'), - Categorical(['relu'], name="activation") - ] } diff --git a/flexynesis/data.py b/flexynesis/data.py index 52545a6b..feee43b1 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -16,7 +16,12 @@ from tqdm import tqdm -from sklearn.preprocessing import OrdinalEncoder, StandardScaler, MinMaxScaler, PowerTransformer +from sklearn.preprocessing import ( + OrdinalEncoder, + StandardScaler, + MinMaxScaler, + PowerTransformer, +) from .feature_selection import filter_by_laplacian from .utils import get_variable_types, create_covariate_matrix @@ -92,8 +97,22 @@ class DataImporter: Encodes categorical labels in the annotation dataframe. """ - def __init__(self, path, data_types, covariates = None, processed_dir="processed", log_transform = False, concatenate = False, restrict_to_features = None, min_features=None, - top_percentile=20, correlation_threshold = 0.9, variance_threshold=0.01, na_threshold=0.1, downsample=0): + def __init__( + self, + path, + data_types, + covariates=None, + processed_dir="processed", + log_transform=False, + concatenate=False, + restrict_to_features=None, + min_features=None, + top_percentile=20, + correlation_threshold=0.9, + variance_threshold=0.01, + na_threshold=0.1, + downsample=0, + ): self.path = path self.data_types = data_types self.processed_dir = os.path.join(self.path, processed_dir) @@ -105,7 +124,7 @@ def __init__(self, path, data_types, covariates = None, processed_dir="processed self.na_threshold = na_threshold self.log_transform = log_transform # Initialize a dictionary to store the label encoders - self.encoders = {} # used if labels are categorical + self.encoders = {} # used if labels are categorical # initialize data scalers self.scalers = None # initialize data transformers @@ -124,8 +143,12 @@ def __init__(self, path, data_types, covariates = None, processed_dir="processed self.feature_logs = {} # NEW: Storage for inference mode artifacts - self.train_features = {} # Stores final feature names per modality after selection - self.label_encoders = {} # Stores fitted label encoders for categorical variables + self.train_features = ( + {} + ) # Stores final feature names per modality after selection + self.label_encoders = ( + {} + ) # Stores fitted label encoders for categorical variables # Note: self.scalers already exists, so we'll use that! def get_user_features(self): @@ -136,9 +159,11 @@ def get_user_features(self): if not os.path.isfile(self.restrict_to_features): raise FileNotFoundError(f"File not found: {self.restrict_to_features}") try: - with open(self.restrict_to_features, 'r') as fp: + with open(self.restrict_to_features, "r") as fp: # Read and process the file - feature_list = [x.strip() for x in fp.read().splitlines() if x.strip()] + feature_list = [ + x.strip() for x in fp.read().splitlines() if x.strip() + ] # Ensure uniqueness and assign self.restrict_to_features = np.unique(feature_list) except Exception as e: @@ -148,8 +173,8 @@ def get_user_features(self): def import_data(self): print("\n[INFO] ================= Importing Data =================") - training_path = os.path.join(self.path, 'train') - testing_path = os.path.join(self.path, 'test') + training_path = os.path.join(self.path, "train") + testing_path = os.path.join(self.path, "test") self.validate_data_folders(training_path, testing_path) @@ -158,7 +183,7 @@ def import_data(self): test_dat = self.read_data(testing_path) if self.downsample > 0: - print("[INFO] Randomly drawing",self.downsample,"samples for training") + print("[INFO] Randomly drawing", self.downsample, "samples for training") train_dat = self.subsample(train_dat, self.downsample) if self.restrict_to_features is not None: @@ -169,8 +194,12 @@ def import_data(self): self.validate_input_data(train_dat, test_dat) # cleanup uninformative features/samples, subset annotation data, do feature selection on training data - train_dat, train_ann, train_samples, train_features = self.process_data(train_dat, split = 'train') - test_dat, test_ann, test_samples, test_features = self.process_data(test_dat, split = 'test') + train_dat, train_ann, train_samples, train_features = self.process_data( + train_dat, split="train" + ) + test_dat, test_ann, test_samples, test_features = self.process_data( + test_dat, split="test" + ) # harmonize feature sets in train/test train_dat, test_dat = self.harmonize(train_dat, test_dat) @@ -188,9 +217,16 @@ def import_data(self): # if covariates are defined, create a covariate matrix and add to the dictionary of data matrices if self.covariates: - print("[INFO] Attempting to create a covariate matrix for the covariates:",self.covariates) - train_dat['covariates'] = create_covariate_matrix(self.covariates, get_variable_types(train_ann), train_ann) - test_dat['covariates'] = create_covariate_matrix(self.covariates, get_variable_types(test_ann), test_ann) + print( + "[INFO] Attempting to create a covariate matrix for the covariates:", + self.covariates, + ) + train_dat["covariates"] = create_covariate_matrix( + self.covariates, get_variable_types(train_ann), train_ann + ) + test_dat["covariates"] = create_covariate_matrix( + self.covariates, get_variable_types(test_ann), test_ann + ) # harmonize again to match the covariate features train_dat, test_dat = self.harmonize(train_dat, test_dat) @@ -202,27 +238,48 @@ def import_data(self): if self.concatenate: # Use data_types order for consistent concatenation modality_order = self.data_types - training_dataset.dat = {'all': torch.cat([training_dataset.dat[x] for x in modality_order], dim = 1)} - training_dataset.features = {'all': list(chain(*[training_dataset.features[x] for x in modality_order]))} + training_dataset.dat = { + "all": torch.cat( + [training_dataset.dat[x] for x in modality_order], dim=1 + ) + } + training_dataset.features = { + "all": list( + chain(*[training_dataset.features[x] for x in modality_order]) + ) + } - testing_dataset.dat = {'all': torch.cat([testing_dataset.dat[x] for x in modality_order], dim = 1)} - testing_dataset.features = {'all': list(chain(*[testing_dataset.features[x] for x in modality_order]))} + testing_dataset.dat = { + "all": torch.cat( + [testing_dataset.dat[x] for x in modality_order], dim=1 + ) + } + testing_dataset.features = { + "all": list( + chain(*[testing_dataset.features[x] for x in modality_order]) + ) + } # Save final feature lists AFTER concatenation (for inference mode) self.train_features = training_dataset.features.copy() - print("[INFO] Training Data Stats: ", training_dataset.get_dataset_stats()) print("[INFO] Test Data Stats: ", testing_dataset.get_dataset_stats()) print("[INFO] Merging Feature Logs...") logs = self.feature_logs - if 'select_features' in logs: - self.feature_logs = {x: pd.merge(logs['cleanup'][x], - logs['select_features'][x], - on = 'feature', how = 'outer', - suffixes=['_cleanup', '_laplacian']) for x in self.data_types} + if "select_features" in logs: + self.feature_logs = { + x: pd.merge( + logs["cleanup"][x], + logs["select_features"][x], + on="feature", + how="outer", + suffixes=["_cleanup", "_laplacian"], + ) + for x in self.data_types + } else: # Feature selection was skipped (top_percentile=0), so just use cleanup logs - self.feature_logs = logs['cleanup'] + self.feature_logs = logs["cleanup"] print("[INFO] Data import successful.") return training_dataset, testing_dataset @@ -232,19 +289,23 @@ def validate_data_folders(self, training_path, testing_path): training_files = set(os.listdir(training_path)) testing_files = set(os.listdir(testing_path)) - required_files = {'clin.csv'} | {f"{dt}.csv" for dt in self.data_types} + required_files = {"clin.csv"} | {f"{dt}.csv" for dt in self.data_types} if not required_files.issubset(training_files): missing_files = required_files - training_files - raise ValueError(f"Missing files in training folder: {', '.join(missing_files)}") + raise ValueError( + f"Missing files in training folder: {', '.join(missing_files)}" + ) if not required_files.issubset(testing_files): missing_files = required_files - testing_files - raise ValueError(f"Missing files in testing folder: {', '.join(missing_files)}") + raise ValueError( + f"Missing files in testing folder: {', '.join(missing_files)}" + ) def read_data(self, folder_path): data = {} - required_files = {'clin.csv'} | {f"{dt}.csv" for dt in self.data_types} + required_files = {"clin.csv"} | {f"{dt}.csv" for dt in self.data_types} print("\n[INFO] ----------------- Reading Data ----------------- ") for file in required_files: file_path = os.path.join(folder_path, file) @@ -255,12 +316,11 @@ def read_data(self, folder_path): # randomly draw N samples; return subset of dat (output of read_data) def subsample(self, dat, N): - clin = dat['clin'].sample(N) + clin = dat["clin"].sample(N) dat_sub = {x: dat[x][clin.index] for x in self.data_types} - dat_sub['clin'] = clin + dat_sub["clin"] = clin return dat_sub - def filter_by_features(self, dat, features): """ If the user has provided list of features to restrict the analysis to, @@ -271,20 +331,26 @@ def filter_by_features(self, dat, features): for key, df in dat.items() } - print("[INFO] The initial features are filtered to include user-provided features only") + print( + "[INFO] The initial features are filtered to include user-provided features only" + ) for key, df in dat_filtered.items(): remaining_features = len(df.index) - print(f"In layer '{key}', {remaining_features} features are remaining after filtering.") + print( + f"In layer '{key}', {remaining_features} features are remaining after filtering." + ) return dat_filtered - def process_data(self, data, split = 'train'): - print(f"\n[INFO] ----------------- Processing Data ({split}) ----------------- ") + def process_data(self, data, split="train"): + print( + f"\n[INFO] ----------------- Processing Data ({split}) ----------------- " + ) # remove uninformative features and samples with no information (from data matrices) dat = self.cleanup_data({x: data[x] for x in self.data_types}) - ann = data['clin'] + ann = data["clin"] dat, ann, samples = self.get_labels(dat, ann) # do feature selection: only applied to training data - if split == 'train': + if split == "train": if self.top_percentile: dat = self.select_features(dat) features = {x: dat[x].index for x in dat.keys()} @@ -295,10 +361,10 @@ def cleanup_data(self, df_dict): cleaned_dfs = {} sample_masks = [] - feature_logs = {} # keep track of feature variation/NA value scores + feature_logs = {} # keep track of feature variation/NA value scores # First pass: remove near-zero-variation features and create masks for informative samples for key, df in df_dict.items(): - print("\n[INFO] working on layer: ",key) + print("\n[INFO] working on layer: ", key) original_features_count = df.shape[0] # Compute variances and NA percentages for each feature in the DataFrame @@ -306,13 +372,29 @@ def cleanup_data(self, df_dict): na_percentages = df.isna().mean(axis=1) # Combine variances and NA percentages into a single DataFrame for logging - log_df = pd.DataFrame({ 'feature': df.index, 'na_percent': na_percentages, 'variance': feature_variances, 'selected': False}) + log_df = pd.DataFrame( + { + "feature": df.index, + "na_percent": na_percentages, + "variance": feature_variances, + "selected": False, + } + ) # Filter based on both variance and NA percentage thresholds # Identify features that meet both criteria - df = df.loc[(feature_variances >= feature_variances.quantile(self.variance_threshold)) & (na_percentages < self.na_threshold)] + df = df.loc[ + ( + feature_variances + >= feature_variances.quantile(self.variance_threshold) + ) + & (na_percentages < self.na_threshold) + ] # set selected features to True - log_df['selected'] = (log_df['variance'] >= feature_variances.quantile(self.variance_threshold)) & (log_df['na_percent'] < self.na_threshold) + log_df["selected"] = ( + log_df["variance"] + >= feature_variances.quantile(self.variance_threshold) + ) & (log_df["na_percent"] < self.na_threshold) feature_logs[key] = log_df # Step 3: Fill NA values with the median of the feature @@ -320,7 +402,12 @@ def cleanup_data(self, df_dict): if np.sum(df.isna().sum()) > 0: missing_rows = df.isna().any(axis=1) - print("[INFO] Imputing NA values to median of features, affected # of cells in the matrix", np.sum(df.isna().sum()), " # of rows:",sum(missing_rows)) + print( + "[INFO] Imputing NA values to median of features, affected # of cells in the matrix", + np.sum(df.isna().sum()), + " # of rows:", + sum(missing_rows), + ) # Calculate medians for each 'column' (originally rows) and fill NAs # Note: After transposition, operations are more efficient @@ -329,16 +416,20 @@ def cleanup_data(self, df_dict): df_T.fillna(medians_T, inplace=True) df = df_T.T - print("[INFO] Number of NA values: ",np.sum(df.isna().sum())) + print("[INFO] Number of NA values: ", np.sum(df.isna().sum())) removed_features_count = original_features_count - df.shape[0] - print(f"[INFO] DataFrame {key} - Removed {removed_features_count} features.") + print( + f"[INFO] DataFrame {key} - Removed {removed_features_count} features." + ) # Step 2: Create masks for informative samples # Compute standard deviation of samples (along columns) sample_stdevs = df.std(axis=0) # Create mask for samples that do not have std dev of 0 or NaN - mask = np.logical_and(sample_stdevs != 0, np.logical_not(np.isnan(sample_stdevs))) + mask = np.logical_and( + sample_stdevs != 0, np.logical_not(np.isnan(sample_stdevs)) + ) sample_masks.append(mask) cleaned_dfs[key] = df @@ -351,15 +442,22 @@ def cleanup_data(self, df_dict): original_samples_count = cleaned_dfs[key].shape[1] cleaned_dfs[key] = cleaned_dfs[key].loc[:, common_mask] removed_samples_count = original_samples_count - cleaned_dfs[key].shape[1] - print(f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%).") + print( + f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%)." + ) # update feature logs from this process - self.feature_logs['cleanup'] = feature_logs + self.feature_logs["cleanup"] = feature_logs return cleaned_dfs def get_labels(self, dat, ann): # subset samples and reorder annotations for the samples - samples = list(reduce(set.intersection, [set(item) for item in [dat[x].columns for x in dat.keys()]])) + samples = list( + reduce( + set.intersection, + [set(item) for item in [dat[x].columns for x in dat.keys()]], + ) + ) samples = list(set(ann.index).intersection(samples)) dat = {x: dat[x][samples] for x in dat.keys()} ann = ann.loc[samples] @@ -367,18 +465,25 @@ def get_labels(self, dat, ann): # unsupervised feature selection using laplacian score and correlation filters (optional) def select_features(self, dat): - counts = {x: max(int(dat[x].shape[0] * self.top_percentile / 100), self.min_features) for x in dat.keys()} + counts = { + x: max(int(dat[x].shape[0] * self.top_percentile / 100), self.min_features) + for x in dat.keys() + } dat_filtered = {} - feature_logs = {} # feature log for each layer + feature_logs = {} # feature log for each layer for layer in dat.keys(): # filter features in the layer and keep a log of filtering process; notice we provide a transposed matrix - X_filt, log_df = filter_by_laplacian(X = dat[layer].T, layer = layer, - topN=counts[layer], correlation_threshold = self.correlation_threshold) - dat_filtered[layer] = X_filt.T # transpose after laplacian filtering again + X_filt, log_df = filter_by_laplacian( + X=dat[layer].T, + layer=layer, + topN=counts[layer], + correlation_threshold=self.correlation_threshold, + ) + dat_filtered[layer] = X_filt.T # transpose after laplacian filtering again # Features will be stored after concatenation feature_logs[layer] = log_df # update main feature logs with events from this function - self.feature_logs['select_features'] = feature_logs + self.feature_logs["select_features"] = feature_logs return dat_filtered def harmonize(self, dat1, dat2): @@ -386,7 +491,9 @@ def harmonize(self, dat1, dat2): # common data layers common_layers = dat1.keys() & dat2.keys() # Get common features - common_features = {x: dat1[x].index.intersection(dat2[x].index) for x in common_layers} + common_features = { + x: dat1[x].index.intersection(dat2[x].index) for x in common_layers + } # Subset both datasets to only include common features dat1 = {x: dat1[x].loc[common_features[x]] for x in common_layers} dat2 = {x: dat2[x].loc[common_features[x]] for x in common_layers} @@ -411,10 +518,14 @@ def normalize_data(self, data, scaler_type="standard", fit=True): else: raise ValueError("Invalid scaler_type. Choose 'standard' or 'min_max'.") - normalized_data = {x: pd.DataFrame(self.scalers[x].transform(data[x].T), - index=data[x].columns, - columns=data[x].index).T - for x in data.keys()} + normalized_data = { + x: pd.DataFrame( + self.scalers[x].transform(data[x].T), + index=data[x].columns, + columns=data[x].index, + ).T + for x in data.keys() + } return normalized_data def get_torch_dataset(self, dat, ann, samples): @@ -425,19 +536,32 @@ def get_torch_dataset(self, dat, ann, samples): ann, variable_types, label_mappings = self.encode_labels(ann) # Convert DataFrame to tensor with MPS-compatible dtypes - ann = {col: torch.from_numpy(ann[col].values.copy()).float() if ann[col].dtype in ['float64', 'float32'] - else torch.from_numpy(ann[col].values.copy()) for col in ann.columns} - return MultiOmicDataset(dat, ann, variable_types, features, samples, label_mappings) + ann = { + col: ( + torch.from_numpy(ann[col].values.copy()).float() + if ann[col].dtype in ["float64", "float32"] + else torch.from_numpy(ann[col].values.copy()) + ) + for col in ann.columns + } + return MultiOmicDataset( + dat, ann, variable_types, features, samples, label_mappings + ) def encode_labels(self, df): label_mappings = {} + def encode_column(series): nonlocal label_mappings # Declare as nonlocal so that we can modify it # Fill NA values with 'missing' # series = series.fillna('missing') if series.name not in self.encoders: - self.encoders[series.name] = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1) - encoded_series = self.encoders[series.name].fit_transform(series.to_frame()) + self.encoders[series.name] = OrdinalEncoder( + handle_unknown="use_encoded_value", unknown_value=-1 + ) + encoded_series = self.encoders[series.name].fit_transform( + series.to_frame() + ) # NEW: Store encoder for inference mode self.label_encoders[series.name] = self.encoders[series.name] else: @@ -445,60 +569,82 @@ def encode_column(series): # also save label mappings label_mappings[series.name] = { - int(code): label for code, label in enumerate(self.encoders[series.name].categories_[0]) - } + int(code): label + for code, label in enumerate(self.encoders[series.name].categories_[0]) + } return encoded_series.ravel() # Select only the categorical columns - df_categorical = df.select_dtypes(include=['object', 'category']).apply(encode_column) + df_categorical = df.select_dtypes(include=["object", "category"]).apply( + encode_column + ) # Combine the encoded categorical data with the numerical data - df_encoded = pd.concat([df.select_dtypes(exclude=['object', 'category']), df_categorical], axis=1) + df_encoded = pd.concat( + [df.select_dtypes(exclude=["object", "category"]), df_categorical], axis=1 + ) # Store the variable types - variable_types = {col: 'categorical' for col in df_categorical.columns} - variable_types.update({col: 'numerical' for col in df.select_dtypes(exclude=['object', 'category']).columns}) + variable_types = {col: "categorical" for col in df_categorical.columns} + variable_types.update( + { + col: "numerical" + for col in df.select_dtypes(exclude=["object", "category"]).columns + } + ) return df_encoded, variable_types, label_mappings - def validate_input_data(self, train_dat, test_dat): - print("\n[INFO] ----------------- Checking for problems with the input data ----------------- ") + print( + "\n[INFO] ----------------- Checking for problems with the input data ----------------- " + ) errors = [] warnings = [] + def check_rownames(dat, split): # Check 1: Validate first columns are unique for file_name, df in dat.items(): if not df.index.is_unique: - identifier_type = "Sample labels" if file_name == 'clin' else "Feature names" - errors.append(f"Error in {split}/{file_name}.csv: {identifier_type} in the first column must be unique.") + identifier_type = ( + "Sample labels" if file_name == "clin" else "Feature names" + ) + errors.append( + f"Error in {split}/{file_name}.csv: {identifier_type} in the first column must be unique." + ) def check_sample_labels(dat, split): - clin_samples = set(dat['clin'].index) + clin_samples = set(dat["clin"].index) for file_name, df in dat.items(): - if file_name != 'clin': + if file_name != "clin": omics_samples = set(df.columns) matching_samples = clin_samples.intersection(omics_samples) if not matching_samples: - errors.append(f"Error: No matching sample labels found between {split}/clin.csv and {split}/{file_name}.csv.") + errors.append( + f"Error: No matching sample labels found between {split}/clin.csv and {split}/{file_name}.csv." + ) elif len(matching_samples) < len(clin_samples): missing_samples = clin_samples - matching_samples - warnings.append(f"Warning: Some sample labels in {split}/clin.csv are missing in {split}/{file_name}.csv: {missing_samples}") + warnings.append( + f"Warning: Some sample labels in {split}/clin.csv are missing in {split}/{file_name}.csv: {missing_samples}" + ) def check_common_features(train_dat, test_dat): for file_name in train_dat: - if file_name != 'clin' and file_name in test_dat: + if file_name != "clin" and file_name in test_dat: train_features = set(train_dat[file_name].index) test_features = set(test_dat[file_name].index) common_features = train_features.intersection(test_features) if not common_features: - errors.append(f"Error: No common features found between train/{file_name}.csv and test/{file_name}.csv.") + errors.append( + f"Error: No common features found between train/{file_name}.csv and test/{file_name}.csv." + ) - check_rownames(train_dat, 'train') - check_rownames(test_dat, 'test') + check_rownames(train_dat, "train") + check_rownames(test_dat, "test") - check_sample_labels(train_dat, 'train') - check_sample_labels(test_dat, 'test') + check_sample_labels(train_dat, "train") + check_sample_labels(test_dat, "test") check_common_features(train_dat, test_dat) @@ -514,12 +660,10 @@ def check_common_features(train_dat, test_dat): print(f"[ERROR] {i}. {error}") raise Exception("[ERROR] Please correct the above errors and try again.") - if not warnings and not errors: print("[INFO] Data structure is valid with no errors or warnings.") - """ DataImporterInference class for flexynesis inference mode. Add this to flexynesis/data.py @@ -529,7 +673,6 @@ def check_common_features(train_dat, test_dat): import numpy as np - class DataImporterInference: """ Data importer for inference mode. @@ -538,21 +681,28 @@ class DataImporterInference: def __init__(self, test_data_path, artifacts_path, verbose=True): from .inference import _load_artifacts + self.test_data_path = test_data_path self.verbose = verbose self.artifacts = _load_artifacts(artifacts_path) # Map artifact keys to expected names (compatibility layer) - self.feature_names = self.artifacts.get("feature_lists", self.artifacts.get("feature_names", {})) - self.scalers = self.artifacts.get("transforms", self.artifacts.get("scalers", {})) + self.feature_names = self.artifacts.get( + "feature_lists", self.artifacts.get("feature_names", {}) + ) + self.scalers = self.artifacts.get( + "transforms", self.artifacts.get("scalers", {}) + ) self.label_encoders = self.artifacts.get("label_encoders", {}) - self.modalities = self.artifacts.get("data_types", self.artifacts.get("modalities", [])) + self.modalities = self.artifacts.get( + "data_types", self.artifacts.get("modalities", []) + ) self.target_variables = self.artifacts.get("target_variables", []) # For early fusion, we need feature lists for original modalities # but artifacts only have 'all'. We need to reconstruct from transforms. - if self.modalities == ['all']: - original_modalities = self.artifacts.get('original_modalities', []) + if self.modalities == ["all"]: + original_modalities = self.artifacts.get("original_modalities", []) # Feature lists for original modalities can be inferred from transforms # since transforms are keyed by original modality names if original_modalities: @@ -563,7 +713,6 @@ def __init__(self, test_data_path, artifacts_path, verbose=True): if self.verbose: print(f"[INFO] Loaded artifacts for modalities: {self.modalities}") - def import_data(self): """Returns MultiOmicDataset object""" test_data = {} @@ -571,7 +720,7 @@ def import_data(self): labels_df = None # Load clinical data - clin_path = os.path.join(self.test_data_path, 'clin.csv') + clin_path = os.path.join(self.test_data_path, "clin.csv") if os.path.exists(clin_path): labels_df = pd.read_csv(clin_path, index_col=0) @@ -579,20 +728,22 @@ def import_data(self): # For early fusion: load original modalities before concatenation # For covariates: load only the omics data (covariates come from clin.csv) modalities_to_load = self.modalities - if self.modalities == ['all']: - modalities_to_load = self.artifacts.get('original_modalities', []) + if self.modalities == ["all"]: + modalities_to_load = self.artifacts.get("original_modalities", []) if not modalities_to_load: - raise ValueError('[ERROR] Early fusion mode but original_modalities not found in artifacts') + raise ValueError( + "[ERROR] Early fusion mode but original_modalities not found in artifacts" + ) else: # Filter out 'covariates' from modalities_to_load # Covariates will be created from clinical data later - modalities_to_load = [m for m in self.modalities if m != 'covariates'] + modalities_to_load = [m for m in self.modalities if m != "covariates"] # Load each modality (skip 'covariates' - it's in clin.csv) for modality in modalities_to_load: - if modality == 'covariates': + if modality == "covariates": continue # Covariates are in clin.csv, not a separate file - file_path = os.path.join(self.test_data_path, f'{modality}.csv') + file_path = os.path.join(self.test_data_path, f"{modality}.csv") if not os.path.exists(file_path): raise FileNotFoundError(f"[ERROR] Required file not found: {file_path}") @@ -610,18 +761,24 @@ def import_data(self): extra = set(df.columns) - set(expected_features) if missing: - raise ValueError(f"[ERROR] {modality}: Missing {len(missing)} features required by model. " - f"Test data must have same preprocessing as training data.") + raise ValueError( + f"[ERROR] {modality}: Missing {len(missing)} features required by model. " + f"Test data must have same preprocessing as training data." + ) if extra and self.verbose: - print(f"[INFO] {modality}: Ignoring {len(extra)} extra features not in training") + print( + f"[INFO] {modality}: Ignoring {len(extra)} extra features not in training" + ) # Select only expected features in correct order df = df[expected_features] # Apply scaling scaler = self.scalers[modality] - df_scaled = pd.DataFrame(scaler.transform(df.values), index=df.index, columns=df.columns) + df_scaled = pd.DataFrame( + scaler.transform(df.values), index=df.index, columns=df.columns + ) # Store as DataFrame for cross-modality intersection test_data[modality] = df_scaled @@ -639,16 +796,21 @@ def import_data(self): ).float() # Create covariates matrix if needed - if 'covariates' in self.modalities and labels_df is not None: + if "covariates" in self.modalities and labels_df is not None: from flexynesis.utils import create_covariate_matrix, get_variable_types - covariate_vars = self.artifacts.get('covariate_vars', []) + + covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: if self.verbose: print(f"[INFO] Creating covariate matrix for: {covariate_vars}") variable_types_clin = get_variable_types(labels_df) - covariates_df = create_covariate_matrix(covariate_vars, variable_types_clin, labels_df) + covariates_df = create_covariate_matrix( + covariate_vars, variable_types_clin, labels_df + ) # Convert to torch tensor and add to test_data - test_data['covariates'] = torch.from_numpy(covariates_df.T.values).float() + test_data["covariates"] = torch.from_numpy( + covariates_df.T.values + ).float() if samples is None: samples = covariates_df.T.index.tolist() @@ -668,8 +830,7 @@ def import_data(self): # Create mapping from old sample order to new order old_samples = samples # Original order from data CSV df_reordered = pd.DataFrame( - test_data[modality].numpy(), - index=old_samples + test_data[modality].numpy(), index=old_samples ).loc[common_samples] test_data[modality] = torch.from_numpy(df_reordered.values).float() @@ -681,35 +842,50 @@ def import_data(self): valid_mask = ~labels_df[col].isna() encoded = np.full(len(labels_df), -1, dtype=np.int64) - if hasattr(encoder, 'classes_'): # LabelEncoder + if hasattr(encoder, "classes_"): # LabelEncoder if valid_mask.sum() > 0: - encoded[valid_mask] = encoder.transform(labels_df[col][valid_mask].values) + encoded[valid_mask] = encoder.transform( + labels_df[col][valid_mask].values + ) ann_dict[col] = torch.from_numpy(encoded) - variable_types[col] = 'categorical' - label_mappings[col] = {int(c): l for c, l in enumerate(encoder.classes_)} + variable_types[col] = "categorical" + label_mappings[col] = { + int(c): l for c, l in enumerate(encoder.classes_) + } else: # OrdinalEncoder if valid_mask.sum() > 0: - encoded[valid_mask] = encoder.transform(labels_df[col][valid_mask].values.reshape(-1, 1)).ravel() + encoded[valid_mask] = encoder.transform( + labels_df[col][valid_mask].values.reshape(-1, 1) + ).ravel() ann_dict[col] = torch.from_numpy(encoded) - variable_types[col] = 'categorical' - label_mappings[col] = {int(c): l for c, l in enumerate(encoder.categories_[0])} + variable_types[col] = "categorical" + label_mappings[col] = { + int(c): l for c, l in enumerate(encoder.categories_[0]) + } - label_mappings[col][-1] = 'Unknown' # For missing values + label_mappings[col][-1] = "Unknown" # For missing values else: ann_dict[col] = torch.from_numpy(labels_df[col].values).float() - variable_types[col] = 'numerical' + variable_types[col] = "numerical" # Create features dict # For early fusion, get features from scalers since feature_lists only has 'all' - if self.modalities == ['all']: - modalities_for_features = self.artifacts.get('original_modalities', []) + if self.modalities == ["all"]: + modalities_for_features = self.artifacts.get("original_modalities", []) # Get features from scalers for each modality - features = {modality: list(self.scalers[modality].feature_names_in_) for modality in modalities_for_features} + features = { + modality: list(self.scalers[modality].feature_names_in_) + for modality in modalities_for_features + } else: - features = {modality: self.feature_names[modality] for modality in self.modalities} + features = { + modality: self.feature_names[modality] for modality in self.modalities + } # CRITICAL: Reorder test_data dict to match self.modalities order (model expects specific order) - test_data_ordered = {mod: test_data[mod] for mod in self.modalities if mod in test_data} + test_data_ordered = { + mod: test_data[mod] for mod in self.modalities if mod in test_data + } # Create MultiOmicDataset object dataset = MultiOmicDataset( @@ -718,15 +894,18 @@ def import_data(self): variable_types=variable_types, features=features, samples=samples, - label_mappings=label_mappings + label_mappings=label_mappings, ) # Concatenate for early fusion if needed - if self.modalities == ['all']: + if self.modalities == ["all"]: from itertools import chain + # For early fusion, we already loaded original modalities into test_data # Now concatenate them in the SAME ORDER as training - modality_order = self.artifacts.get('original_modalities', list(test_data.keys())) + modality_order = self.artifacts.get( + "original_modalities", list(test_data.keys()) + ) # Concatenate the data tensors concatenated_data = torch.cat([test_data[x] for x in modality_order], dim=1) @@ -735,12 +914,14 @@ def import_data(self): all_features = list(chain(*[dataset.features[x] for x in modality_order])) # Filter to expected features from artifacts - expected_all_features = self.feature_names['all'] - feature_indices = [i for i, f in enumerate(all_features) if f in expected_all_features] + expected_all_features = self.feature_names["all"] + feature_indices = [ + i for i, f in enumerate(all_features) if f in expected_all_features + ] # Update dataset with concatenated data - dataset.dat = {'all': concatenated_data[:, feature_indices]} - dataset.features = {'all': [all_features[i] for i in feature_indices]} + dataset.dat = {"all": concatenated_data[:, feature_indices]} + dataset.features = {"all": [all_features[i] for i in feature_indices]} return dataset @@ -758,7 +939,16 @@ class MultiOmicDataset(Dataset): A PyTorch dataset that can be used for training or evaluation. """ - def __init__(self, dat, ann, variable_types, features, samples, label_mappings, feature_ann=None): + def __init__( + self, + dat, + ann, + variable_types, + features, + samples, + label_mappings, + feature_ann=None, + ): """Initialize the dataset.""" self.dat = dat self.ann = ann @@ -783,7 +973,7 @@ def __getitem__(self, index): subset_ann = {x: self.ann[x][index] for x in self.ann.keys()} return subset_dat, subset_ann, self.samples[index] - def __len__ (self): + def __len__(self): """Get the total number of samples in the dataset. Returns: @@ -805,8 +995,15 @@ def subset(self, indices): subset_samples = [self.samples[idx] for idx in indices] # Create a new dataset object - return MultiOmicDataset(subset_dat, subset_ann, self.variable_types, self.features, - subset_samples, self.label_mappings, self.feature_ann) + return MultiOmicDataset( + subset_dat, + subset_ann, + self.variable_types, + self.features, + subset_samples, + self.label_mappings, + self.feature_ann, + ) def get_feature_subset(self, feature_df): """Get a subset of data matrices corresponding to specified features and concatenate them into a pandas DataFrame. @@ -818,19 +1015,32 @@ def get_feature_subset(self, feature_df): A pandas DataFrame that concatenates the data matrices for the specified features from all layers. """ # Convert the DataFrame to a dictionary - feature_dict = feature_df.groupby('layer')['name'].apply(list).to_dict() + feature_dict = feature_df.groupby("layer")["name"].apply(list).to_dict() dfs = [] for layer, features in feature_dict.items(): if layer in self.dat: # Create a dictionary to look up indices by feature name for each layer - feature_index_dict = {feature: i for i, feature in enumerate(self.features[layer])} + feature_index_dict = { + feature: i for i, feature in enumerate(self.features[layer]) + } # Get the indices for the requested features - indices = [feature_index_dict[feature] for feature in features if feature in feature_index_dict] + indices = [ + feature_index_dict[feature] + for feature in features + if feature in feature_index_dict + ] # Subset the data matrix for the current layer using the indices subset = self.dat[layer][:, indices] # Convert the subset to a pandas DataFrame, add the layer name as a prefix to each column name - df = pd.DataFrame(subset, columns=[f'{layer}_{feature}' for feature in features if feature in feature_index_dict]) + df = pd.DataFrame( + subset, + columns=[ + f"{layer}_{feature}" + for feature in features + if feature in feature_index_dict + ], + ) dfs.append(df) else: print(f"Layer {layer} not found in the dataset.") @@ -844,9 +1054,12 @@ def get_feature_subset(self, feature_df): return result def get_dataset_stats(self): - stats = {': '.join(['feature_count in', x]): self.dat[x].shape[1] for x in self.dat.keys()} - stats['sample_count'] = len(self.samples) - return(stats) + stats = { + ": ".join(["feature_count in", x]): self.dat[x].shape[1] + for x in self.dat.keys() + } + stats["sample_count"] = len(self.samples) + return stats # given a MultiOmicDataset object, convert to Triplets (anchor,positive,negative) @@ -863,7 +1076,9 @@ def __init__(self, mydataset, main_var): self.labels_set, self.label_to_indices = self.get_label_indices(labels) # Valid anchor indices are those without NA labels - self.valid_indices = [i for i, label in enumerate(labels) if not np.isnan(label)] + self.valid_indices = [ + i for i, label in enumerate(labels) if not np.isnan(label) + ] def __getitem__(self, index): # We only use valid non-NA indices for anchors @@ -881,11 +1096,12 @@ def __getitem__(self, index): # choose another sample with a different label # possible negative labels include NA import random + negative_label = random.choice(list(self.labels_set - set([label]))) negative_index = np.random.choice(self.label_to_indices[negative_label]) - pos = self.dataset[positive_index][0] # positive example - neg = self.dataset[negative_index][0] # negative example + pos = self.dataset[positive_index][0] # positive example + neg = self.dataset[negative_index][0] # negative example return anchor, pos, neg, y_dict def __len__(self): @@ -896,8 +1112,9 @@ def get_label_indices(self, labels_array): valid_labels = [l for l in labels_array if not np.isnan(l)] labels_set = set(valid_labels) - label_to_indices = {label: np.where(labels_array == label)[0] - for label in labels_set} + label_to_indices = { + label: np.where(labels_array == label)[0] for label in labels_set + } # Handle NA as a single separate group (if any exist) na_indices = np.where(np.isnan(labels_array))[0] @@ -912,11 +1129,15 @@ class MultiOmicDatasetNW(Dataset): def __init__(self, multiomic_dataset, interaction_df, modality_order=None): self.multiomic_dataset = multiomic_dataset self.interaction_df = interaction_df - self.modality_order = modality_order if modality_order else sorted(multiomic_dataset.dat.keys()) + self.modality_order = ( + modality_order if modality_order else sorted(multiomic_dataset.dat.keys()) + ) # Compute union of features in the data matrices that also appear in the network self.common_features = self.find_union_features() - self.gene_to_index = {gene: idx for idx, gene in enumerate(self.common_features)} + self.gene_to_index = { + gene: idx for idx, gene in enumerate(self.common_features) + } self.edge_index = self.create_edge_index() self.samples = self.multiomic_dataset.samples self.variable_types = self.multiomic_dataset.variable_types @@ -927,30 +1148,42 @@ def __init__(self, multiomic_dataset, interaction_df, modality_order=None): self.node_features_tensor = self.precompute_node_features() # Store labels for all samples - self.labels = {target_name: labels for target_name, labels in self.multiomic_dataset.ann.items()} + self.labels = { + target_name: labels + for target_name, labels in self.multiomic_dataset.ann.items() + } def find_union_features(self): # Find the union of all features in the multiomic dataset - all_omic_features = set().union(*(set(features) for features in self.multiomic_dataset.features.values())) + all_omic_features = set().union( + *(set(features) for features in self.multiomic_dataset.features.values()) + ) # Find the union of proteins involved in interactions - interaction_genes = set(self.interaction_df['protein1']).union(set(self.interaction_df['protein2'])) + interaction_genes = set(self.interaction_df["protein1"]).union( + set(self.interaction_df["protein2"]) + ) # Return the intersection of omic features and interaction genes return sorted(list(all_omic_features.intersection(interaction_genes))) def create_edge_index(self): # Create edges only if both proteins are within the available features filtered_df = self.interaction_df[ - (self.interaction_df['protein1'].isin(self.common_features)) & - (self.interaction_df['protein2'].isin(self.common_features)) + (self.interaction_df["protein1"].isin(self.common_features)) + & (self.interaction_df["protein2"].isin(self.common_features)) + ] + edge_list = [ + (self.gene_to_index[row["protein1"]], self.gene_to_index[row["protein2"]]) + for index, row in filtered_df.iterrows() ] - edge_list = [(self.gene_to_index[row['protein1']], self.gene_to_index[row['protein2']]) for index, row in filtered_df.iterrows()] return torch.tensor(edge_list, dtype=torch.long).t() def precompute_node_features(self): num_samples = len(self.samples) num_nodes = len(self.common_features) num_data_types = len(self.multiomic_dataset.dat) - all_features = torch.full((num_samples, num_nodes, num_data_types), float('nan'), dtype=torch.float) + all_features = torch.full( + (num_samples, num_nodes, num_data_types), float("nan"), dtype=torch.float + ) # CRITICAL: Use sorted keys to ensure consistent order across training/inference data_types_ordered = sorted(self.multiomic_dataset.dat.keys()) @@ -958,16 +1191,24 @@ def precompute_node_features(self): data_matrix = self.multiomic_dataset.dat[data_type] feature_indices = { gene: self.multiomic_dataset.features[data_type].get_loc(gene) - for gene in self.common_features if gene in self.multiomic_dataset.features[data_type] + for gene in self.common_features + if gene in self.multiomic_dataset.features[data_type] } - valid_indices = torch.tensor(list(feature_indices.values()), dtype=torch.long) - feature_positions = torch.tensor([self.gene_to_index[gene] for gene in feature_indices.keys()], dtype=torch.long) + valid_indices = torch.tensor( + list(feature_indices.values()), dtype=torch.long + ) + feature_positions = torch.tensor( + [self.gene_to_index[gene] for gene in feature_indices.keys()], + dtype=torch.long, + ) # Fill in the available data all_features[:, feature_positions, i] = data_matrix[:, valid_indices] # Precompute medians for all data types, ignoring NaN values - medians = torch.nanmedian(all_features, dim=1, keepdim=True).values # Use .values to get the actual median tensor + medians = torch.nanmedian( + all_features, dim=1, keepdim=True + ).values # Use .values to get the actual median tensor # Replace all NaN values in all_features with their corresponding median values isnan = torch.isnan(all_features) @@ -982,10 +1223,11 @@ def subset(self, indices): # Create a new instance of MultiOmicDatasetNW with the subsetted multiomic dataset return MultiOmicDatasetNW(dataset_subset, self.interaction_df.copy()) - def __getitem__(self, idx): node_features_tensor = self.node_features_tensor[idx] - y_dict = {target_name: self.labels[target_name][idx] for target_name in self.labels} + y_dict = { + target_name: self.labels[target_name][idx] for target_name in self.labels + } return node_features_tensor, y_dict, self.samples[idx] def __len__(self): @@ -1002,13 +1244,19 @@ def print_stats(self): # Calculate degree for each node degrees = torch.zeros(num_nodes, dtype=torch.long) degrees.index_add_(0, self.edge_index[0], torch.ones_like(self.edge_index[0])) - degrees.index_add_(0, self.edge_index[1], torch.ones_like(self.edge_index[1])) # For undirected graphs + degrees.index_add_( + 0, self.edge_index[1], torch.ones_like(self.edge_index[1]) + ) # For undirected graphs num_singletons = torch.sum(degrees == 0).item() non_singletons = degrees[degrees > 0] - mean_edges_per_node = non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 - median_edges_per_node = non_singletons.median().item() if len(non_singletons) > 0 else 0 + mean_edges_per_node = ( + non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 + ) + median_edges_per_node = ( + non_singletons.median().item() if len(non_singletons) > 0 else 0 + ) max_edges = degrees.max().item() print("Dataset Statistics:") @@ -1016,10 +1264,15 @@ def print_stats(self): print(f"Total number of edges: {num_edges}") print(f"Number of node features per node: {num_node_features}") print(f"Number of singletons (nodes with no edges): {num_singletons}") - print(f"Mean number of edges per node (excluding singletons): {mean_edges_per_node:.2f}") - print(f"Median number of edges per node (excluding singletons): {median_edges_per_node}") + print( + f"Mean number of edges per node (excluding singletons): {mean_edges_per_node:.2f}" + ) + print( + f"Median number of edges per node (excluding singletons): {median_edges_per_node}" + ) print(f"Max number of edges per node: {max_edges}") + def get_flexynesis_cache_dir() -> Path: """Resolve a writable cache directory for Flexynesis.""" env_cache = os.getenv("FLEXYNESIS_CACHE") @@ -1046,26 +1299,41 @@ class STRING(PYGDataset): - Downloads once, processes once per (version, organism, node_name). - Safe for concurrent jobs via a file lock. """ + base_folder = "STRING" version = "12.0" files = ("links", "aliases") - url = ("https://stringdb-downloads.org/download/" - "protein.{data}.v{version}/" - "{organism}.protein.{data}.v{version}.txt.gz") - - def __init__(self, root: str | None = None, organism: int = 9606, node_name: str = "gene_name") -> None: + url = ( + "https://stringdb-downloads.org/download/" + "protein.{data}.v{version}/" + "{organism}.protein.{data}.v{version}.txt.gz" + ) + + def __init__( + self, + root: str | None = None, + organism: int = 9606, + node_name: str = "gene_name", + ) -> None: self.organism = organism self.node_name = node_name # ---------- resolve central cache ---------- - cache_root = get_flexynesis_cache_dir() / self.base_folder / f"v{self.version}" / str(self.organism) + cache_root = ( + get_flexynesis_cache_dir() + / self.base_folder + / f"v{self.version}" + / str(self.organism) + ) # layout self._cache_root = cache_root self._cache_raw_dir = cache_root / "raw" self._cache_processed_dir = cache_root / "processed" self._processed_basename = f"graph_{self.node_name}.csv" - self._cache_processed_path = self._cache_processed_dir / self._processed_basename + self._cache_processed_path = ( + self._cache_processed_dir / self._processed_basename + ) self._cache_raw_dir.mkdir(parents=True, exist_ok=True) self._cache_processed_dir.mkdir(parents=True, exist_ok=True) @@ -1081,7 +1349,9 @@ def __init__(self, root: str | None = None, organism: int = 9606, node_name: str super().__init__(str(self._cache_root)) # Load processed data - self.graph_df = pd.read_csv(self.processed_paths[0], sep=",", header=0, index_col=0) + self.graph_df = pd.read_csv( + self.processed_paths[0], sep=",", header=0, index_col=0 + ) # -------------------- PyG stubs -------------------- def len(self) -> int: @@ -1123,7 +1393,7 @@ def cache_dir(self) -> Path: return self._cache_root -def read_user_graph(fpath, sep=None, header='infer', **pd_read_csv_kw): +def read_user_graph(fpath, sep=None, header="infer", **pd_read_csv_kw): """Read user-provided gene-gene interaction network file. This function reads a custom network file and validates that it contains @@ -1159,18 +1429,19 @@ def read_user_graph(fpath, sep=None, header='infer', **pd_read_csv_kw): if sep is None: # Use CSV Sniffer for robust separator detection import csv - with open(fpath, 'r') as f: + + with open(fpath, "r") as f: # Read first few lines for better detection sample = f.read(4096) try: sniffer = csv.Sniffer() - dialect = sniffer.sniff(sample, delimiters='\t,| ') + dialect = sniffer.sniff(sample, delimiters="\t,| ") sep = dialect.delimiter print(f"[INFO] Auto-detected separator using CSV Sniffer: {repr(sep)}") except csv.Error: # Fallback to tab if Sniffer fails - sep = '\t' + sep = "\t" print(f"[INFO] CSV Sniffer failed, using default separator: {repr(sep)}") # Read the file @@ -1186,21 +1457,43 @@ def read_user_graph(fpath, sep=None, header='infer', **pd_read_csv_kw): # Get column names (handle cases with or without header) if header is None or (isinstance(header, int) and header < 0): # No header - assign default names - print("[INFO] No header detected. Assuming first 3 columns are: GeneA, GeneB, Score") - df.columns = [f'col_{i}' for i in range(len(df.columns))] - col_gene_a = 'col_0' - col_gene_b = 'col_1' - col_score = 'col_2' + print( + "[INFO] No header detected. Assuming first 3 columns are: GeneA, GeneB, Score" + ) + df.columns = [f"col_{i}" for i in range(len(df.columns))] + col_gene_a = "col_0" + col_gene_b = "col_1" + col_score = "col_2" else: # Header present - use hybrid scoring to intelligently identify columns from difflib import SequenceMatcher # Define candidate keywords for each column type candidates = { - 'gene_a': ['genea', 'gene_a', 'gene1', 'protein1', 'node1', 'source', 'from'], - 'gene_b': ['geneb', 'gene_b', 'gene2', 'protein2', 'node2', 'target', 'to'], - 'score': ['score', 'weight', 'combined_score', 'correlation', 'confidence', - 'value', 'strength', 'coef', 'coefficient', 'corr', 'pvalue', 'p_value'] + "gene_a": [ + "genea", + "gene_a", + "gene1", + "protein1", + "node1", + "source", + "from", + ], + "gene_b": ["geneb", "gene_b", "gene2", "protein2", "node2", "target", "to"], + "score": [ + "score", + "weight", + "combined_score", + "correlation", + "confidence", + "value", + "strength", + "coef", + "coefficient", + "corr", + "pvalue", + "p_value", + ], } def score_column_match(col, col_idx, category, total_cols): @@ -1230,15 +1523,15 @@ def score_column_match(col, col_idx, category, total_cols): # Signal 4: Position hints (10 points) # First column likely gene_a, second likely gene_b, third+ likely score - if category == 'gene_a' and col_idx == 0: + if category == "gene_a" and col_idx == 0: total_score += 10 - elif category == 'gene_b' and col_idx == 1: + elif category == "gene_b" and col_idx == 1: total_score += 10 - elif category == 'score' and col_idx >= 2: + elif category == "score" and col_idx >= 2: total_score += 10 # Signal 5: Data type hint for score (5 points) - if category == 'score': + if category == "score": if pd.api.types.is_numeric_dtype(df[col]): total_score += 5 @@ -1248,7 +1541,7 @@ def score_column_match(col, col_idx, category, total_cols): total_cols = len(df.columns) matches = {} - for category in ['gene_a', 'gene_b', 'score']: + for category in ["gene_a", "gene_b", "score"]: best_col = None best_score = 0 @@ -1261,29 +1554,37 @@ def score_column_match(col, col_idx, category, total_cols): matches[category] = (best_col, best_score) # Extract matched columns and scores - col_gene_a, score_a = matches['gene_a'] - col_gene_b, score_b = matches['gene_b'] - col_score, score_s = matches['score'] + col_gene_a, score_a = matches["gene_a"] + col_gene_b, score_b = matches["gene_b"] + col_score, score_s = matches["score"] # Check if confidence is too low min_threshold = 30 - if score_a < min_threshold or score_b < min_threshold or score_s < min_threshold: - print(f"[WARNING] Low confidence in column detection. Using first 3 columns as fallback.") + if ( + score_a < min_threshold + or score_b < min_threshold + or score_s < min_threshold + ): + print( + f"[WARNING] Low confidence in column detection. Using first 3 columns as fallback." + ) col_gene_a = df.columns[0] col_gene_b = df.columns[1] col_score = df.columns[2] else: - print(f"[INFO] Detected columns - GeneA: '{col_gene_a}', GeneB: '{col_gene_b}', Score: '{col_score}'") + print( + f"[INFO] Detected columns - GeneA: '{col_gene_a}', GeneB: '{col_gene_b}', Score: '{col_score}'" + ) # Extract only the required 3 columns and standardize names result_df = df[[col_gene_a, col_gene_b, col_score]].copy() - result_df.columns = ['protein1', 'protein2', 'combined_score'] + result_df.columns = ["protein1", "protein2", "combined_score"] # Validate data types # Convert score to numeric if not already - if not pd.api.types.is_numeric_dtype(result_df['combined_score']): + if not pd.api.types.is_numeric_dtype(result_df["combined_score"]): try: - result_df['combined_score'] = pd.to_numeric(result_df['combined_score']) + result_df["combined_score"] = pd.to_numeric(result_df["combined_score"]) except ValueError: raise ValueError( f"Score column must contain numeric values. " @@ -1294,14 +1595,19 @@ def score_column_match(col, col_idx, category, total_cols): original_len = len(result_df) result_df = result_df.dropna() if len(result_df) < original_len: - print(f"[WARNING] Removed {original_len - len(result_df)} rows with missing values") + print( + f"[WARNING] Removed {original_len - len(result_df)} rows with missing values" + ) print(f"[INFO] Successfully loaded user graph: {len(result_df)} interactions") - print(f"[INFO] Score range: [{result_df['combined_score'].min():.4f}, {result_df['combined_score'].max():.4f}]") + print( + f"[INFO] Score range: [{result_df['combined_score'].min():.4f}, {result_df['combined_score'].max():.4f}]" + ) return result_df -def read_stringdb_links(fname, top_neighbors = 5): + +def read_stringdb_links(fname, top_neighbors=5): """ Reads and processes a STRING database file to extract and rank protein-protein interactions. @@ -1327,19 +1633,26 @@ def read_stringdb_links(fname, top_neighbors = 5): """ df = pd.read_csv(fname, header=0, sep=" ") df = df[df.combined_score > 400] - df_expanded = pd.concat([ - df.rename(columns={'protein1': 'protein', 'protein2': 'partner'}), - df.rename(columns={'protein2': 'protein', 'protein1': 'partner'}) - ]) + df_expanded = pd.concat( + [ + df.rename(columns={"protein1": "protein", "protein2": "partner"}), + df.rename(columns={"protein2": "protein", "protein1": "partner"}), + ] + ) # Sort the expanded DataFrame by 'combined_score' in descending order - df_expanded_sorted = df_expanded.sort_values(by='combined_score', ascending=False) - # Reduce to unique interactions to avoid counting duplicates - df_expanded_unique = df_expanded_sorted.drop_duplicates(subset=['protein', 'partner']) - top_interactions = df_expanded_unique.groupby('protein').head(top_neighbors) - df = top_interactions.rename(columns={'protein': 'protein1', 'partner': 'protein2'}) - df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map(lambda a: a.split(".")[-1]) + df_expanded_sorted = df_expanded.sort_values(by="combined_score", ascending=False) + # Reduce to unique interactions to avoid counting duplicates + df_expanded_unique = df_expanded_sorted.drop_duplicates( + subset=["protein", "partner"] + ) + top_interactions = df_expanded_unique.groupby("protein").head(top_neighbors) + df = top_interactions.rename(columns={"protein": "protein1", "partner": "protein2"}) + df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map( + lambda a: a.split(".")[-1] + ) return df + def read_stringdb_aliases(fname: str, node_name: str) -> dict[str, str]: if node_name == "gene_id": source = ("Ensembl_HGNC_ensembl_gene_id", "Ensembl_gene") @@ -1385,6 +1698,7 @@ def fn(a): graph_df[["protein1", "protein2"]] = graph_df[["protein1", "protein2"]].map(fn) return graph_df + def stringdb_links_to_list(df): lst = df[["protein1", "protein2"]].to_numpy().tolist() return lst @@ -1401,7 +1715,7 @@ def split_by_median(tensor_dict): # Convert to categorical, but preserve NaNs tensor_cat = (tensor > median_val).float() - tensor_cat[torch.isnan(tensor)] = float('nan') + tensor_cat[torch.isnan(tensor)] = float("nan") new_dict[key] = tensor_cat else: # If tensor is not numerical, leave it as it is diff --git a/flexynesis/feature_selection.py b/flexynesis/feature_selection.py index 5b072d79..763dc7f9 100644 --- a/flexynesis/feature_selection.py +++ b/flexynesis/feature_selection.py @@ -1,4 +1,4 @@ -# Tools to do feature selection +# Tools to do feature selection import numpy as np import pandas as pd @@ -9,15 +9,16 @@ from scipy.sparse import csr_matrix, diags from tqdm import tqdm + def laplacian_score(X, k=5, t=None): """ Calculate Laplacian Score for each feature in the dataset. - + Parameters: X: numpy array (n_samples, n_features) - Input data k: int - Number of nearest neighbors to consider t: float - Heat kernel parameter (optional) - + Returns: scores: (n_features,) - Laplacian Scores for each feature """ @@ -27,11 +28,11 @@ def laplacian_score(X, k=5, t=None): dist_matrix = squareform(pdist(X)) # Calculate the k-nearest neighbors adjacency matrix - W = kneighbors_graph(X, k, mode='connectivity', include_self=True) + W = kneighbors_graph(X, k, mode="connectivity", include_self=True) # Compute the heat kernel (optional) if t is not None: - W = csr_matrix(np.exp(-dist_matrix ** 2 / t)) + W = csr_matrix(np.exp(-(dist_matrix**2) / t)) # Normalize the adjacency matrix D = np.array(W.sum(axis=1)).flatten() @@ -46,8 +47,8 @@ def laplacian_score(X, k=5, t=None): # Calculate Laplacian Score for each feature scores = [] D = diags(D) # Convert to sparse diagonal matrix for efficient dot product - for i in tqdm(range(n_features), desc='Calculating Laplacian scores'): - fi = X[:,i] + for i in tqdm(range(n_features), desc="Calculating Laplacian scores"): + fi = X[:, i] fi = fi - np.dot(S, fi).sum() / n_samples L_score = np.dot(fi.T, L @ fi) / np.dot(fi.T, D @ fi) scores.append(L_score) @@ -57,14 +58,14 @@ def laplacian_score(X, k=5, t=None): def remove_redundant_features(X, laplacian_scores, threshold, topN=None): """ - Removes features from the dataset based on correlation and Laplacian scores, optionally + Removes features from the dataset based on correlation and Laplacian scores, optionally keeping only the top N features based on their scores, and manages redundant features. - This function evaluates features in a dataset for redundancy by measuring the correlation - between them. If the absolute correlation between two features exceeds a specified threshold, - the feature with the higher Laplacian score (a lower score is better) is considered redundant - and is marked for removal. - If the number of selected features is less than a specified top N, additional features are + This function evaluates features in a dataset for redundancy by measuring the correlation + between them. If the absolute correlation between two features exceeds a specified threshold, + the feature with the higher Laplacian score (a lower score is better) is considered redundant + and is marked for removal. + If the number of selected features is less than a specified top N, additional features are included from the redundant set based on their Laplacian scores until the top N count is reached. Parameters @@ -98,13 +99,15 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): indicate features more important for preserving data structure. - Features initially marked as redundant but included during the top-up process to meet the topN requirement are removed from the redundant_features_df before it is returned. - """ + """ correlation_matrix = np.corrcoef(X.T) - ranked_indices = np.argsort(laplacian_scores) # Assuming minimizing laplacian_scores + ranked_indices = np.argsort( + laplacian_scores + ) # Assuming minimizing laplacian_scores selected_features = [] redundant_features = {} - for idx in tqdm(ranked_indices, desc='Filtering redundant features'): + for idx in tqdm(ranked_indices, desc="Filtering redundant features"): correlated = False for selected_idx in selected_features: correlation = np.abs(correlation_matrix[idx, selected_idx]) @@ -115,7 +118,10 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): break if correlated: - redundant_features[idx] = {'correlated_with': correlated_feature, 'correlation_score': correlation_score} + redundant_features[idx] = { + "correlated_with": correlated_feature, + "correlation_score": correlation_score, + } else: selected_features.append(idx) @@ -123,7 +129,9 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): if topN is not None and len(selected_features) < topN: shortfall = topN - len(selected_features) # Sort redundant features by their laplacian score, prioritizing lower scores - sorted_redundant_indices = sorted(redundant_features.keys(), key=lambda x: laplacian_scores[x]) + sorted_redundant_indices = sorted( + redundant_features.keys(), key=lambda x: laplacian_scores[x] + ) topped_up_features = [] for idx in sorted_redundant_indices: if len(selected_features) < topN: @@ -135,24 +143,29 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): # Remove topped-up features from the redundant_features dictionary for idx in topped_up_features: del redundant_features[idx] - + if len(redundant_features) > 0: # Convert redundant_features dictionary to DataFrame - redundant_features_df = pd.DataFrame([ - { - "feature": X.columns[idx], - "correlated_with": X.columns[redundant_features[idx]['correlated_with']], - "correlation_score": redundant_features[idx]['correlation_score'] - } - for idx in redundant_features - ]) + redundant_features_df = pd.DataFrame( + [ + { + "feature": X.columns[idx], + "correlated_with": X.columns[ + redundant_features[idx]["correlated_with"] + ], + "correlation_score": redundant_features[idx]["correlation_score"], + } + for idx in redundant_features + ] + ) return X.columns[selected_features], redundant_features_df else: return X.columns[selected_features], pd.DataFrame() + def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0.9): """ - Filters features in a dataset based on Laplacian scores and removes highly correlated features, + Filters features in a dataset based on Laplacian scores and removes highly correlated features, retaining only the top N features with the lowest scores and optionally considering correlation. This function computes Laplacian scores for each feature in the dataset to measure its importance @@ -170,13 +183,13 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 k : int, optional The number of nearest neighbors to consider for computing the Laplacian score (default is 5). t : float, optional - The heat kernel parameter for Laplacian score computation. If None, the default behavior + The heat kernel parameter for Laplacian score computation. If None, the default behavior applies without a heat kernel (default is None). topN : int, optional The number of top features to keep based on the lowest Laplacian scores (default is 100). correlation_threshold : float, optional - The Pearson correlation coefficient threshold for identifying and removing highly correlated - features. Features with a correlation above this threshold are considered redundant + The Pearson correlation coefficient threshold for identifying and removing highly correlated + features. Features with a correlation above this threshold are considered redundant (default is 0.9). Returns @@ -198,21 +211,32 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 correlated with any already selected feature. - The process may select additional features beyond `topN` before correlation filtering to ensure that the best candidates are considered. The final number of features, however, is pruned to `topN`. - """ - print("[INFO] Implementing feature selection using laplacian score for layer:",layer,"with ",X.shape[1],"features"," and ",X.shape[0], " samples ") - - feature_log = pd.DataFrame({'feature': X.columns, 'laplacian_score': np.nan}) + """ + print( + "[INFO] Implementing feature selection using laplacian score for layer:", + layer, + "with ", + X.shape[1], + "features", + " and ", + X.shape[0], + " samples ", + ) + + feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": np.nan}) # only apply filtering if topN < n_features - if topN >= X.shape[1]: - print("[INFO] No feature selection applied. Returning original matrix. Demanded # of features is ", - "larger than existing number of features") + if topN >= X.shape[1]: + print( + "[INFO] No feature selection applied. Returning original matrix. Demanded # of features is ", + "larger than existing number of features", + ) return X, feature_log - + # compute laplacian scores scores = laplacian_score(X.values, k, t) - - feature_log = pd.DataFrame({'feature': X.columns, 'laplacian_score': scores}) - + + feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": scores}) + # Sort the features based on their Laplacian Scores sorted_indices = np.argsort(scores) selected_feature_indices = sorted_indices[:topN] @@ -220,27 +244,34 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 if correlation_threshold < 1: # Choose the topN + 10% features with the lowest Laplacian Scores - # this is done to avoid unnecessary computation of correlation for all features. + # this is done to avoid unnecessary computation of correlation for all features. topN_extended = int(topN + 0.10 * X.shape[1]) - topN_extended = min(topN_extended, X.shape[1]) # Ensure we don't exceed the number of features + topN_extended = min( + topN_extended, X.shape[1] + ) # Ensure we don't exceed the number of features selected_features = sorted_indices[:topN_extended] # Remove redundancy from topN + 10% features - selected_features, redundant_features_df = remove_redundant_features(X[X.columns[selected_feature_indices]], - scores[selected_feature_indices], correlation_threshold, - topN) + selected_features, redundant_features_df = remove_redundant_features( + X[X.columns[selected_feature_indices]], + scores[selected_feature_indices], + correlation_threshold, + topN, + ) # Prune down to topN features selected_features = selected_features[:topN] - - # if any redundant features found, merge feature log with info from this. + + # if any redundant features found, merge feature log with info from this. if not redundant_features_df.empty: # record the table of features which were removed due to redundancy - feature_log = pd.merge(feature_log, redundant_features_df, on = 'feature', how = 'outer') + feature_log = pd.merge( + feature_log, redundant_features_df, on="feature", how="outer" + ) # Extract the selected features from the dataset X_selected = X[selected_features] - - feature_log['selected'] = False - feature_log.loc[feature_log['feature'].isin(selected_features),'selected'] = True - + + feature_log["selected"] = False + feature_log.loc[feature_log["feature"].isin(selected_features), "selected"] = True + return X_selected, feature_log diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index ecf2150d..24c98f4e 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -30,39 +30,41 @@ import torch -def build_network(expr_df, method='spearman', min_correlation=0.3, top_k=10, device=None): +def build_network( + expr_df, method="spearman", min_correlation=0.3, top_k=10, device=None +): """ Build co-expression network without storing full correlation matrix. Computes correlations in batches and immediately filters to save memory. - + Args: expr_df: DataFrame with genes as rows, samples as columns method: 'spearman' or 'pearson' min_correlation: Minimum absolute correlation threshold top_k: Keep top K neighbors per gene device: torch device (cuda/mps/cpu). Auto-detected if None. - + Returns: List of edge dictionaries with GeneA, GeneB, Score """ # Auto-detect device if not specified if device is None: if torch.cuda.is_available(): - device = torch.device('cuda') + device = torch.device("cuda") elif torch.backends.mps.is_available(): - device = torch.device('mps') + device = torch.device("mps") else: - device = torch.device('cpu') - + device = torch.device("cpu") + print(f"Using device: {device}") print(f"Calculating {method} correlations for {len(expr_df)} genes...") - + # Convert to torch tensor and move to GPU data = torch.tensor(expr_df.values, dtype=torch.float32, device=device) n_genes = data.shape[0] gene_names = expr_df.index.tolist() - - if method == 'spearman': + + if method == "spearman": # Convert to ranks for Spearman - process in batches with progress bar print("Computing ranks...") batch_size = 5000 @@ -71,83 +73,91 @@ def build_network(expr_df, method='spearman', min_correlation=0.3, top_k=10, dev for i in range(0, n_genes, batch_size): end_i = min(i + batch_size, n_genes) batch = data[i:end_i] - ranks[i:end_i] = torch.argsort(torch.argsort(batch, dim=1), dim=1).float() + ranks[i:end_i] = torch.argsort( + torch.argsort(batch, dim=1), dim=1 + ).float() pbar.update(end_i - i) data = ranks - elif method != 'pearson': + elif method != "pearson": raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'") - + # Standardize for correlation computation print("Standardizing data...") data_mean = data.mean(dim=1, keepdim=True) data_std = data.std(dim=1, keepdim=True, unbiased=False) data_normalized = (data - data_mean) / (data_std + 1e-8) - + # Compute correlations in batches and extract top-k on the fly batch_size = 1000 # Process 1000 genes at a time edges = [] - - print(f"Computing correlations and building network (min |r| = {min_correlation}, top {top_k} per gene)...") + + print( + f"Computing correlations and building network (min |r| = {min_correlation}, top {top_k} per gene)..." + ) with tqdm(total=n_genes, desc="Processing genes") as pbar: for i in range(0, n_genes, batch_size): end_i = min(i + batch_size, n_genes) batch = data_normalized[i:end_i] - + # Compute correlation for this batch with all genes corr_batch = torch.mm(batch, data_normalized.T) / data.shape[1] - + # Process each gene in the batch for local_idx, global_idx in enumerate(range(i, end_i)): gene_corr = corr_batch[local_idx] gene_name = gene_names[global_idx] - + # Remove self-correlation gene_corr[global_idx] = 0 - + # Get absolute correlations abs_corr = gene_corr.abs() - + # Filter by threshold mask = abs_corr >= min_correlation - + # Get top-k if mask.sum() > top_k: # Get indices of top-k values - top_k_values, top_k_indices = torch.topk(abs_corr, min(top_k, len(abs_corr))) + top_k_values, top_k_indices = torch.topk( + abs_corr, min(top_k, len(abs_corr)) + ) # Filter to only those above threshold valid_mask = top_k_values >= min_correlation top_k_indices = top_k_indices[valid_mask] else: # All values above threshold top_k_indices = torch.where(mask)[0] - + # Add edges for neighbor_idx in top_k_indices: neighbor_idx = neighbor_idx.item() score = abs_corr[neighbor_idx].item() - edges.append({ - 'GeneA': gene_name, - 'GeneB': gene_names[neighbor_idx], - 'Score': score - }) - + edges.append( + { + "GeneA": gene_name, + "GeneB": gene_names[neighbor_idx], + "Score": score, + } + ) + pbar.update(end_i - i) - + return edges def generate_coexpression_network( input_file, output_file, - method='spearman', + method="spearman", min_correlation=0.3, top_k=10, remove_self_loops=True, - remove_duplicates=True + remove_duplicates=True, ): """ Main function to generate co-expression network. - + Args: input_file: Path to gene expression CSV/TSV file output_file: Path to output network file (CSV or TSV) @@ -160,23 +170,23 @@ def generate_coexpression_network( print("=" * 70) print("Co-expression Network Generator") print("=" * 70) - + # Load expression data print(f"\n[1/3] Loading expression data from: {input_file}") try: - sep = '\t' if input_file.endswith('.tsv') else ',' + sep = "\t" if input_file.endswith(".tsv") else "," expr_df = pd.read_csv(input_file, sep=sep, index_col=0) except Exception as e: # Fallback: try the other separator try: - sep = ',' if sep == '\t' else '\t' + sep = "," if sep == "\t" else "\t" expr_df = pd.read_csv(input_file, sep=sep, index_col=0) except Exception as e2: print(f"[ERROR] Failed to load file: {e2}") sys.exit(1) - + print(f" Expression matrix: {expr_df.shape[0]} genes × {expr_df.shape[1]} samples") - + # Check for missing values na_count = expr_df.isna().sum().sum() if na_count > 0: @@ -184,41 +194,39 @@ def generate_coexpression_network( print(f" [WARNING] Found {na_count} missing values in {genes_with_na} genes") print(f" [INFO] Removing genes with missing data.") expr_df = expr_df.dropna() - print(f" [INFO] Retained {expr_df.shape[0]} genes ({genes_with_na} genes removed)") - + print( + f" [INFO] Retained {expr_df.shape[0]} genes ({genes_with_na} genes removed)" + ) + # Build network directly print(f"\n[2/3] Building network...") edges = build_network( - expr_df, - method=method, - min_correlation=min_correlation, - top_k=top_k + expr_df, method=method, min_correlation=min_correlation, top_k=top_k ) - + # Create dataframe network_df = pd.DataFrame(edges) - + if len(network_df) == 0: print("[WARNING] No edges found! Try lowering min_correlation threshold.") print("[ERROR] No edges in network! Exiting.") sys.exit(1) - + # Remove duplicate edges (A-B and B-A are the same) if remove_duplicates: print("\nRemoving duplicate edges...") - network_df['pair'] = network_df.apply( - lambda row: tuple(sorted([row['GeneA'], row['GeneB']])), - axis=1 + network_df["pair"] = network_df.apply( + lambda row: tuple(sorted([row["GeneA"], row["GeneB"]])), axis=1 ) original_len = len(network_df) - network_df = network_df.drop_duplicates(subset='pair').drop(columns='pair') + network_df = network_df.drop_duplicates(subset="pair").drop(columns="pair") print(f" Removed {original_len - len(network_df)} duplicate edges") - + # Save network (auto-detect format from extension) print(f"\n[3/3] Saving network to: {output_file}") - sep = '\t' if output_file.endswith('.tsv') else ',' + sep = "\t" if output_file.endswith(".tsv") else "," network_df.to_csv(output_file, sep=sep, index=False) - + # Print statistics print("\n" + "=" * 70) print("Network Generation Complete!") @@ -227,15 +235,19 @@ def generate_coexpression_network( print(f" Total edges: {len(network_df):,}") print(f" Unique genes (GeneA): {network_df['GeneA'].nunique():,}") print(f" Unique genes (GeneB): {network_df['GeneB'].nunique():,}") - print(f" All unique genes: {len(set(network_df['GeneA']) | set(network_df['GeneB'])):,}") - print(f" Score range: [{network_df['Score'].min():.4f}, {network_df['Score'].max():.4f}]") + print( + f" All unique genes: {len(set(network_df['GeneA']) | set(network_df['GeneB'])):,}" + ) + print( + f" Score range: [{network_df['Score'].min():.4f}, {network_df['Score'].max():.4f}]" + ) print(f" Mean score: {network_df['Score'].mean():.4f}") print(f" Median score: {network_df['Score'].median():.4f}") - + # Show sample print(f"\nSample edges (first 5):") print(network_df.head().to_string(index=False)) - + print(f"\n{'=' * 70}") print("Usage with Flexynesis:") print(f"{'=' * 70}") @@ -250,7 +262,7 @@ def generate_coexpression_network( def main(): parser = argparse.ArgumentParser( - description='Generate gene co-expression network from expression data', + description="Generate gene co-expression network from expression data", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: @@ -283,56 +295,58 @@ def main(): TP53,5.2,6.1,5.8 BRCA1,7.3,6.9,7.1 EGFR,8.1,8.3,8.0 -""" +""", ) - + parser.add_argument( - '--input', '-i', + "--input", + "-i", required=True, - help='Input gene expression file (CSV/TSV supported)' + help="Input gene expression file (CSV/TSV supported)", ) - + parser.add_argument( - '--output', '-o', - required=True, - help='Output network file (CSV/TSV supported)' + "--output", "-o", required=True, help="Output network file (CSV/TSV supported)" ) - + parser.add_argument( - '--method', '-m', - default='spearman', - choices=['spearman', 'pearson'], - help='Correlation method (default: spearman)' + "--method", + "-m", + default="spearman", + choices=["spearman", "pearson"], + help="Correlation method (default: spearman)", ) - + parser.add_argument( - '--min_correlation', '-c', + "--min_correlation", + "-c", type=float, default=0.3, - help='Minimum absolute correlation to include (default: 0.3)' + help="Minimum absolute correlation to include (default: 0.3)", ) - + parser.add_argument( - '--top_k', '-k', + "--top_k", + "-k", type=int, default=10, - help='Keep top K neighbors per gene (default: 10)' + help="Keep top K neighbors per gene (default: 10)", ) - + parser.add_argument( - '--keep_self_loops', - action='store_true', - help='Keep self-correlations (gene-gene) - not recommended' + "--keep_self_loops", + action="store_true", + help="Keep self-correlations (gene-gene) - not recommended", ) - + parser.add_argument( - '--keep_duplicates', - action='store_true', - help='Keep duplicate edges (A-B and B-A) - not recommended' + "--keep_duplicates", + action="store_true", + help="Keep duplicate edges (A-B and B-A) - not recommended", ) - + args = parser.parse_args() - + # Generate network generate_coexpression_network( input_file=args.input, @@ -341,7 +355,7 @@ def main(): min_correlation=args.min_correlation, top_k=args.top_k, remove_self_loops=not args.keep_self_loops, - remove_duplicates=not args.keep_duplicates + remove_duplicates=not args.keep_duplicates, ) diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 5b1dde5f..3b0eb67f 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -12,19 +12,20 @@ import torch from safetensors.torch import load_file - MODEL_REGISTRY = { - "DirectPred": ("flexynesis.models.direct_pred", "DirectPred"), - "supervised_vae": ("flexynesis.models.supervised_vae", "supervised_vae"), - "CrossModalPred": ("flexynesis.models.crossmodal_pred","CrossModalPred"), - "MultiTripletNetwork": ("flexynesis.models.triplet_encoder","MultiTripletNetwork"), - "GNN": ("flexynesis.models.gnn_early", "GNN"), + "DirectPred": ("flexynesis.models.direct_pred", "DirectPred"), + "supervised_vae": ("flexynesis.models.supervised_vae", "supervised_vae"), + "CrossModalPred": ("flexynesis.models.crossmodal_pred", "CrossModalPred"), + "MultiTripletNetwork": ("flexynesis.models.triplet_encoder", "MultiTripletNetwork"), + "GNN": ("flexynesis.models.gnn_early", "GNN"), } + def check_model_type(file_path): import struct import json - with open(file_path, 'rb') as f: + + with open(file_path, "rb") as f: header_start = f.read(8) if len(header_start) < 8: @@ -32,12 +33,12 @@ def check_model_type(file_path): # 1. Try SafeTensors check first try: - header_size = struct.unpack(' 0): + + if ( + platform.system() == "Darwin" + and os.environ.get("GITHUB_ACTIONS") == "true" + and num_workers > 0 + ): import warnings + warnings.warn( f"Detected macOS in GHA: Setting num_workers=0 (was {num_workers}) to avoid permission error in GHA.", - UserWarning + UserWarning, ) num_workers = 0 self.num_workers = num_workers - - self.DataLoader = torch.utils.data.DataLoader # use torch data loader by default - - if self.model_class.__name__ == 'MultiTripletNetwork': - self.loader_dataset = TripletMultiOmicDataset(self.dataset, self.target_variables[0]) + + self.DataLoader = ( + torch.utils.data.DataLoader + ) # use torch data loader by default + + if self.model_class.__name__ == "MultiTripletNetwork": + self.loader_dataset = TripletMultiOmicDataset( + self.dataset, self.target_variables[0] + ) # If config_path is provided, use it if config_path: @@ -135,55 +169,65 @@ def __init__(self, dataset, model_class, config_name, target_variables, if self.config_name in external_config: self.space = external_config[self.config_name] else: - raise ValueError(f"'{self.config_name}' not found in the provided config file.") + raise ValueError( + f"'{self.config_name}' not found in the provided config file." + ) else: if self.config_name in search_spaces: self.space = search_spaces[self.config_name] # get batch sizes (a function of dataset size) self.space.append(self.get_batch_space()) else: - raise ValueError(f"'{self.config_name}' not found in the default config.") + raise ValueError( + f"'{self.config_name}' not found in the default config." + ) - def get_batch_space(self, min_size = 32, max_size = 128): + def get_batch_space(self, min_size=32, max_size=128): m = int(np.log2(len(self.dataset) * 0.8)) st = int(np.log2(min_size)) end = int(np.log2(max_size)) if m < end: end = m - s = Categorical([np.power(2, x) for x in range(st, end+1)], name = 'batch_size') + s = Categorical([np.power(2, x) for x in range(st, end + 1)], name="batch_size") return s - - def setup_trainer(self, params, current_step, total_steps, full_train = False): + + def setup_trainer(self, params, current_step, total_steps, full_train=False): # Configure callbacks and trainer for the current fold mycallbacks = [] if self.plot_losses: - mycallbacks.append(LiveLossPlot(hyperparams=params, current_step=current_step, total_steps=total_steps)) + mycallbacks.append( + LiveLossPlot( + hyperparams=params, + current_step=current_step, + total_steps=total_steps, + ) + ) else: mycallbacks.append(self.progress_bar) - # when training on a full dataset; no cross-validation or no validation splits; + # when training on a full dataset; no cross-validation or no validation splits; # we don't do early stopping early_stop_callback = None if self.early_stop_patience > 0 and full_train == False: early_stop_callback = self.init_early_stopping() mycallbacks.append(early_stop_callback) - + trainer = pl.Trainer( - #deterministic = True, - #precision = '16-mixed', # mixed precision training - max_epochs=int(params['epochs']), - gradient_clip_val=1.0, - gradient_clip_algorithm='norm', + # deterministic = True, + # precision = '16-mixed', # mixed precision training + max_epochs=int(params["epochs"]), + gradient_clip_val=1.0, + gradient_clip_algorithm="norm", log_every_n_steps=5, callbacks=mycallbacks, default_root_dir="./", logger=False, enable_checkpointing=False, devices=1, - accelerator=self.device_type + accelerator=self.device_type, ) return trainer, early_stop_callback - - def objective(self, params, current_step, total_steps, full_train = False): + + def objective(self, params, current_step, total_steps, full_train=False): # Unpack or construct specific model arguments model_args = { "config": params, @@ -195,19 +239,26 @@ def objective(self, params, current_step, total_steps, full_train = False): "use_loss_weighting": self.use_loss_weighting, "device_type": self.device_type, } - - if self.model_class.__name__ == 'GNN': - model_args['gnn_conv_type'] = self.gnn_conv_type - if self.model_class.__name__ == 'CrossModalPred': - model_args['input_layers'] = self.input_layers - model_args['output_layers'] = self.output_layers + + if self.model_class.__name__ == "GNN": + model_args["gnn_conv_type"] = self.gnn_conv_type + if self.model_class.__name__ == "CrossModalPred": + model_args["input_layers"] = self.input_layers + model_args["output_layers"] = self.output_layers if full_train: # Train on the full dataset - full_loader = self.DataLoader(self.loader_dataset, batch_size=int(params['batch_size']), - shuffle=True, pin_memory=True, drop_last=True) + full_loader = self.DataLoader( + self.loader_dataset, + batch_size=int(params["batch_size"]), + shuffle=True, + pin_memory=True, + drop_last=True, + ) model = self.model_class(**model_args) - trainer, _ = self.setup_trainer(params, current_step, total_steps, full_train = True) + trainer, _ = self.setup_trainer( + params, current_step, total_steps, full_train=True + ) trainer.fit(model, train_dataloaders=full_loader) return model # Return the trained model @@ -215,39 +266,62 @@ def objective(self, params, current_step, total_steps, full_train = False): validation_losses = [] epochs = [] - if self.use_cv: # if the user asks for cross-validation + if self.use_cv: # if the user asks for cross-validation kf = KFold(n_splits=self.n_splits, shuffle=True) split_iterator = kf.split(self.loader_dataset) - else: # otherwise do a single train/validation split + else: # otherwise do a single train/validation split # Compute the number of samples for training based on the ratio num_val = int(len(self.loader_dataset) * self.val_size) num_train = len(self.loader_dataset) - num_val - train_subset, val_subset = random_split(self.loader_dataset, [num_train, num_val]) + train_subset, val_subset = random_split( + self.loader_dataset, [num_train, num_val] + ) # create single split format similar to KFold - split_iterator = [(list(train_subset.indices), list(val_subset.indices))] + split_iterator = [ + (list(train_subset.indices), list(val_subset.indices)) + ] i = 1 - model = None # save the model if not using cross-validation + model = None # save the model if not using cross-validation for train_index, val_index in split_iterator: - print(f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}") + print( + f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}" + ) train_subset = torch.utils.data.Subset(self.loader_dataset, train_index) val_subset = torch.utils.data.Subset(self.loader_dataset, val_index) - train_loader = self.DataLoader(train_subset, batch_size=int(params['batch_size']), - pin_memory=True, shuffle=True, drop_last=True, num_workers = self.num_workers, prefetch_factor = None, - persistent_workers = self.num_workers > 0) - val_loader = self.DataLoader(val_subset, batch_size=int(params['batch_size']), - pin_memory=True, shuffle=False, num_workers = self.num_workers, prefetch_factor = None, - persistent_workers = self.num_workers > 0) + train_loader = self.DataLoader( + train_subset, + batch_size=int(params["batch_size"]), + pin_memory=True, + shuffle=True, + drop_last=True, + num_workers=self.num_workers, + prefetch_factor=None, + persistent_workers=self.num_workers > 0, + ) + val_loader = self.DataLoader( + val_subset, + batch_size=int(params["batch_size"]), + pin_memory=True, + shuffle=False, + num_workers=self.num_workers, + prefetch_factor=None, + persistent_workers=self.num_workers > 0, + ) model = self.model_class(**model_args) - trainer, early_stop_callback = self.setup_trainer(params, current_step, total_steps) + trainer, early_stop_callback = self.setup_trainer( + params, current_step, total_steps + ) print(f"[INFO] hpo config:{params}") - trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + trainer.fit( + model, train_dataloaders=train_loader, val_dataloaders=val_loader + ) if early_stop_callback.stopped_epoch: epochs.append(early_stop_callback.stopped_epoch) else: - epochs.append(int(params['epochs'])) + epochs.append(int(params["epochs"])) validation_result = trainer.validate(model, dataloaders=val_loader) - val_loss = validation_result[0]['val_loss'] + val_loss = validation_result[0]["val_loss"] validation_losses.append(val_loss) i += 1 if not self.use_cv: @@ -256,27 +330,39 @@ def objective(self, params, current_step, total_steps, full_train = False): # Calculate average validation loss across all folds avg_val_loss = np.mean(validation_losses) avg_epochs = int(np.mean(epochs)) - return avg_val_loss, avg_epochs, model - - def perform_tuning(self, hpo_patience = 0): - opt = Optimizer(dimensions=self.space, n_initial_points=10, acq_func="gp_hedge", acq_optimizer="auto") + return avg_val_loss, avg_epochs, model + + def perform_tuning(self, hpo_patience=0): + opt = Optimizer( + dimensions=self.space, + n_initial_points=10, + acq_func="gp_hedge", + acq_optimizer="auto", + ) best_loss = np.inf best_params = None best_epochs = 0 best_model = None - # keep track of the streak of HPO iterations without improvement + # keep track of the streak of HPO iterations without improvement no_improvement_count = 0 - with tqdm(total=self.n_iter, desc='Tuning Progress') as pbar: + with tqdm(total=self.n_iter, desc="Tuning Progress") as pbar: for i in range(self.n_iter): np.int = int # Ensure int type is correctly handled suggested_params_list = opt.ask() - suggested_params_dict = {param.name: value for param, value in zip(self.space, suggested_params_list)} - loss, avg_epochs, model = self.objective(suggested_params_dict, current_step=i+1, total_steps=self.n_iter) + suggested_params_dict = { + param.name: value + for param, value in zip(self.space, suggested_params_list) + } + loss, avg_epochs, model = self.objective( + suggested_params_dict, current_step=i + 1, total_steps=self.n_iter + ) if self.use_cv: - print(f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}") + print( + f"[INFO] average 5-fold cross-validation loss {loss} for params: {suggested_params_dict}" + ) opt.tell(suggested_params_list, loss) if loss < best_loss: @@ -289,34 +375,48 @@ def perform_tuning(self, hpo_patience = 0): no_improvement_count += 1 # Increment the no improvement counter # Print result of each iteration - pbar.set_postfix({'Iteration': i+1, 'Best Loss': best_loss}) + pbar.set_postfix({"Iteration": i + 1, "Best Loss": best_loss}) pbar.update(1) # Early stopping condition if no_improvement_count >= hpo_patience & hpo_patience > 0: - print(f"No improvement in best loss for {hpo_patience} iterations, stopping hyperparameter optimisation early.") + print( + f"No improvement in best loss for {hpo_patience} iterations, stopping hyperparameter optimisation early." + ) break # Break out of the loop - best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} if best_params else None - print(f"[INFO] current best val loss: {best_loss}; best params: {best_params_dict} since {no_improvement_count} hpo iterations") - + best_params_dict = ( + {param.name: value for param, value in zip(self.space, best_params)} + if best_params + else None + ) + print( + f"[INFO] current best val loss: {best_loss}; best params: {best_params_dict} since {no_improvement_count} hpo iterations" + ) # Convert best parameters from list to dictionary and include epochs - best_params_dict = {param.name: value for param, value in zip(self.space, best_params)} - best_params_dict['epochs'] = best_epochs + best_params_dict = { + param.name: value for param, value in zip(self.space, best_params) + } + best_params_dict["epochs"] = best_epochs if self.use_cv: # Build a final model based on best parameters if using cross-validation - print(f"[INFO] Building a final model using best params: {best_params_dict}") - best_model = self.objective(best_params_dict, current_step=0, total_steps=1, full_train=True) + print( + f"[INFO] Building a final model using best params: {best_params_dict}" + ) + best_model = self.objective( + best_params_dict, current_step=0, total_steps=1, full_train=True + ) + + return best_model, best_params_dict - return best_model, best_params_dict def init_early_stopping(self): """Initialize the early stopping callback.""" return EarlyStopping( - monitor='val_loss', + monitor="val_loss", patience=self.early_stop_patience, verbose=False, - mode='min' + mode="min", ) def load_and_convert_config(self, config_path): @@ -325,8 +425,8 @@ def load_and_convert_config(self, config_path): raise ValueError(f"Config file '{config_path}' doesn't exist.") # Read the config file - if config_path.endswith('.yaml') or config_path.endswith('.yml'): - with open(config_path, 'r') as file: + if config_path.endswith(".yaml") or config_path.endswith(".yml"): + with open(config_path, "r") as file: loaded_config = yaml.safe_load(file) else: raise ValueError("Unsupported file format. Use .yaml or .yml") @@ -354,13 +454,14 @@ def load_and_convert_config(self, config_path): import numpy as np import random, copy, logging + class FineTuner(pl.LightningModule): """ - FineTuner class is designed for fine-tuning trained flexynesis models with flexible control over parameters such as + FineTuner class is designed for fine-tuning trained flexynesis models with flexible control over parameters such as learning rates and component freezing, utilizing cross-validation to optimize generalization. - This class allows the application of different configuration strategies to either freeze or unfreeze specific - model components, while also exploring different learning rates to find the optimal setting. + This class allows the application of different configuration strategies to either freeze or unfreeze specific + model components, while also exploring different learning rates to find the optimal setting. It carries out cross-validation to find the best combination of parameter freezing strategies and learning rates. Attributes: @@ -382,37 +483,55 @@ class FineTuner(pl.LightningModule): run_experiments(): Executes the finetuning process across all configurations and learning rates, evaluates using cross-validation, and selects the best configuration based on validation loss. """ - def __init__(self, model, dataset, n_splits=5, batch_size=32, learning_rates=None, max_epoch = 50, freeze_configs = None): + + def __init__( + self, + model, + dataset, + n_splits=5, + batch_size=32, + learning_rates=None, + max_epoch=50, + freeze_configs=None, + ): super().__init__() logging.getLogger("lightning.pytorch").setLevel(logging.ERROR) - self.original_model = model + self.original_model = model self.dataset = dataset # Use the entire dataset self.n_splits = n_splits self.batch_size = batch_size self.kfold = KFold(n_splits=self.n_splits, shuffle=True) - self.learning_rates = learning_rates if learning_rates else [model.config['lr'], model.config['lr']/10, model.config['lr']/100] - self.folds_data = list(self.kfold.split(np.arange(len(self.dataset)))) + self.learning_rates = ( + learning_rates + if learning_rates + else [model.config["lr"], model.config["lr"] / 10, model.config["lr"] / 100] + ) + self.folds_data = list(self.kfold.split(np.arange(len(self.dataset)))) self.max_epoch = max_epoch - self.freeze_configs = freeze_configs if freeze_configs else [ - {'encoders': True, 'supervisors': False}, - {'encoders': False, 'supervisors': True}, - {'encoders': False, 'supervisors': False} - ] - - if model.__class__.__name__ == 'MultiTripletNetwork': + self.freeze_configs = ( + freeze_configs + if freeze_configs + else [ + {"encoders": True, "supervisors": False}, + {"encoders": False, "supervisors": True}, + {"encoders": False, "supervisors": False}, + ] + ) + + if model.__class__.__name__ == "MultiTripletNetwork": # modify dataset structure to accommodate TripletNetworks self.dataset = TripletMultiOmicDataset(dataset, model.main_var) - + def apply_freeze_config(self, config): # Freeze or unfreeze encoders for encoder in self.model.encoders: for param in encoder.parameters(): - param.requires_grad = not config['encoders'] - + param.requires_grad = not config["encoders"] + # Freeze or unfreeze supervisors for mlp in self.model.MLPs.values(): for param in mlp.parameters(): - param.requires_grad = not config['supervisors'] + param.requires_grad = not config["supervisors"] def train_dataloader(self): # Override to load data for the current fold @@ -431,59 +550,96 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): # Call the model's validation step without logging - val_loss = self.model.validation_step(batch, batch_idx, log=False) + val_loss = self.model.validation_step(batch, batch_idx, log=False) self.log("val_loss", val_loss, on_epoch=True, prog_bar=True) return val_loss - + def configure_optimizers(self): - return torch.optim.Adam(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.learning_rate) + return torch.optim.Adam( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.learning_rate, + ) def run_experiments(self): val_loss_results = [] for lr in self.learning_rates: for config in self.freeze_configs: fold_losses = [] - epochs = [] # record how many epochs the training happened + epochs = [] # record how many epochs the training happened for fold in range(self.n_splits): - model_copy = copy.deepcopy(self.original_model) # Deep copy the model for each fold + model_copy = copy.deepcopy( + self.original_model + ) # Deep copy the model for each fold self.model = model_copy - self.apply_freeze_config(config) # try freezing different components + self.apply_freeze_config( + config + ) # try freezing different components self.current_fold = fold self.learning_rate = lr early_stopping = EarlyStopping( - monitor='val_loss', - patience=3, - verbose=False, - mode='min' + monitor="val_loss", patience=3, verbose=False, mode="min" + ) + trainer = pl.Trainer( + max_epochs=self.max_epoch, + devices=1, + accelerator="auto", + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + callbacks=[early_stopping], + ) + trainer.fit( + self, + train_dataloaders=self.train_dataloader(), + val_dataloaders=self.val_dataloader(), ) - trainer = pl.Trainer(max_epochs=self.max_epoch, devices=1, accelerator='auto', logger=False, enable_checkpointing=False, - enable_progress_bar = False, enable_model_summary=False, callbacks=[early_stopping]) - trainer.fit(self, train_dataloaders=self.train_dataloader(), val_dataloaders=self.val_dataloader()) stopped_epoch = early_stopping.stopped_epoch - val_loss = trainer.validate(self, dataloaders = self.val_dataloader(), verbose = False) - fold_losses.append(val_loss[0]['val_loss']) # Adjust based on your validation output format + val_loss = trainer.validate( + self, dataloaders=self.val_dataloader(), verbose=False + ) + fold_losses.append( + val_loss[0]["val_loss"] + ) # Adjust based on your validation output format epochs.append(stopped_epoch) - #print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}") + # print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}") avg_val_loss = np.mean(fold_losses) avg_epochs = int(np.mean(epochs)) - print(f"[INFO] average 5-fold cross-validation loss {avg_val_loss} for learning rate: {lr} freeze {config}, average epochs {avg_epochs}") - val_loss_results.append({'learning_rate': lr, 'average_val_loss': avg_val_loss, 'freeze': config, 'epochs': avg_epochs}) + print( + f"[INFO] average 5-fold cross-validation loss {avg_val_loss} for learning rate: {lr} freeze {config}, average epochs {avg_epochs}" + ) + val_loss_results.append( + { + "learning_rate": lr, + "average_val_loss": avg_val_loss, + "freeze": config, + "epochs": avg_epochs, + } + ) # Find the best configuration based on validation loss - best_config = min(val_loss_results, key=lambda x: x['average_val_loss']) - print(f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", - f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}") + best_config = min(val_loss_results, key=lambda x: x["average_val_loss"]) + print( + f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", + f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}", + ) # build a final model using the best setup on all samples final_model = copy.deepcopy(self.model) self.model = final_model - self.learning_rate = best_config['learning_rate'] - self.apply_freeze_config(best_config['freeze']) + self.learning_rate = best_config["learning_rate"] + self.apply_freeze_config(best_config["freeze"]) dl = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True) - final_trainer = pl.Trainer(max_epochs=best_config['epochs'], devices=1, accelerator='auto', logger=False, enable_checkpointing=False) + final_trainer = pl.Trainer( + max_epochs=best_config["epochs"], + devices=1, + accelerator="auto", + logger=False, + enable_checkpointing=False, + ) final_trainer.fit(self, train_dataloaders=dl) - - + + import matplotlib.pyplot as plt from IPython.display import display from lightning import Callback @@ -495,7 +651,7 @@ class LiveLossPlot(Callback): This class is a PyTorch Lightning callback that plots training loss and other metrics live as the model trains. It is especially useful for tracking the progress of hyperparameter optimization (HPO) steps. - + Compatible with papermill/nbclient execution by using display handles for proper output tracking. Attributes: @@ -510,6 +666,7 @@ class LiveLossPlot(Callback): on_train_epoch_end(trainer, pl_module): Updates and plots the loss after each training epoch. plot_losses(): Renders the loss plot with the current training metrics. """ + def __init__(self, hyperparams, current_step, total_steps, figsize=(8, 6)): super().__init__() self.hyperparams = hyperparams @@ -547,7 +704,9 @@ def plot_losses(self): ax.plot(epochs_range, losses_to_plot, label=key) - hyperparams_str = ', '.join(f"{key}={value}" for key, value in self.hyperparams.items()) + hyperparams_str = ", ".join( + f"{key}={value}" for key, value in self.hyperparams.items() + ) title = f"HPO Step={self.current_step} out of {self.total_steps}\n({hyperparams_str})" ax.set_title(title) @@ -555,12 +714,12 @@ def plot_losses(self): ax.set_ylabel("Loss") ax.legend() fig.tight_layout() - + # Use display with a handle for proper papermill/nbclient compatibility # This avoids the display_id assertion error if self.display_handle is None: self.display_handle = display(fig, display_id=True) else: self.display_handle.update(fig) - + plt.close(fig) diff --git a/flexynesis/models/__init__.py b/flexynesis/models/__init__.py index 2fad7361..a2582e01 100644 --- a/flexynesis/models/__init__.py +++ b/flexynesis/models/__init__.py @@ -3,4 +3,11 @@ from .triplet_encoder import MultiTripletNetwork from .crossmodal_pred import CrossModalPred from .gnn_early import GNN -__all__ = ["DirectPred", "supervised_vae", "MultiTripletNetwork", "CrossModalPred", "GNN"] + +__all__ = [ + "DirectPred", + "supervised_vae", + "MultiTripletNetwork", + "CrossModalPred", + "GNN", +] diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 1c004e50..92c6634e 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -1,5 +1,5 @@ import torch -import itertools +import itertools from torch import nn from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader, random_split @@ -18,21 +18,30 @@ class CrossModalPred(pl.LightningModule): """ - A Cross-Modality Prediction Architecture that encodes user-specified input data modalities and + A Cross-Modality Prediction Architecture that encodes user-specified input data modalities and tries to reconstruct user-specificed output data modalities. In the case where input/output data modalities - are the same, this behaves like an auto-encoder. - The network also can be connected to one or more MLPs for outcome variable prediction. - + are the same, this behaves like an auto-encoder. + The network also can be connected to one or more MLPs for outcome variable prediction. + dataset: dictionary of data matrices - input_layers: which data modalities from `dataset` to encode (use a subset of keys from `dataset`) - output_layers: which data modalities are aimed to be reconsructed via decoders (use a subset of keys from `dataset`). - + input_layers: which data modalities from `dataset` to encode (use a subset of keys from `dataset`) + output_layers: which data modalities are aimed to be reconsructed via decoders (use a subset of keys from `dataset`). + """ - def __init__(self, config, dataset, target_variables = None, batch_variables = None, - surv_event_var = None, surv_time_var = None, - input_layers = None, output_layers = None, - use_loss_weighting = True, - device_type = None): + + def __init__( + self, + config, + dataset, + target_variables=None, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + input_layers=None, + output_layers=None, + use_loss_weighting=True, + device_type=None, + ): super(CrossModalPred, self).__init__() self.config = config self.target_variables = target_variables @@ -43,56 +52,86 @@ def __init__(self, config, dataset, target_variables = None, batch_variables = if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + self.batch_variables if self.batch_variables else self.target_variables - self.variable_types = dataset.variable_types + self.variables = ( + self.target_variables + self.batch_variables + if self.batch_variables + else self.target_variables + ) + self.variable_types = dataset.variable_types self.ann = dataset.ann - - self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) - self.output_layers = output_layers if output_layers else list(dataset.dat.keys()) - + + self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) + self.output_layers = ( + output_layers if output_layers else list(dataset.dat.keys()) + ) + self.feature_importances = {} - + self.device_type = device_type self.use_loss_weighting = use_loss_weighting - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() - for loss_type in itertools.chain(self.variables, ['mmd_loss']): + for loss_type in itertools.chain(self.variables, ["mmd_loss"]): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - - # create a list of Encoder instances for separately encoding each input omics layer - input_dims = [len(dataset.features[self.input_layers[i]]) for i in range(len(self.input_layers))] - self.encoders = nn.ModuleList([Encoder(input_dims[i], - # define hidden_dim size as a factor of input_dim - [int(input_dims[i] * config['hidden_dim_factor'])], - config['latent_dim']) - for i in range(len(self.input_layers))]) - + + # create a list of Encoder instances for separately encoding each input omics layer + input_dims = [ + len(dataset.features[self.input_layers[i]]) + for i in range(len(self.input_layers)) + ] + self.encoders = nn.ModuleList( + [ + Encoder( + input_dims[i], + # define hidden_dim size as a factor of input_dim + [int(input_dims[i] * config["hidden_dim_factor"])], + config["latent_dim"], + ) + for i in range(len(self.input_layers)) + ] + ) + # Fully connected layers for concatenated means and log_vars - self.FC_mean = nn.Linear(len(self.input_layers) * config['latent_dim'], config['latent_dim']) - self.FC_log_var = nn.Linear(len(self.input_layers) * config['latent_dim'], config['latent_dim']) - - # list of decoders to decode the latent layer into the target/output layers - output_dims = [len(dataset.features[self.output_layers[i]]) for i in range(len(self.output_layers))] - self.decoders = nn.ModuleList([Decoder(config['latent_dim'], - [int(output_dims[i] * config['hidden_dim_factor'])], - output_dims[i]) - for i in range(len(self.output_layers))]) + self.FC_mean = nn.Linear( + len(self.input_layers) * config["latent_dim"], config["latent_dim"] + ) + self.FC_log_var = nn.Linear( + len(self.input_layers) * config["latent_dim"], config["latent_dim"] + ) + + # list of decoders to decode the latent layer into the target/output layers + output_dims = [ + len(dataset.features[self.output_layers[i]]) + for i in range(len(self.output_layers)) + ] + self.decoders = nn.ModuleList( + [ + Decoder( + config["latent_dim"], + [int(output_dims[i] * config["hidden_dim_factor"])], + output_dims[i], + ) + for i in range(len(self.output_layers)) + ] + ) # define supervisor heads # using ModuleDict to store multiple MLPs - self.MLPs = nn.ModuleDict() + self.MLPs = nn.ModuleDict() for var in self.variables: - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[var])) - self.MLPs[var] = MLP(input_dim = config['latent_dim'], - hidden_dim = config['supervisor_hidden_dim'], - output_dim = num_class) - + self.MLPs[var] = MLP( + input_dim=config["latent_dim"], + hidden_dim=config["supervisor_hidden_dim"], + output_dim=num_class, + ) + def multi_encoder(self, x_list): """ Encode each input matrix separately using the corresponding Encoder. @@ -117,7 +156,7 @@ def multi_encoder(self, x_list): mean = self.FC_mean(torch.cat(means, dim=1)) log_var = self.FC_log_var(torch.cat(log_vars, dim=1)) return mean, log_var - + def forward(self, x_list_input): """ Forward pass through the model. @@ -134,20 +173,20 @@ def forward(self, x_list_input): - y_pred (torch.Tensor): Predicted output. """ mean, log_var = self.multi_encoder(x_list_input) - + # generate latent layer z = self.reparameterization(mean, log_var) # decode the latent space to target output layer(s) x_hat_list = [self.decoders[i](z) for i in range(len(self.output_layers))] - #run the supervisor heads using the latent layer as input + # run the supervisor heads using the latent layer as input outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(z) - + return x_hat_list, z, mean, log_var, outputs - + def reparameterization(self, mean, var): """ Reparameterize the mean and variance values. @@ -159,10 +198,10 @@ def reparameterization(self, mean, var): Returns: torch.Tensor: Latent representation. """ - epsilon = torch.randn_like(var) - z = mean + var*epsilon + epsilon = torch.randn_like(var) + z = mean + var * epsilon return z - + def configure_optimizers(self): """ Configure the optimizer for the model. @@ -170,63 +209,69 @@ def configure_optimizers(self): Returns: torch.optim.Adam: Adam optimizer with learning rate 1e-3. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -235,28 +280,30 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Executes one training step using a single batch of data from a cross-modality prediction model. - + Args: train_batch (tuple): The batch data containing input features and target labels. batch_idx (int): The index of the current batch. log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. - + Returns: torch.Tensor: The total loss for the current training batch, combining MMD loss for latent space regularization, reconstruction losses for each output layer, and losses from supervisor heads for specified target variables. - + This method processes the batch by encoding input features from specified layers, decoding them to reconstruct the output layers, and calculating the Maximum Mean Discrepancy (MMD) loss for latent space regularization. It computes the reconstruction loss for each target/output layer. Additional losses are computed for other target variables in the @@ -265,50 +312,53 @@ def training_step(self, train_batch, batch_idx, log = True): """ dat, y_dict, samples = train_batch - # get input omics modalities and encode them; decode them to output layers + # get input omics modalities and encode them; decode them to output layers x_list_input = [dat[x] for x in self.input_layers] x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) - + # compute mmd loss for the latent space + reconsruction loss for each target/output layer x_list_output = [dat[x] for x in self.output_layers] - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) + for i in range(len(self.output_layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} - + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} + for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = self.compute_total_loss(losses) - # add total loss for logging - losses['train_loss'] = total_loss + # add total loss for logging + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Executes one validation step using a single batch of data, assessing the model's performance on the validation set. - + Args: val_batch (tuple): The batch data containing input features and target labels for validation. batch_idx (int): The index of the current batch in the validation process. log (bool, optional): Indicates whether to log the validation losses during this step. Defaults to True. - + Returns: torch.Tensor: The total loss for the current validation batch, calculated by combining MMD loss, reconstruction losses, and losses from supervisor heads for specified target variables. - + In this method, the model processes input data by encoding it through specified input layers and decoding it to targeted output layers. It computes the Maximum Mean Discrepancy (MMD) loss to measure the divergence between the model's latent representations and a predefined distribution, along with reconstruction losses for output layers. @@ -321,32 +371,35 @@ def validation_step(self, val_batch, batch_idx, log = True): # get input omics modalities and encode them x_list_input = [dat[x] for x in self.input_layers] x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) - + # compute mmd loss for the latent space + reconsruction loss for each target/output layer x_list_output = [dat[x] for x in self.output_layers] - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) for i in range(len(self.output_layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list_output[i]) + for i in range(len(self.output_layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - + def transform(self, dataset): """ Transform the input dataset to latent representation. @@ -361,10 +414,10 @@ def transform(self, dataset): x_list_input = [dataset.dat[x] for x in self.input_layers] M = self.forward(x_list_input)[1].detach().numpy() z = pd.DataFrame(M) - z.columns = [''.join(['E', str(x)]) for x in z.columns] + z.columns = ["".join(["E", str(x)]) for x in z.columns] z.index = dataset.samples return z - + def predict(self, dataset): """ Evaluate the model on a dataset. @@ -376,29 +429,34 @@ def predict(self, dataset): predicted values. """ self.eval() - + x_list_input = [dataset.dat[x] for x in self.input_layers] X_hat, z, mean, log_var, outputs = self.forward(x_list_input) - - predictions = {var: [] for var in self.variables} # Initialize prediction storage + + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} - + return predictions - - + def decode(self, dataset): """ - Extract the decoded values of the target/output layers + Extract the decoded values of the target/output layers """ self.eval() x_list_input = [dataset.dat[x] for x in self.input_layers] @@ -409,11 +467,10 @@ def decode(self, dataset): x = pd.DataFrame(x_hat_list[i].detach().numpy()).transpose() layer = self.output_layers[i] x.columns = dataset.samples - x.index = dataset.features[layer] + x.index = dataset.features[layer] X[layer] = x return X - def compute_kernel(self, x, y): """ Compute the Gaussian kernel matrix between two sets of vectors. @@ -428,12 +485,12 @@ def compute_kernel(self, x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) - x = x.unsqueeze(1) # (x_size, 1, dim) - y = y.unsqueeze(0) # (1, y_size, dim) + x = x.unsqueeze(1) # (x_size, 1, dim) + y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) - kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) - return torch.exp(-kernel_input) # (x_size, y_size) + kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim) + return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(self, x, y): """ @@ -449,7 +506,7 @@ def compute_mmd(self, x, y): x_kernel = self.compute_kernel(x, x) y_kernel = self.compute_kernel(y, y) xy_kernel = self.compute_kernel(x, y) - mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() + mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() return mmd def MMD_loss(self, latent_dim, z, xhat, x): @@ -465,25 +522,34 @@ def MMD_loss(self, latent_dim, z, xhat, x): Returns: torch.Tensor: A scalar tensor representing the MMD loss. """ - true_samples = torch.randn(200, latent_dim, device = self.device) - mmd = self.compute_mmd(true_samples, z) # compute maximum mean discrepancy (MMD) - nll = (xhat - x).pow(2).mean() #negative log likelihood - return mmd+nll - - # Adaptor forward function for captum integrated gradients. + true_samples = torch.randn(200, latent_dim, device=self.device) + mmd = self.compute_mmd( + true_samples, z + ) # compute maximum mean discrepancy (MMD) + nll = (xhat - x).pow(2).mean() # negative log likelihood + return mmd + nll + + # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps for IntegratedGradients().attribute + steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] for i in range(steps): # get list of tensors for each step into a list of tensors x_step = [input_data[j][i] for j in range(len(input_data))] x_hat_list, z, mean, log_var, outputs = self.forward(x_step) outputs_list.append(outputs[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. @@ -500,27 +566,41 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) - - print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) - + + print( + "[INFO] Computing feature importance for variable:", + target_var, + "on device:", + device, + ) + # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) # Get the number of classes for the target variable - if self.variable_types[target_var] == 'numerical': + if self.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[target_var])) @@ -528,49 +608,63 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) aggregated_attributions = [[] for _ in range(num_class)] - + for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in self.input_layers] input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( - torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0) + torch.cat( + [torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0 + ) for x in input_data ) if num_class == 1: - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) - # For each target class and for each data modality/layer, concatenate attributions accross batches + # For each target class and for each data modality/layer, concatenate attributions accross batches layers = list(self.input_layers) num_layers = len(layers) - processed_attributions = [] + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -583,32 +677,44 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - + # summarize feature importances # Compute absolute attributions # Move the processed tensors to CPU for further operations that are not supported on GPU - abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in processed_attributions] - # average over samples + abs_attr = [ + [torch.abs(a).cpu() for a in attr_class] + for attr_class in processed_attributions + ] + # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') + self.to("cpu") - # combine into a single data frame + # combine into a single data frame df_list = [] for i in range(num_class): for j in range(len(layers)): features = dataset.features[layers[j]] importances = imp[i][j][0].detach().numpy() - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, 'importance': importances})) - df_imp = pd.concat(df_list, ignore_index = True) - + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) + df_imp = pd.concat(df_list, ignore_index=True) + # save scores in model self.feature_importances[target_var] = df_imp - - diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 4865994a..565f7c05 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -15,6 +15,7 @@ from ..modules import * from ..utils import to_device_safe + class DirectPred(pl.LightningModule): """ A fully connected network for multi-omics integration with supervisor heads. @@ -30,9 +31,17 @@ class DirectPred(pl.LightningModule): device_type (str, optional): Type of device to run the model ('gpu' or 'cpu'). Defaults to None. """ - def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, use_loss_weighting = True, - device_type = None): + def __init__( + self, + config, + dataset, + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type=None, + ): super(DirectPred, self).__init__() self.config = config self.target_variables = target_variables @@ -43,11 +52,15 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + self.variables = ( + self.target_variables + batch_variables + if batch_variables + else self.target_variables + ) self.feature_importances = {} self.use_loss_weighting = use_loss_weighting self.device_type = device_type - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() @@ -57,31 +70,43 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, self.variable_types = dataset.variable_types self.ann = dataset.ann self.layers = list(dataset.dat.keys()) - self.input_dims = [len(dataset.features[self.layers[i]]) for i in range(len(self.layers))] - - self.encoders = nn.ModuleList([ - MLP(input_dim=self.input_dims[i], - # define hidden_dim size relative to the input_dim size - hidden_dim=int(self.input_dims[i] * self.config['hidden_dim_factor']), - output_dim=self.config['latent_dim']) for i in range(len(self.layers))]) + self.input_dims = [ + len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) + ] + + self.encoders = nn.ModuleList( + [ + MLP( + input_dim=self.input_dims[i], + # define hidden_dim size relative to the input_dim size + hidden_dim=int( + self.input_dims[i] * self.config["hidden_dim_factor"] + ), + output_dim=self.config["latent_dim"], + ) + for i in range(len(self.layers)) + ] + ) if len(self.input_dims) > 1: self.fusion_block = nn.Linear( - in_features=self.config['latent_dim'] * len(self.layers), - out_features=self.config['latent_dim'] + in_features=self.config["latent_dim"] * len(self.layers), + out_features=self.config["latent_dim"], ) else: self.fusion_block = None - + self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs for var in self.variables: - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[var])) - self.MLPs[var] = MLP(input_dim=self.config['latent_dim'], - hidden_dim=self.config['supervisor_hidden_dim'], - output_dim=num_class) + self.MLPs[var] = MLP( + input_dim=self.config["latent_dim"], + hidden_dim=self.config["supervisor_hidden_dim"], + output_dim=num_class, + ) def forward(self, x_list): """ @@ -98,15 +123,18 @@ def forward(self, x_list): for i, x in enumerate(x_list): embeddings_list.append(self.encoders[i](x)) embeddings_concat = torch.cat(embeddings_list, dim=1) - # if multiple embeddings, fuse them - embeddings = self.fusion_block(embeddings_concat) if self.fusion_block else embeddings_concat + # if multiple embeddings, fuse them + embeddings = ( + self.fusion_block(embeddings_concat) + if self.fusion_block + else embeddings_concat + ) outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(embeddings) - return outputs - - + return outputs + def configure_optimizers(self): """ Configure the optimizer for the DirectPred model. @@ -115,63 +143,69 @@ def configure_optimizers(self): torch.optim.Optimizer: The configured optimizer. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -180,15 +214,18 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Executes one training step using a single batch from the training dataset. @@ -200,8 +237,8 @@ def training_step(self, train_batch, batch_idx, log = True): Returns: torch.Tensor: The total loss computed for the batch. """ - - dat, y_dict, samples = train_batch + + dat, y_dict, samples = train_batch layers = dat.keys() x_list = [dat[x] for x in layers] outputs = self.forward(x_list) @@ -209,23 +246,23 @@ def training_step(self, train_batch, batch_idx, log = True): for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = self.compute_total_loss(losses) # add train loss for logging - losses['train_loss'] = total_loss + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Executes one validation step using a single batch from the validation dataset. @@ -237,7 +274,7 @@ def validation_step(self, val_batch, batch_idx, log = True): Returns: torch.Tensor: The total loss computed for the batch. """ - dat, y_dict, samples = val_batch + dat, y_dict, samples = val_batch layers = dat.keys() x_list = [dat[x] for x in layers] outputs = self.forward(x_list) @@ -245,8 +282,8 @@ def validation_step(self, val_batch, batch_idx, log = True): for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] @@ -254,7 +291,7 @@ def validation_step(self, val_batch, batch_idx, log = True): loss = self.compute_loss(var, y, y_hat) losses[var] = loss total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss @@ -271,18 +308,29 @@ def predict(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device # Create a DataLoader with a practical batch size - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Process each batch for batch in dataloader: dat, y_dict, samples = batch - x_list = [to_device_safe(dat[x], device) for x in dat.keys()] # Prepare the data batch for processing + x_list = [ + to_device_safe(dat[x], device) for x in dat.keys() + ] # Prepare the data batch for processing # Perform the forward pass outputs = self.forward(x_list) @@ -290,17 +338,21 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} return predictions - + def transform(self, dataset): """ Transforms the input data into a lower-dimensional representation using trained encoders. @@ -313,10 +365,17 @@ def transform(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed embeddings_list = [] # Initialize a list to collect all batch embeddings sample_names = [] # List to collect sample names @@ -328,40 +387,59 @@ def transform(self, dataset): # Process each input matrix with its corresponding Encoder for i, x in enumerate(dat.values()): x = to_device_safe(x, device) # Move data to GPU - encoded_x = self.encoders[i](x) # Transform data using the corresponding encoder + encoded_x = self.encoders[i]( + x + ) # Transform data using the corresponding encoder batch_embeddings.append(encoded_x) - + # Concatenate all embeddings from the current batch embeddings_batch_concat = torch.cat(batch_embeddings, dim=1) - # if multiple embeddings, fuse them - embeddings_batch = self.fusion_block(embeddings_batch_concat) if self.fusion_block else embeddings_batch_concat + # if multiple embeddings, fuse them + embeddings_batch = ( + self.fusion_block(embeddings_batch_concat) + if self.fusion_block + else embeddings_batch_concat + ) - embeddings_list.append(embeddings_batch.detach().cpu()) # Move tensor back to CPU and detach + embeddings_list.append( + embeddings_batch.detach().cpu() + ) # Move tensor back to CPU and detach sample_names.extend(samples) # Collect sample names # Concatenate all batch embeddings into one tensor embeddings_concat = torch.cat(embeddings_list, dim=0) # Converting tensor to numpy array and then to DataFrame - embeddings_df = pd.DataFrame(embeddings_concat.numpy(), - index=sample_names, # Set DataFrame index to sample names - columns=[f"E{dim}" for dim in range(embeddings_concat.shape[1])]) + embeddings_df = pd.DataFrame( + embeddings_concat.numpy(), + index=sample_names, # Set DataFrame index to sample names + columns=[f"E{dim}" for dim in range(embeddings_concat.shape[1])], + ) return embeddings_df - - # Adaptor forward function for captum integrated gradients or gradient shap + + # Adaptor forward function for captum integrated gradients or gradient shap def forward_target(self, *args): input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps/samples for IntegratedGradients().attribute or GradientShap.attribute + steps = args[ + -1 + ] # number of steps/samples for IntegratedGradients().attribute or GradientShap.attribute outputs_list = [] for i in range(steps): # get list of tensors for each step into a list of tensors x_step = [input_data[j][i] for j in range(len(input_data))] out = self.forward(x_step) outputs_list.append(out[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. @@ -378,13 +456,20 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) @@ -395,10 +480,12 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) # Handle target class (numerical vs categorical) - if dataset.variable_types[target_var] == 'numerical': + if dataset.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique([y[target_var] for _, y, _ in dataset])) @@ -408,38 +495,52 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( - torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0) + torch.cat( + [torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0 + ) for x in input_data ) if num_class == 1: # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) # Post-process attributions layers = list(dataset.dat.keys()) @@ -454,9 +555,12 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in processed_attributions] + abs_attr = [ + [torch.abs(a).cpu() for a in attr_class] + for attr_class in processed_attributions + ] imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] - self.to('cpu') + self.to("cpu") # Combine results into a DataFrame df_list = [] @@ -464,13 +568,22 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad for j in range(len(layers)): features = dataset.features[layers[j]] importances = imp[i][j][0].detach().numpy() - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, - 'importance': importances})) + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) df_imp = pd.concat(df_list, ignore_index=True) self.feature_importances[target_var] = df_imp - diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index fb4be5f8..b4fe571e 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -53,18 +53,19 @@ class GNN(pl.LightningModule): use_loss_weighting (bool, optional): Whether to use uncertainty weighting in loss calculation. Defaults to True. device_type (str, optional): Specifies the computation device ('gpu' or 'cpu'). Default is None, which uses 'cpu' if 'gpu' is not available. gnn_conv_type (str, optional): Specifies the type of graph convolutional layer to use. - """ + """ + def __init__( self, config, - dataset, # MultiomicDatasetNW object + dataset, # MultiomicDatasetNW object target_variables, batch_variables=None, surv_event_var=None, surv_time_var=None, use_loss_weighting=True, - device_type = None, - gnn_conv_type = None + device_type=None, + gnn_conv_type=None, ): super().__init__() self.config = config @@ -76,37 +77,57 @@ def __init__( if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + self.batch_variables if self.batch_variables else self.target_variables - self.variable_types = getattr(dataset, 'multiomic_dataset', dataset).variable_types - self.ann = getattr(dataset, 'multiomic_dataset', dataset).ann + self.variables = ( + self.target_variables + self.batch_variables + if self.batch_variables + else self.target_variables + ) + self.variable_types = getattr( + dataset, "multiomic_dataset", dataset + ).variable_types + self.ann = getattr(dataset, "multiomic_dataset", dataset).ann self.edge_index = dataset.edge_index - + self.feature_importances = {} self.use_loss_weighting = use_loss_weighting - self.device_type = device_type + self.device_type = device_type self.gnn_conv_type = gnn_conv_type - + from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - self.edge_index = to_device_safe(self.edge_index, device) # edge index is re-used across samples, so we keep it in device - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + self.edge_index = to_device_safe( + self.edge_index, device + ) # edge index is re-used across samples, so we keep it in device + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() for var in self.variables: self.log_vars[var] = nn.Parameter(torch.zeros(1)) - - self.encoders = nn.ModuleList([ - flexGCN( - node_count = dataset[0][0].shape[0], #number of nodes - node_feature_count= dataset[0][0].shape[1], # number of node features - node_embedding_dim=int(self.config["node_embedding_dim"]), - num_convs = int(self.config['num_convs']), # Number of convolutional layers - output_dim=self.config["latent_dim"], - act = self.config['activation'], - conv = self.gnn_conv_type - )]) + + self.encoders = nn.ModuleList( + [ + flexGCN( + node_count=dataset[0][0].shape[0], # number of nodes + node_feature_count=dataset[0][0].shape[ + 1 + ], # number of node features + node_embedding_dim=int(self.config["node_embedding_dim"]), + num_convs=int( + self.config["num_convs"] + ), # Number of convolutional layers + output_dim=self.config["latent_dim"], + act=self.config["activation"], + conv=self.gnn_conv_type, + ) + ] + ) # Init output layers self.MLPs = nn.ModuleDict() @@ -118,10 +139,10 @@ def __init__( self.MLPs[var] = MLP( input_dim=self.config["latent_dim"], hidden_dim=self.config["supervisor_hidden_dim"], - output_dim=num_class + output_dim=num_class, ) - - def forward(self, x, edge_index): + + def forward(self, x, edge_index): """ Defines the forward pass of the GNN. @@ -139,8 +160,7 @@ def forward(self, x, edge_index): outputs[var] = mlp(embeddings) return outputs - - def training_step(self, batch, batch_idx, log = True): + def training_step(self, batch, batch_idx, log=True): """ Performs a training step including loss calculation and logging. @@ -169,10 +189,16 @@ def training_step(self, batch, batch_idx, log = True): total_loss = self.compute_total_loss(losses) losses["train_loss"] = total_loss if log: - self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch)) + self.log_dict( + losses, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(batch), + ) return total_loss - def validation_step(self, batch, batch_idx, log = True): + def validation_step(self, batch, batch_idx, log=True): """ Performs a validation step, computing losses for a batch of data. @@ -200,9 +226,15 @@ def validation_step(self, batch, batch_idx, log = True): total_loss = sum(losses.values()) losses["val_loss"] = total_loss if log: - self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True, batch_size=len(batch)) + self.log_dict( + losses, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=len(batch), + ) return total_loss - + def configure_optimizers(self): """ Configure the optimizer for the DirectPred model. @@ -210,23 +242,23 @@ def configure_optimizers(self): Returns: torch.optim.Optimizer: The configured optimizer. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for @@ -235,17 +267,23 @@ def compute_loss(self, var, y, y_hat): if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) @@ -259,14 +297,14 @@ def compute_total_loss(self, losses): either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -278,13 +316,14 @@ def compute_total_loss(self, losses): # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance total_loss = sum( - torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items() + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - + def predict(self, dataset): """ Make predictions on an entire dataset. @@ -297,29 +336,42 @@ def predict(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device # Create a DataLoader with a practical batch size - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed edge_index = dataset.edge_index.to(device) # Move edge_index to GPU - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Process each batch - for x, y_dict,samples in dataloader: + for x, y_dict, samples in dataloader: x = x.to(device) # Move data to GPU outputs = self.forward(x, edge_index) for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} return predictions @@ -336,11 +388,18 @@ def transform(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device edge_index = dataset.edge_index.to(device) # Move edge_index to GPU - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed all_embeddings = [] # List to store embeddings from all batches sample_ids = [] # List to store indices for all samples processed @@ -349,10 +408,12 @@ def transform(self, dataset): x = x.to(device) # Move data to GPU # notice we are using the first encoder (it is currently a early fusion method) - embeddings = self.encoders[0](x, edge_index).detach().cpu().numpy() # Compute embeddings and move to CPU + embeddings = ( + self.encoders[0](x, edge_index).detach().cpu().numpy() + ) # Compute embeddings and move to CPU all_embeddings.append(embeddings) - sample_ids.extend(samples) - + sample_ids.extend(samples) + # Concatenate all embeddings into a single numpy array all_embeddings = np.vstack(all_embeddings) @@ -363,22 +424,28 @@ def transform(self, dataset): columns=[f"E{dim}" for dim in range(all_embeddings.shape[1])], ) return embeddings_df - - # Adaptor forward function for captum integrated gradients. + + # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): input_data = list(args[:-2]) # expect a single tensor (early integration) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps for IntegratedGradients().attribute + steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] for i in range(steps): - x_step = input_data[0][i] - #edges_step = edge_index[i] # although, identical, they get copied. + x_step = input_data[0][i] + # edges_step = edge_index[i] # although, identical, they get copied. out = self.forward(x_step, self.dataset_edge_index) outputs_list.append(out[target_var]) - return torch.cat(outputs_list, dim = 0) + return torch.cat(outputs_list, dim=0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. @@ -394,43 +461,58 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad Returns: pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ + def bytes_to_gb(bytes): - return bytes / 1024 ** 2 + return bytes / 1024**2 + from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) self.dataset_edge_index = to_device_safe(dataset.edge_index, device) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) - + # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) - if dataset.variable_types[target_var] == 'numerical': + if dataset.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique(dataset.ann[target_var])) - + # Report memory usage based on device type - if device.type == 'cuda': - print("Memory before batch processing: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + if device.type == "cuda": + print( + "Memory before batch processing: {:.3f} MB".format( + bytes_to_gb(torch.cuda.max_memory_reserved()) + ) + ) # MPS and CPU don't have detailed memory tracking like CUDA aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: x, y_dict, samples = batch - + # Ensure input data is on the correct device x = to_device_safe(x, device) input_data = x.unsqueeze(0).requires_grad_() @@ -439,38 +521,53 @@ def bytes_to_gb(bytes): x = to_device_safe(x, device) input_data = x.unsqueeze(0).requires_grad_() baseline = torch.zeros_like(input_data) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = torch.zeros_like(input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap - baseline = torch.cat([torch.zeros_like(input_data) for _ in range(steps_or_samples)], dim=0) - + elif method == "GradientShap": # provide multiple baselines for Gr.Shap + baseline = torch.cat( + [torch.zeros_like(input_data) for _ in range(steps_or_samples)], + dim=0, + ) + if num_class == 1: # returns a tuple of tensors (one per data modality) - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) - # For each target class concatenate node attributions accross batches - processed_attributions = [] + # For each target class concatenate node attributions accross batches + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -478,37 +575,57 @@ def bytes_to_gb(bytes): attr_concat = torch.cat([batch_attr for batch_attr in class_attr], dim=1) processed_attributions.append(attr_concat) - # compute absolute importance and move to cpu - abs_attr = [torch.abs(attr_class).cpu() for attr_class in processed_attributions] - # average over samples + # compute absolute importance and move to cpu + abs_attr = [ + torch.abs(attr_class).cpu() for attr_class in processed_attributions + ] + # average over samples imp = [a.mean(dim=1) for a in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') - + self.to("cpu") + # Report memory usage based on device type - if device.type == 'cuda': - print("Memory after batch processing: {:.3f} MB".format(bytes_to_gb(torch.cuda.max_memory_reserved()))) + if device.type == "cuda": + print( + "Memory after batch processing: {:.3f} MB".format( + bytes_to_gb(torch.cuda.max_memory_reserved()) + ) + ) # MPS and CPU don't have detailed memory tracking like CUDA df_list = [] - layers = list(getattr(dataset, 'multiomic_dataset', dataset).dat.keys()) + layers = list(getattr(dataset, "multiomic_dataset", dataset).dat.keys()) for i in range(num_class): features = dataset.common_features - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - for l in range(len(layers)): + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + for l in range(len(layers)): # Extracting node feature attributes coming from different omic layers importances_array = imp[i].squeeze().detach().numpy() if importances_array.ndim == 1: - importances = importances_array # Use the array as is if it is 1-dimensional + importances = ( + importances_array # Use the array as is if it is 1-dimensional + ) else: - importances = importances_array[:, l] # Use the original indexing for 2D arrays - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[l], - 'name': features, - 'importance': importances})) + importances = importances_array[ + :, l + ] # Use the original indexing for 2D arrays + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[l], + "name": features, + "importance": importances, + } + ) + ) df_imp = pd.concat(df_list, ignore_index=True) # save the computed scores in the model - self.feature_importances[target_var] = df_imp \ No newline at end of file + self.feature_importances[target_var] = df_imp diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index 57609e9c..b679ab39 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -1,6 +1,6 @@ # Supervised VAE-MMD architecture import torch -import itertools +import itertools from torch import nn from torch.nn import functional as F from torch.utils.data import Dataset, DataLoader, random_split @@ -16,9 +16,10 @@ from ..utils import to_device_safe from ..modules import * -# Supervised Variational Auto-encoder that can train one or more layers of omics datasets + +# Supervised Variational Auto-encoder that can train one or more layers of omics datasets # num_layers: number of omics layers in the input -# each layer is encoded separately, encodings are concatenated, and decoded separately +# each layer is encoded separately, encodings are concatenated, and decoded separately # depends on MLP, Encoder, Decoder classes in models_shared class supervised_vae(pl.LightningModule): """ @@ -28,7 +29,7 @@ class supervised_vae(pl.LightningModule): Each omics layer is encoded separately using an Encoder. The resulting latent representations are then concatenated and passed through a fully connected network (fusion layer) to make predictions. The model also can be attached to one ore more supervisor heads for regression/classification/survival tasks. - In the absence of supervisor heads, it can be used for unsupervised learning. + In the absence of supervisor heads, it can be used for unsupervised learning. Attributes: config (dict): Configuration settings for the model, including learning rates and dimensions. @@ -40,9 +41,18 @@ class supervised_vae(pl.LightningModule): use_loss_weighting (bool, optional): Whether to use loss weighting in the model. Defaults to True. device_type (str, optional): Type of device to run the model ('gpu' or 'cpu'). Defaults to None. """ - def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, use_loss_weighting = True, - device_type = None): + + def __init__( + self, + config, + dataset, + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type=None, + ): super(supervised_vae, self).__init__() self.config = config self.dataset = dataset @@ -54,50 +64,74 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + self.variables = ( + self.target_variables + batch_variables + if batch_variables + else self.target_variables + ) self.feature_importances = {} - + # sometimes the model may have exploding/vanishing gradients leading to NaN values - self.nan_detected = False + self.nan_detected = False self.device_type = device_type self.use_loss_weighting = use_loss_weighting - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() - for loss_type in itertools.chain(self.variables, ['mmd_loss']): + for loss_type in itertools.chain(self.variables, ["mmd_loss"]): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - + layers = list(dataset.dat.keys()) input_dims = [len(dataset.features[layers[i]]) for i in range(len(layers))] # create a list of Encoder instances for separately encoding each omics layer - self.encoders = nn.ModuleList([Encoder(input_dims[i], - # define hidden_dim size as a factor of input_dim; keep at least 2 units - [max(int(input_dims[i] * config['hidden_dim_factor']), 2)], - config['latent_dim']) for i in range(len(layers))]) + self.encoders = nn.ModuleList( + [ + Encoder( + input_dims[i], + # define hidden_dim size as a factor of input_dim; keep at least 2 units + [max(int(input_dims[i] * config["hidden_dim_factor"]), 2)], + config["latent_dim"], + ) + for i in range(len(layers)) + ] + ) # Fully connected layers for concatenated means and log_vars - self.FC_mean = nn.Linear(len(layers) * config['latent_dim'], config['latent_dim']) - self.FC_log_var = nn.Linear(len(layers) * config['latent_dim'], config['latent_dim']) - # list of decoders to decode each omics layer separately - self.decoders = nn.ModuleList([Decoder(config['latent_dim'], - # define hidden_dim size as a factor of input_dim; keep at least 2 units - [max(int(input_dims[i] * config['hidden_dim_factor']),2)], - input_dims[i]) for i in range(len(layers))]) + self.FC_mean = nn.Linear( + len(layers) * config["latent_dim"], config["latent_dim"] + ) + self.FC_log_var = nn.Linear( + len(layers) * config["latent_dim"], config["latent_dim"] + ) + # list of decoders to decode each omics layer separately + self.decoders = nn.ModuleList( + [ + Decoder( + config["latent_dim"], + # define hidden_dim size as a factor of input_dim; keep at least 2 units + [max(int(input_dims[i] * config["hidden_dim_factor"]), 2)], + input_dims[i], + ) + for i in range(len(layers)) + ] + ) # define supervisor heads # using ModuleDict to store multiple MLPs - self.MLPs = nn.ModuleDict() + self.MLPs = nn.ModuleDict() for var in self.variables: - if self.dataset.variable_types[var] == 'numerical': + if self.dataset.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.dataset.ann[var])) - self.MLPs[var] = MLP(input_dim = config['latent_dim'], - hidden_dim = config['supervisor_hidden_dim'], - output_dim = num_class) - + self.MLPs[var] = MLP( + input_dim=config["latent_dim"], + hidden_dim=config["supervisor_hidden_dim"], + output_dim=num_class, + ) + def multi_encoder(self, x_list): """ Encode each input matrix separately using the corresponding Encoder. @@ -122,7 +156,7 @@ def multi_encoder(self, x_list): mean = self.FC_mean(torch.cat(means, dim=1)) log_var = self.FC_log_var(torch.cat(log_vars, dim=1)) return mean, log_var - + def forward(self, x_list): """ Forward pass through the model. @@ -139,20 +173,20 @@ def forward(self, x_list): - y_pred (torch.Tensor): Predicted output. """ mean, log_var = self.multi_encoder(x_list) - + # generate latent layer z = self.reparameterization(mean, log_var) # Decode each latent variable with its corresponding Decoder x_hat_list = [self.decoders[i](z) for i in range(len(x_list))] - #run the supervisor heads using the latent layer as input + # run the supervisor heads using the latent layer as input outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(z) - + return x_hat_list, z, mean, log_var, outputs - + def reparameterization(self, mean, var): """ Reparameterize the mean and variance values. @@ -164,10 +198,10 @@ def reparameterization(self, mean, var): Returns: torch.Tensor: Latent representation. """ - epsilon = torch.randn_like(var) - z = mean + var*epsilon + epsilon = torch.randn_like(var) + z = mean + var * epsilon return z - + def configure_optimizers(self): """ Configure the optimizer for the model. @@ -175,63 +209,69 @@ def configure_optimizers(self): Returns: torch.optim.Adam: Adam optimizer with learning rate 1e-3. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.dataset.variable_types[var] == 'numerical': + if self.dataset.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -240,16 +280,18 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Executes one training step using a single batch from the training dataset. @@ -264,36 +306,39 @@ def training_step(self, train_batch, batch_idx, log = True): dat, y_dict, samples = train_batch layers = dat.keys() x_list = [dat[x] for x in layers] - + x_hat_list, z, mean, log_var, outputs = self.forward(x_list) - + # compute mmd loss for each layer and take average - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) for i in range(len(layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) + for i in range(len(layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} - + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} + for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = self.compute_total_loss(losses) - # add total loss for logging - losses['train_loss'] = total_loss + # add total loss for logging + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Executes one validation step using a single batch from the validation dataset. @@ -308,33 +353,36 @@ def validation_step(self, val_batch, batch_idx, log = True): dat, y_dict, samples = val_batch layers = dat.keys() x_list = [dat[x] for x in layers] - + x_hat_list, z, mean, log_var, outputs = self.forward(x_list) - + # compute mmd loss for each layer and take average - mmd_loss_list = [self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) for i in range(len(layers))] + mmd_loss_list = [ + self.MMD_loss(z.shape[1], z, x_hat_list[i], x_list[i]) + for i in range(len(layers)) + ] mmd_loss = torch.mean(torch.stack(mmd_loss_list)) - # compute loss values for the supervisor heads - losses = {'mmd_loss': mmd_loss} + # compute loss values for the supervisor heads + losses = {"mmd_loss": mmd_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) - else: + else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - + def transform(self, dataset): """ Transform the input dataset to latent representation using batching. @@ -347,22 +395,37 @@ def transform(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed - all_latent_representations = [] # Initialize a list to collect all batch latent representations + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed + all_latent_representations = ( + [] + ) # Initialize a list to collect all batch latent representations sample_names = [] # List to collect sample names # Process each batch for batch in dataloader: dat, _, samples = batch - x_list = [dat[x].to(device) for x in dat.keys()] # Prepare the data batch for processing + x_list = [ + dat[x].to(device) for x in dat.keys() + ] # Prepare the data batch for processing # Perform the forward pass and extract the latent representation - latent_representation = self.forward(x_list)[1].detach().cpu().numpy() # Index [1] assumes second return is the latent rep + latent_representation = ( + self.forward(x_list)[1].detach().cpu().numpy() + ) # Index [1] assumes second return is the latent rep - all_latent_representations.append(latent_representation) # Store the batch's latent representation + all_latent_representations.append( + latent_representation + ) # Store the batch's latent representation sample_names.extend(samples) # Collect sample names for this batch # Concatenate all batch latent representations into one array @@ -370,11 +433,11 @@ def transform(self, dataset): # Convert the array to a DataFrame z = pd.DataFrame(concatenated_latents) - z.columns = ['E' + str(i) for i in range(z.shape[1])] # Name columns + z.columns = ["E" + str(i) for i in range(z.shape[1])] # Name columns z.index = sample_names # Set DataFrame index to sample names return z - + def predict(self, dataset): """ Evaluate the model on a dataset using batching. @@ -387,17 +450,28 @@ def predict(self, dataset): """ self.eval() # Set the model to evaluation mode from ..utils import create_device_from_string - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) self.to(device) # Move the model to the appropriate device - dataloader = DataLoader(dataset, batch_size=64, shuffle=False) # Adjust the batch size as needed + dataloader = DataLoader( + dataset, batch_size=64, shuffle=False + ) # Adjust the batch size as needed - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Process each batch for batch in dataloader: dat, _, _ = batch - x_list = [dat[x].to(device) for x in dat.keys()] # Prepare the data batch for processing + x_list = [ + dat[x].to(device) for x in dat.keys() + ] # Prepare the data batch for processing # Perform the forward pass X_hat, z, mean, log_var, outputs = self.forward(x_list) @@ -405,17 +479,21 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} return predictions - + def compute_kernel(self, x, y): """ Compute the Gaussian kernel matrix between two sets of vectors. @@ -430,12 +508,12 @@ def compute_kernel(self, x, y): x_size = x.size(0) y_size = y.size(0) dim = x.size(1) - x = x.unsqueeze(1) # (x_size, 1, dim) - y = y.unsqueeze(0) # (1, y_size, dim) + x = x.unsqueeze(1) # (x_size, 1, dim) + y = y.unsqueeze(0) # (1, y_size, dim) tiled_x = x.expand(x_size, y_size, dim) tiled_y = y.expand(x_size, y_size, dim) - kernel_input = (tiled_x - tiled_y).pow(2).mean(2)/float(dim) - return torch.exp(-kernel_input) # (x_size, y_size) + kernel_input = (tiled_x - tiled_y).pow(2).mean(2) / float(dim) + return torch.exp(-kernel_input) # (x_size, y_size) def compute_mmd(self, x, y): """ @@ -451,7 +529,7 @@ def compute_mmd(self, x, y): x_kernel = self.compute_kernel(x, x) y_kernel = self.compute_kernel(y, y) xy_kernel = self.compute_kernel(x, y) - mmd = x_kernel.mean() + y_kernel.mean() - 2*xy_kernel.mean() + mmd = x_kernel.mean() + y_kernel.mean() - 2 * xy_kernel.mean() return mmd def MMD_loss(self, latent_dim, z, xhat, x): @@ -467,25 +545,34 @@ def MMD_loss(self, latent_dim, z, xhat, x): Returns: torch.Tensor: A scalar tensor representing the MMD loss. """ - true_samples = torch.randn(200, latent_dim, device = self.device) - mmd = self.compute_mmd(true_samples, z) # compute maximum mean discrepancy (MMD) - nll = (xhat - x).pow(2).mean() #negative log likelihood - return mmd+nll - - # Adaptor forward function for captum integrated gradients. + true_samples = torch.randn(200, latent_dim, device=self.device) + mmd = self.compute_mmd( + true_samples, z + ) # compute maximum mean discrepancy (MMD) + nll = (xhat - x).pow(2).mean() # negative log likelihood + return mmd + nll + + # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest - steps = args[-1] # number of steps for IntegratedGradients().attribute + steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] for i in range(steps): # get list of tensors for each step into a list of tensors x_step = [input_data[j][i] for j in range(len(input_data))] x_hat_list, z, mean, log_var, outputs = self.forward(x_step) outputs_list.append(outputs[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. @@ -502,26 +589,40 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad pd.DataFrame: A DataFrame containing feature importances across different variables and data modalities. """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) - - print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) + + print( + "[INFO] Computing feature importance for variable:", + target_var, + "on device:", + device, + ) # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) # Get the number of classes for the target variable - if self.dataset.variable_types[target_var] == 'numerical': + if self.dataset.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.dataset.ann[target_var])) @@ -529,48 +630,62 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) aggregated_attributions = [[] for _ in range(num_class)] - + for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) - - if method == 'IntegratedGradients': + + if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( - torch.cat([torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0) + torch.cat( + [torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0 + ) for x in input_data ) - + if num_class == 1: - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + n_samples=steps_or_samples, + ) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_steps=steps_or_samples) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), - target=target_class, - n_samples=steps_or_samples) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_steps=steps_or_samples, + ) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=(target_var, steps_or_samples), + target=target_class, + n_samples=steps_or_samples, + ) aggregated_attributions[target_class].append(attributions) - # For each target class and for each data modality/layer, concatenate attributions accross batches + # For each target class and for each data modality/layer, concatenate attributions accross batches layers = list(dataset.dat.keys()) num_layers = len(layers) - processed_attributions = [] + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -583,32 +698,44 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - + # summarize feature importances # Compute absolute attributions # Move the processed tensors to CPU for further operations that are not supported on GPU - abs_attr = [[torch.abs(a).cpu() for a in attr_class] for attr_class in processed_attributions] - # average over samples + abs_attr = [ + [torch.abs(a).cpu() for a in attr_class] + for attr_class in processed_attributions + ] + # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') + self.to("cpu") - # combine into a single data frame + # combine into a single data frame df_list = [] for i in range(num_class): for j in range(len(layers)): features = self.dataset.features[layers[j]] importances = imp[i][j][0].detach().numpy() - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, 'importance': importances})) - df_imp = pd.concat(df_list, ignore_index = True) - + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) + df_imp = pd.concat(df_list, ignore_index=True) + # save scores in model self.feature_importances[target_var] = df_imp - - diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index f714dcc9..6824610c 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -1,4 +1,4 @@ -# Generating encodings of multi-omic data using triplet loss +# Generating encodings of multi-omic data using triplet loss import torch from torch import nn from torch.nn import functional as F @@ -18,7 +18,6 @@ from captum.attr import IntegratedGradients, GradientShap - class MultiTripletNetwork(pl.LightningModule): """ A PyTorch Lightning module that implements a multi-triplet network architecture for handling @@ -37,11 +36,19 @@ class MultiTripletNetwork(pl.LightningModule): device_type (str, optional): Type of device ('gpu' or 'cpu') on which the model will be run. """ - def __init__(self, config, dataset, target_variables, batch_variables = None, - surv_event_var = None, surv_time_var = None, use_loss_weighting = True, - device_type = None): + def __init__( + self, + config, + dataset, + target_variables, + batch_variables=None, + surv_event_var=None, + surv_time_var=None, + use_loss_weighting=True, + device_type=None, + ): super(MultiTripletNetwork, self).__init__() - + self.config = config self.target_variables = target_variables self.surv_event_var = surv_event_var @@ -51,54 +58,73 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, if self.surv_event_var is not None and self.surv_time_var is not None: self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables - self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables + self.variables = ( + self.target_variables + batch_variables + if batch_variables + else self.target_variables + ) self.ann = dataset.ann self.variable_types = dataset.variable_types self.feature_importances = {} self.device_type = device_type - # The first target variable is the main variable that dictates the triplets - # it has to be a categorical variable - self.main_var = self.target_variables[0] - if self.variable_types[self.main_var] == 'numerical': - raise ValueError("The first target variable",self.main_var," must be a categorical variable") - + # The first target variable is the main variable that dictates the triplets + # it has to be a categorical variable + self.main_var = self.target_variables[0] + if self.variable_types[self.main_var] == "numerical": + raise ValueError( + "The first target variable", + self.main_var, + " must be a categorical variable", + ) + self.use_loss_weighting = use_loss_weighting - + if self.use_loss_weighting: # Initialize log variance parameters for uncertainty weighting self.log_vars = nn.ParameterDict() - for loss_type in itertools.chain(self.variables, ['triplet_loss']): + for loss_type in itertools.chain(self.variables, ["triplet_loss"]): self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) - + self.layers = list(dataset.dat.keys()) - self.input_dims = [len(dataset.features[self.layers[i]]) for i in range(len(self.layers))] - - self.encoders = nn.ModuleList([ - MLP(input_dim=self.input_dims[i], - # define hidden_dim size relative to the input_dim size - hidden_dim=int(self.input_dims[i] * self.config['hidden_dim_factor']), - output_dim=self.config['latent_dim']) for i in range(len(self.layers))]) - + self.input_dims = [ + len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) + ] + + self.encoders = nn.ModuleList( + [ + MLP( + input_dim=self.input_dims[i], + # define hidden_dim size relative to the input_dim size + hidden_dim=int( + self.input_dims[i] * self.config["hidden_dim_factor"] + ), + output_dim=self.config["latent_dim"], + ) + for i in range(len(self.layers)) + ] + ) + if len(self.input_dims) > 1: self.fusion_block = nn.Linear( - in_features=self.config['latent_dim'] * len(self.layers), - out_features=self.config['latent_dim'] + in_features=self.config["latent_dim"] * len(self.layers), + out_features=self.config["latent_dim"], ) else: self.fusion_block = None - - # define supervisor heads for both target and batch variables - self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs + # define supervisor heads for both target and batch variables + self.MLPs = nn.ModuleDict() # using ModuleDict to store multiple MLPs for var in self.variables: - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": num_class = 1 else: num_class = len(np.unique(self.ann[var])) - self.MLPs[var] = MLP(input_dim=self.config['latent_dim'], - hidden_dim=self.config['supervisor_hidden_dim'], - output_dim=num_class) - + self.MLPs[var] = MLP( + input_dim=self.config["latent_dim"], + hidden_dim=self.config["supervisor_hidden_dim"], + output_dim=num_class, + ) + def concat_embeddings(self, dat): embeddings_list = [] x_list = [dat[x] for x in dat.keys()] @@ -106,10 +132,14 @@ def concat_embeddings(self, dat): for i, x in enumerate(x_list): embeddings_list.append(self.encoders[i](x)) embeddings_concat = torch.cat(embeddings_list, dim=1) - # if multiple embeddings, fuse them - embeddings = self.fusion_block(embeddings_concat) if self.fusion_block else embeddings_concat + # if multiple embeddings, fuse them + embeddings = ( + self.fusion_block(embeddings_concat) + if self.fusion_block + else embeddings_concat + ) return embeddings - + def forward(self, anchor, positive, negative): """ Compute the forward pass of the MultiTripletNetwork and return the embeddings and predictions. @@ -126,13 +156,13 @@ def forward(self, anchor, positive, negative): anchor_embedding = self.concat_embeddings(anchor) positive_embedding = self.concat_embeddings(positive) negative_embedding = self.concat_embeddings(negative) - - #run the supervisor heads using the anchor embeddings as input + + # run the supervisor heads using the anchor embeddings as input outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(anchor_embedding) return anchor_embedding, positive_embedding, negative_embedding, outputs - + def configure_optimizers(self): """ Configure the optimizer for the MultiTripletNetwork. @@ -140,9 +170,9 @@ def configure_optimizers(self): Returns: torch.optim.Optimizer: The configured optimizer. """ - optimizer = torch.optim.Adam(self.parameters(), lr=self.config['lr']) + optimizer = torch.optim.Adam(self.parameters(), lr=self.config["lr"]) return optimizer - + def triplet_loss(self, anchor, positive, negative, margin=1.0): """ Compute the triplet loss for the given anchor, positive, and negative embeddings. @@ -159,63 +189,69 @@ def triplet_loss(self, anchor, positive, negative, margin=1.0): distance_positive = (anchor - positive).pow(2).sum(1) distance_negative = (anchor - negative).pow(2).sum(1) losses = torch.relu(distance_positive - distance_negative + margin) - return losses.mean() + return losses.mean() def compute_loss(self, var, y, y_hat): """ Computes the loss for a specific variable based on whether the variable is numerical or categorical. Handles missing labels by excluding them from the loss calculation. - + Args: var (str): The name of the variable for which the loss is being calculated. y (torch.Tensor): The true labels or values for the variable. y_hat (torch.Tensor): The predicted labels or values output by the model. - + Returns: torch.Tensor: The calculated loss tensor for the variable. If there are no valid labels or values to compute the loss (all are missing), returns a zero loss tensor with gradient enabled. - + The method first checks the type of the variable (`var`) from `variable_types`. If the variable is numerical, it computes the mean squared error loss. For categorical variables, it calculates the cross-entropy loss. The method ensures to ignore any instances where the labels are missing (NaN for numerical or -1 for categorical as assumed missing value encoding) when calculating the loss. """ - if self.variable_types[var] == 'numerical': + if self.variable_types[var] == "numerical": # Ignore instances with missing labels for numerical variables valid_indices = ~torch.isnan(y) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.mse_loss(torch.flatten(y_hat), y.float()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) # if no valid labels, set loss to 0 + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) # if no valid labels, set loss to 0 else: # Ignore instances with missing labels for categorical variables # Assuming that missing values were encoded as -1 valid_indices = (y != -1) & (~torch.isnan(y)) - if valid_indices.sum() > 0: # only calculate loss if there are valid targets + if ( + valid_indices.sum() > 0 + ): # only calculate loss if there are valid targets y_hat = y_hat[valid_indices] y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) - else: + else: loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss - + def compute_total_loss(self, losses): """ Computes the total loss from a dictionary of individual losses. This method can compute either weighted or unweighted total loss based on the model configuration. If loss weighting is enabled and there are multiple loss components, it uses uncertainty-based weighting. See Kendall A. et al, https://arxiv.org/abs/1705.07115. - + Args: losses (dict of torch.Tensor): A dictionary where each key is a variable name and each value is the loss tensor associated with that variable. - + Returns: torch.Tensor: The total loss computed across all inputs, either weighted or unweighted. - + The method checks if loss weighting is used (`use_loss_weighting`) and if there are multiple losses to weight. If so, it computes the weighted sum of losses, where the weight involves the exponential of the negative log variance (acting as precision) associated with each loss, @@ -224,42 +260,54 @@ def compute_total_loss(self, losses): loss component, it sums up the losses directly. """ if self.use_loss_weighting and len(losses) > 1: - # Compute weighted loss for each loss + # Compute weighted loss for each loss # Weighted loss = precision * loss + log-variance - total_loss = sum(torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] for name, loss in losses.items()) + total_loss = sum( + torch.exp(-self.log_vars[name]) * loss + self.log_vars[name] + for name, loss in losses.items() + ) else: # Compute unweighted total loss total_loss = sum(losses.values()) return total_loss - def training_step(self, train_batch, batch_idx, log = True): + def training_step(self, train_batch, batch_idx, log=True): """ Perform a training step using a single batch of data, including triplet components and target labels. - + Args: train_batch (tuple): The batch containing data tuples (anchor, positive, negative) and a dictionary of labels. batch_idx (int): The index of the current batch. log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. - + Returns: torch.Tensor: The total loss for the current training batch, which includes triplet loss and any additional losses from supervisor heads. - + This method computes the embedding for the anchor, positive, and negative samples and calculates the triplet loss. Additional losses are computed for other target variables in the dataset, particularly handling survival analysis if applicable. All losses are combined to compute a total loss, which is logged and returned. """ - anchor, positive, negative, y_dict = train_batch[0], train_batch[1], train_batch[2], train_batch[3] - anchor_embedding, positive_embedding, negative_embedding, outputs = self.forward(anchor, positive, negative) - triplet_loss = self.triplet_loss(anchor_embedding, positive_embedding, negative_embedding) - - # compute loss values for the supervisor heads - losses = {'triplet_loss': triplet_loss} + anchor, positive, negative, y_dict = ( + train_batch[0], + train_batch[1], + train_batch[2], + train_batch[3], + ) + anchor_embedding, positive_embedding, negative_embedding, outputs = ( + self.forward(anchor, positive, negative) + ) + triplet_loss = self.triplet_loss( + anchor_embedding, positive_embedding, negative_embedding + ) + + # compute loss values for the supervisor heads + losses = {"triplet_loss": triplet_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] @@ -268,69 +316,78 @@ def training_step(self, train_batch, batch_idx, log = True): losses[var] = loss total_loss = self.compute_total_loss(losses) - # add total loss for logging - losses['train_loss'] = total_loss + # add total loss for logging + losses["train_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - - def validation_step(self, val_batch, batch_idx, log = True): + + def validation_step(self, val_batch, batch_idx, log=True): """ Perform a validation step using a single batch of data, including triplet components and target labels. - + Args: val_batch (tuple): The batch containing data tuples (anchor, positive, negative) and a dictionary of labels. batch_idx (int): The index of the current batch. log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. - + Returns: torch.Tensor: The total loss for the current validation batch, which includes triplet loss and any additional losses from supervisor heads. - + Similar to the training step, this method computes the embedding for the anchor, positive, and negative samples and calculates the triplet loss. It computes additional losses for other target variables in the dataset, aggregates all losses, and returns the total loss. The losses are logged if specified. """ - anchor, positive, negative, y_dict = val_batch[0], val_batch[1], val_batch[2], val_batch[3] - anchor_embedding, positive_embedding, negative_embedding, outputs = self.forward(anchor, positive, negative) - triplet_loss = self.triplet_loss(anchor_embedding, positive_embedding, negative_embedding) - - # compute loss values for the supervisor heads - losses = {'triplet_loss': triplet_loss} + anchor, positive, negative, y_dict = ( + val_batch[0], + val_batch[1], + val_batch[2], + val_batch[3], + ) + anchor_embedding, positive_embedding, negative_embedding, outputs = ( + self.forward(anchor, positive, negative) + ) + triplet_loss = self.triplet_loss( + anchor_embedding, positive_embedding, negative_embedding + ) + + # compute loss values for the supervisor heads + losses = {"triplet_loss": triplet_loss} for var in self.variables: if var == self.surv_event_var: durations = y_dict[self.surv_time_var] - events = y_dict[self.surv_event_var] - risk_scores = outputs[var] #output of MLP + events = y_dict[self.surv_event_var] + risk_scores = outputs[var] # output of MLP loss = cox_ph_loss(risk_scores, durations, events) else: y_hat = outputs[var] y = y_dict[var] loss = self.compute_loss(var, y, y_hat) losses[var] = loss - + total_loss = sum(losses.values()) - losses['val_loss'] = total_loss + losses["val_loss"] = total_loss if log: self.log_dict(losses, on_step=False, on_epoch=True, prog_bar=True) return total_loss - + # dataset: MultiOmicDataset def transform(self, dataset): """ Transforms the input dataset by generating embeddings and predictions. - + Args: dataset (MultiOmicDataset): An instance of the MultiOmicDataset class. - + Returns: z (pd.DataFrame): A dataframe containing the computed embeddings. y_pred (np.ndarray): A numpy array containing the predicted labels. """ self.eval() - # get anchor embeddings + # get anchor embeddings z = pd.DataFrame(self.concat_embeddings(dataset.dat).detach().numpy()) - z.columns = [''.join(['E', str(x)]) for x in z.columns] + z.columns = ["".join(["E", str(x)]) for x in z.columns] z.index = dataset.samples return z @@ -351,44 +408,58 @@ def predict(self, dataset): outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(anchor_embedding) - + # get predictions from the mlp outputs for each var - predictions = {var: [] for var in self.variables} # Initialize prediction storage + predictions = { + var: [] for var in self.variables + } # Initialize prediction storage # Collect predictions for each variable for var in self.variables: logits = outputs[var].detach().cpu() # Raw model outputs (logits) - if dataset.variable_types[var] == 'categorical': - probs = torch.softmax(logits, dim=1).numpy() # class probabilities between 0 and 1 + if dataset.variable_types[var] == "categorical": + probs = torch.softmax( + logits, dim=1 + ).numpy() # class probabilities between 0 and 1 predictions[var].extend(probs) else: - predictions[var].extend(logits.numpy()) # return raw output for regression problems - # Convert lists to arrays + predictions[var].extend( + logits.numpy() + ) # return raw output for regression problems + # Convert lists to arrays predictions = {var: np.array(predictions[var]) for var in predictions} - + return predictions - - # Adaptor forward function for captum integrated gradients. - # layer_sizes: number of features in each omic layer + # Adaptor forward function for captum integrated gradients. + # layer_sizes: number of features in each omic layer def forward_target(self, input_data, layer_sizes, target_var, steps): outputs_list = [] for i in range(steps): - # for each step, get anchor/positive/negative tensors + # for each step, get anchor/positive/negative tensors # (split the concatenated omics layers) - anchor = input_data[i][0].split(layer_sizes, dim = 1) - positive = input_data[i][1].split(layer_sizes, dim = 1) - negative = input_data[i][2].split(layer_sizes, dim = 1) - + anchor = input_data[i][0].split(layer_sizes, dim=1) + positive = input_data[i][1].split(layer_sizes, dim=1) + negative = input_data[i][2].split(layer_sizes, dim=1) + # convert to dict anchor = {k: anchor[k] for k in range(len(anchor))} positive = {k: anchor[k] for k in range(len(positive))} negative = {k: anchor[k] for k in range(len(negative))} - anchor_embedding, positive_embedding, negative_embedding, outputs = self.forward(anchor, positive, negative) + anchor_embedding, positive_embedding, negative_embedding, outputs = ( + self.forward(anchor, positive, negative) + ) outputs_list.append(outputs[target_var]) - return torch.cat(outputs_list, dim = 0) - - def compute_feature_importance(self, dataset, target_var, method="IntegratedGradients", steps_or_samples=5, batch_size=64): + return torch.cat(outputs_list, dim=0) + + def compute_feature_importance( + self, + dataset, + target_var, + method="IntegratedGradients", + steps_or_samples=5, + batch_size=64, + ): """ Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. @@ -406,94 +477,143 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad """ from ..utils import create_device_from_string, to_device_safe - device = create_device_from_string(self.device_type if hasattr(self, 'device_type') and self.device_type else 'auto') - + + device = create_device_from_string( + self.device_type + if hasattr(self, "device_type") and self.device_type + else "auto" + ) + # Force CPU for Captum feature importance, as MPS lacks the required float64 support. - if device.type == 'mps': - print("[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility.") - device = torch.device('cpu') - + if device.type == "mps": + print( + "[WARNING] MPS device detected. Computing feature importance on CPU due to MPS float64 incompatibility." + ) + device = torch.device("cpu") + self.to(device) - print("[INFO] Computing feature importance for variable:",target_var,"on device:",device) + print( + "[INFO] Computing feature importance for variable:", + target_var, + "on device:", + device, + ) - # define data loader + # define data loader triplet_dataset = TripletMultiOmicDataset(dataset, self.main_var) dataloader = DataLoader(triplet_dataset, batch_size=batch_size, shuffle=False) - + # Choose the attribution method dynamically if method == "IntegratedGradients": explainer = IntegratedGradients(self.forward_target) elif method == "GradientShap": explainer = GradientShap(self.forward_target) else: - raise ValueError(f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'.") + raise ValueError( + f"Unsupported method '{method}'. Choose 'IntegratedGradients' or 'GradientShap'." + ) - if self.variable_types[target_var] == 'numerical': + if self.variable_types[target_var] == "numerical": num_class = 1 else: num_class = len(np.unique([y[target_var] for _, y, _ in dataset])) aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: - # see training_step to see how elements are accessed in batches - anchor, positive, negative, y_dict = batch[0], batch[1], batch[2], batch[3] + # see training_step to see how elements are accessed in batches + anchor, positive, negative, y_dict = batch[0], batch[1], batch[2], batch[3] # Move tensors to the specified device using MPS-safe method anchor = {k: to_device_safe(v, device) for k, v in anchor.items()} positive = {k: to_device_safe(v, device) for k, v in positive.items()} negative = {k: to_device_safe(v, device) for k, v in negative.items()} - + anchor = [data.requires_grad_() for data in list(anchor.values())] positive = [data.requires_grad_() for data in list(positive.values())] negative = [data.requires_grad_() for data in list(negative.values())] - + # concatenate multiomic layers of each list element - # then stack the anchor/positive/negative + # then stack the anchor/positive/negative # the purpose is to get a single tensor - input_data = torch.stack([torch.cat(sublist, dim = 1) for sublist in [anchor, positive, negative]]).unsqueeze(0) + input_data = torch.stack( + [torch.cat(sublist, dim=1) for sublist in [anchor, positive, negative]] + ).unsqueeze(0) - # layer sizes will be needed to revert the concatenated tensor + # layer sizes will be needed to revert the concatenated tensor # anchor/positive/negative have the same shape - layer_sizes = [anchor[i].shape[1] for i in range(len(anchor))] - - # Define a baseline - if method == 'IntegratedGradients': - baseline = torch.zeros_like(input_data) - elif method == 'GradientShap': # provide multiple baselines for Gr.Shap - baseline = torch.cat([torch.zeros_like(input_data) for _ in range(steps_or_samples)], dim=0) - + layer_sizes = [anchor[i].shape[1] for i in range(len(anchor))] + + # Define a baseline + if method == "IntegratedGradients": + baseline = torch.zeros_like(input_data) + elif method == "GradientShap": # provide multiple baselines for Gr.Shap + baseline = torch.cat( + [torch.zeros_like(input_data) for _ in range(steps_or_samples)], + dim=0, + ) + if num_class == 1: - - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - n_steps=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) - elif method == 'GradientShape': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - n_samples=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) + + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + n_steps=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) + elif method == "GradientShape": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + n_samples=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) aggregated_attributions[0].append(attributions) else: for target_class in range(num_class): - if method == 'IntegratedGradients': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - target=target_class, n_steps=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) - elif method == 'GradientShap': - attributions = explainer.attribute(input_data, baseline, - additional_forward_args=(layer_sizes, target_var, steps_or_samples), - target=target_class, n_samples=steps_or_samples) - attributions = attributions.split(layer_sizes, dim = 3) + if method == "IntegratedGradients": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + target=target_class, + n_steps=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) + elif method == "GradientShap": + attributions = explainer.attribute( + input_data, + baseline, + additional_forward_args=( + layer_sizes, + target_var, + steps_or_samples, + ), + target=target_class, + n_samples=steps_or_samples, + ) + attributions = attributions.split(layer_sizes, dim=3) aggregated_attributions[target_class].append(attributions) - # For each target class and for each data modality/layer, concatenate attributions accross batches + # For each target class and for each data modality/layer, concatenate attributions accross batches layers = self.layers num_layers = len(layers) - processed_attributions = [] + processed_attributions = [] # Process each class for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] @@ -507,25 +627,40 @@ def compute_feature_importance(self, dataset, target_var, method="IntegratedGrad layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) - # compute absolute importance and move to cpu + # compute absolute importance and move to cpu # notice the squeeze (due to triplets) - abs_attr = [[torch.abs(a.squeeze()).cpu() for a in attr_class] for attr_class in processed_attributions] - # average over samples + abs_attr = [ + [torch.abs(a.squeeze()).cpu() for a in attr_class] + for attr_class in processed_attributions + ] + # average over samples imp = [[a.mean(dim=1) for a in attr_class] for attr_class in abs_attr] # move the model also back to cpu (if not already on cpu) - self.to('cpu') + self.to("cpu") df_list = [] for i in range(num_class): for j in range(len(layers)): features = dataset.features[layers[j]] # Ensure tensors are already on CPU before converting to numpy - importances = imp[i][j][0].detach().numpy() # 0 => extract importances only for the anchor - target_class_label = dataset.label_mappings[target_var].get(i) if target_var in dataset.label_mappings else '' - df_list.append(pd.DataFrame({'target_variable': target_var, - 'target_class': i, - 'target_class_label': target_class_label, - 'layer': layers[j], - 'name': features, - 'importance': importances})) + importances = ( + imp[i][j][0].detach().numpy() + ) # 0 => extract importances only for the anchor + target_class_label = ( + dataset.label_mappings[target_var].get(i) + if target_var in dataset.label_mappings + else "" + ) + df_list.append( + pd.DataFrame( + { + "target_variable": target_var, + "target_class": i, + "target_class_label": target_class_label, + "layer": layers[j], + "name": features, + "importance": importances, + } + ) + ) df_imp = pd.concat(df_list, ignore_index=True) self.feature_importances[target_var] = df_imp diff --git a/flexynesis/modules.py b/flexynesis/modules.py index aed8ac83..f367959f 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -4,66 +4,67 @@ from torch import nn from torch_geometric.nn import aggr, GCNConv, GATConv, SAGEConv, GraphConv - __all__ = ["Encoder", "Decoder", "MLP", "flexGCN", "cox_ph_loss"] class Encoder(nn.Module): """ Encoder class for a Variational Autoencoder (VAE). - + The Encoder class is responsible for taking input data and generating the mean and log variance for the latent space representation. """ + def __init__(self, input_dim, hidden_dims, latent_dim): super(Encoder, self).__init__() self.act = nn.LeakyReLU(0.2) - + hidden_layers = [] - + hidden_layers.append(nn.Linear(input_dim, hidden_dims[0])) nn.init.xavier_uniform_(hidden_layers[-1].weight) hidden_layers.append(self.act) hidden_layers.append(nn.BatchNorm1d(hidden_dims[0])) - for i in range(len(hidden_dims)-1): - hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1])) + for i in range(len(hidden_dims) - 1): + hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1])) nn.init.xavier_uniform_(hidden_layers[-1].weight) hidden_layers.append(self.act) - hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1])) + hidden_layers.append(nn.BatchNorm1d(hidden_dims[i + 1])) self.hidden_layers = nn.Sequential(*hidden_layers) - - self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim) + + self.FC_mean = nn.Linear(hidden_dims[-1], latent_dim) nn.init.xavier_uniform_(self.FC_mean.weight) - self.FC_var = nn.Linear(hidden_dims[-1], latent_dim) + self.FC_var = nn.Linear(hidden_dims[-1], latent_dim) nn.init.xavier_uniform_(self.FC_var.weight) - + def forward(self, x): """ Performs a forward pass through the Encoder network. - + Args: x (torch.Tensor): The input data tensor. - + Returns: mean (torch.Tensor): The mean of the latent space representation. log_var (torch.Tensor): The log variance of the latent space representation. """ - h_ = self.hidden_layers(x) - mean = self.FC_mean(h_) - log_var = self.FC_var(h_) + h_ = self.hidden_layers(x) + mean = self.FC_mean(h_) + log_var = self.FC_var(h_) return mean, log_var - - + + class Decoder(nn.Module): """ Decoder class for a Variational Autoencoder (VAE). - + The Decoder class is responsible for taking the latent space representation and generating the reconstructed output data. """ + def __init__(self, latent_dim, hidden_dims, output_dim): super(Decoder, self).__init__() @@ -80,7 +81,7 @@ def __init__(self, latent_dim, hidden_dims, output_dim): hidden_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1])) nn.init.xavier_uniform_(hidden_layers[-1].weight) hidden_layers.append(self.act) - hidden_layers.append(nn.BatchNorm1d(hidden_dims[i+1])) + hidden_layers.append(nn.BatchNorm1d(hidden_dims[i + 1])) self.hidden_layers = nn.Sequential(*hidden_layers) @@ -90,49 +91,54 @@ def __init__(self, latent_dim, hidden_dims, output_dim): def forward(self, x): """ Performs a forward pass through the Decoder network. - + Args: x (torch.Tensor): The input tensor representing the latent space. - + Returns: x_hat (torch.Tensor): The reconstructed output tensor. """ h = self.hidden_layers(x) x_hat = torch.sigmoid(self.FC_output(h)) return x_hat - + class MLP(nn.Module): """ A Multi-Layer Perceptron (MLP) model for regression or classification tasks. - + The MLP class is a simple feed-forward neural network that can be used for regression when `output_dim` is set to 1 or for classification when `output_dim` is greater than 1. """ + def __init__(self, input_dim, hidden_dim, output_dim): """ Initializes the MLP class with the given input dimension, output dimension, and hidden layer size. - + Args: input_dim (int): The input dimension. hidden_dim (int, optional): The size of the hidden layer. Default is 32. output_dim (int): The output dimension. Set to 1 for regression tasks, and > 1 for classification tasks. """ super().__init__() - hidden_dim = max(hidden_dim, 2) # make sure there are at least 2 units + hidden_dim = max(hidden_dim, 2) # make sure there are at least 2 units self.layer_1 = nn.Linear(input_dim, hidden_dim) - self.layer_out = nn.Linear(hidden_dim, output_dim) if output_dim > 1 else nn.Linear(hidden_dim, 1, bias=False) - self.relu = nn.ReLU() + self.layer_out = ( + nn.Linear(hidden_dim, output_dim) + if output_dim > 1 + else nn.Linear(hidden_dim, 1, bias=False) + ) + self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.1) self.batchnorm = nn.BatchNorm1d(hidden_dim) def forward(self, x): """ Performs a forward pass through the MLP network. - + Args: x (torch.Tensor): The input data tensor. - + Returns: x (torch.Tensor): The output tensor after passing through the MLP network. """ @@ -142,7 +148,8 @@ def forward(self, x): x = self.dropout(x) x = self.layer_out(x) return x - + + class flexGCN(nn.Module): """ A Graph Neural Network (GNN) model using configurable convolution and activation layers. @@ -166,57 +173,74 @@ class flexGCN(nn.Module): output_dim (int): The size of the output vector, which is the final feature vector for the whole graph. num_convs (int, optional): Number of convolutional layers in the network. Defaults to 2. dropout_rate (float, optional): The dropout probability used for regularization. Defaults to 0.2. - conv (str, optional): Type of convolution layer to use. Supported types include 'GCN' for Graph Convolution Network, - 'GAT' for Graph Attention Network, 'SAGE' for GraphSAGE, and 'GC' for generic Graph Convolution. + conv (str, optional): Type of convolution layer to use. Supported types include 'GCN' for Graph Convolution Network, + 'GAT' for Graph Attention Network, 'SAGE' for GraphSAGE, and 'GC' for generic Graph Convolution. Defaults to 'GC'. - act (str, optional): Type of activation function to use. Supported types include 'relu', 'sigmoid', + act (str, optional): Type of activation function to use. Supported types include 'relu', 'sigmoid', 'leakyrelu', 'tanh', and 'gelu'. Defaults to 'relu'. Raises: ValueError: If an unsupported activation function or convolution type is specified. Example: - >>> model = flexGCN(node_count=100, node_feature_count=5, node_embedding_dim=64, output_dim=10, + >>> model = flexGCN(node_count=100, node_feature_count=5, node_embedding_dim=64, output_dim=10, num_convs=3, dropout_rate=0.3, conv='GAT', act='relu') >>> output = model(input_features, edge_index) # Where `input_features` is a tensor of shape (batch_size, num_nodes, node_feature_count) # and `edge_index` is a list of edges in the COO format (2, num_edges). """ - def __init__(self, node_count, node_feature_count, node_embedding_dim, output_dim, - num_convs = 2, dropout_rate = 0.2, conv='GC', act='relu'): + + def __init__( + self, + node_count, + node_feature_count, + node_embedding_dim, + output_dim, + num_convs=2, + dropout_rate=0.2, + conv="GC", + act="relu", + ): super().__init__() act_options = { - 'relu': nn.ReLU(), - 'sigmoid': nn.Sigmoid(), - 'leakyrelu': nn.LeakyReLU(), - 'tanh': nn.Tanh(), - 'gelu': nn.GELU() + "relu": nn.ReLU(), + "sigmoid": nn.Sigmoid(), + "leakyrelu": nn.LeakyReLU(), + "tanh": nn.Tanh(), + "gelu": nn.GELU(), } if act not in act_options: - raise ValueError("Invalid activation function string. Choose from ", list(act_options.keys())) - + raise ValueError( + "Invalid activation function string. Choose from ", + list(act_options.keys()), + ) + conv_options = { - 'GCN': GCNConv, - 'GAT': GATConv, - 'SAGE': SAGEConv, - 'GC': GraphConv + "GCN": GCNConv, + "GAT": GATConv, + "SAGE": SAGEConv, + "GC": GraphConv, } if conv not in conv_options: - raise ValueError('Unknown convolution type. Choose one of: ', list(conv_options.keys())) + raise ValueError( + "Unknown convolution type. Choose one of: ", list(conv_options.keys()) + ) self.act = act_options[act] self.convs = nn.ModuleList() self.bns = nn.ModuleList() self.dropout = nn.Dropout(dropout_rate) - + # Initialize the first convolution layer separately if different input size self.convs.append(conv_options[conv](node_feature_count, node_embedding_dim)) self.bns.append(nn.BatchNorm1d(node_embedding_dim)) # Loop to create the remaining convolution and BN layers for _ in range(1, num_convs): - self.convs.append(conv_options[conv](node_embedding_dim, node_embedding_dim)) + self.convs.append( + conv_options[conv](node_embedding_dim, node_embedding_dim) + ) self.bns.append(nn.BatchNorm1d(node_embedding_dim)) # Final fully connected layer @@ -227,13 +251,14 @@ def forward(self, x, edge_index): x = conv(x, edge_index) x = bn(x.view(-1, x.size(2))).view_as(x) x = self.act(x) - x = self.dropout(x) + x = self.dropout(x) # Flatten the output of all nodes into a single vector per graph/sample x = x.view(x.size(0), -1) x = self.fc(x) return x + def cox_ph_loss(outputs, durations, events): """ Calculate the Cox proportional hazards loss. @@ -251,22 +276,26 @@ def cox_ph_loss(outputs, durations, events): outputs = outputs[valid_indices] events = events[valid_indices] durations = durations[valid_indices] - + # Exponentiate the outputs to get the hazard ratios hazards = torch.exp(outputs) # Ensure hazards is at least 1D if hazards.dim() == 0: hazards = hazards.unsqueeze(0) # Make hazards 1D if it's a scalar # Calculate the risk set sum - log_risk_set_sum = torch.log(torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0)) + log_risk_set_sum = torch.log( + torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0) + ) # Get the indices that sort the durations in descending order sorted_indices = torch.argsort(durations, descending=True) events_sorted = events[sorted_indices] # Calculate the loss - uncensored_loss = torch.sum(outputs[sorted_indices][events_sorted == 1]) - torch.sum(log_risk_set_sum[events_sorted == 1]) + uncensored_loss = torch.sum( + outputs[sorted_indices][events_sorted == 1] + ) - torch.sum(log_risk_set_sum[events_sorted == 1]) total_loss = -uncensored_loss / torch.sum(events) - else: + else: total_loss = torch.tensor(0.0, device=outputs.device, requires_grad=True) if not torch.isfinite(total_loss): return torch.tensor(0.0, device=outputs.device, requires_grad=True) diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 901ac8ab..01245003 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -17,7 +17,14 @@ import matplotlib.pyplot as plt import matplotlib from sklearn.decomposition import PCA -from sklearn.metrics import balanced_accuracy_score, f1_score, cohen_kappa_score, classification_report, roc_auc_score, average_precision_score +from sklearn.metrics import ( + balanced_accuracy_score, + f1_score, + cohen_kappa_score, + classification_report, + roc_auc_score, + average_precision_score, +) from sklearn.metrics import mean_squared_error from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score from sklearn.utils import resample @@ -48,11 +55,28 @@ from lifelines.statistics import logrank_test, multivariate_logrank_test from plotnine import ( - ggplot, aes, geom_point, geom_smooth, geom_line, geom_abline, geom_step, - labs, ggtitle, annotate, theme_minimal, theme, element_text, - scale_color_manual, scale_color_gradient, scale_color_brewer, - geom_errorbarh, geom_text, - theme_bw, theme, element_blank, scale_y_discrete + ggplot, + aes, + geom_point, + geom_smooth, + geom_line, + geom_abline, + geom_step, + labs, + ggtitle, + annotate, + theme_minimal, + theme, + element_text, + scale_color_manual, + scale_color_gradient, + scale_color_brewer, + geom_errorbarh, + geom_text, + theme_bw, + theme, + element_blank, + scale_y_discrete, ) from sklearn.cluster import KMeans @@ -64,14 +88,23 @@ from sklearn.preprocessing import StandardScaler import ot - # imports import numpy as np import pandas as pd import matplotlib.pyplot as plt from plotnine import ( - ggplot, aes, geom_point, geom_step, labs, ggtitle, annotate, - theme_minimal, theme, element_text, scale_color_manual, scale_color_gradient + ggplot, + aes, + geom_point, + geom_step, + labs, + ggtitle, + annotate, + theme_minimal, + theme, + element_text, + scale_color_manual, + scale_color_gradient, ) from sklearn.decomposition import PCA from umap import UMAP @@ -91,6 +124,7 @@ def _labels_to_1d(labels): arr = np.asarray(labels) return np.ravel(arr) + def get_color_mapping(labels): """ Map categorical labels to colors using ALPHABETICAL order (deterministic). @@ -102,17 +136,17 @@ def get_color_mapping(labels): n = len(unique_labels) if n <= 10: - cmap = plt.get_cmap('tab10', n) + cmap = plt.get_cmap("tab10", n) colors = [cmap(i) for i in range(n)] elif n <= 20: - cmap = plt.get_cmap('tab20', n) + cmap = plt.get_cmap("tab20", n) colors = [cmap(i) for i in range(n)] else: # stack multiple palettes to cover many categories palettes = [ - plt.get_cmap('tab20', 20), - plt.get_cmap('Dark2', 8), - plt.get_cmap('Accent', 8), + plt.get_cmap("tab20", 20), + plt.get_cmap("Dark2", 8), + plt.get_cmap("Accent", 8), ] colors = [] for pal in palettes: @@ -122,25 +156,31 @@ def get_color_mapping(labels): colors.extend(colors) colors = colors[:n] - to_hex = lambda c: '#%02x%02x%02x' % (int(c[0]*255), int(c[1]*255), int(c[2]*255)) + to_hex = lambda c: "#%02x%02x%02x" % ( + int(c[0] * 255), + int(c[1] * 255), + int(c[2] * 255), + ) color_hex = [to_hex(c) for c in colors] return dict(zip(unique_labels, color_hex)) -def plot_dim_reduced(matrix, labels, method='pca', color_type='categorical', title=None): +def plot_dim_reduced( + matrix, labels, method="pca", color_type="categorical", title=None +): """ Plot first two dims (PCA/UMAP). Uses alphabetical label ordering + shared palette. """ method = method.lower() - if method == 'pca': + if method == "pca": transformer = PCA(n_components=2) transformed = transformer.fit_transform(matrix) var_exp = transformer.explained_variance_ratio_ * 100 xlab = f"PC1 ({var_exp[0]:.1f}%)" ylab = f"PC2 ({var_exp[1]:.1f}%)" - xcol, ycol = 'PC1', 'PC2' - elif method == 'umap': + xcol, ycol = "PC1", "PC2" + elif method == "umap": transformer = UMAP(n_components=2) m = np.array(matrix, dtype=np.float32) try: @@ -148,7 +188,7 @@ def plot_dim_reduced(matrix, labels, method='pca', color_type='categorical', tit except TypeError: transformed = transformer.fit_transform(m) xlab, ylab = "UMAP1", "UMAP2" - xcol, ycol = 'UMAP1', 'UMAP2' + xcol, ycol = "UMAP1", "UMAP2" else: raise ValueError("Invalid method. Expected 'pca' or 'umap'.") @@ -162,20 +202,20 @@ def plot_dim_reduced(matrix, labels, method='pca', color_type='categorical', tit plot_title = title if title else f"{method.upper()} Scatter Plot" - if color_type == 'categorical': + if color_type == "categorical": p = ( - ggplot(df, aes(x=xcol, y=ycol, color='Label')) + ggplot(df, aes(x=xcol, y=ycol, color="Label")) + geom_point() + scale_color_manual(values=color_mapping) + labs(title=plot_title, x=xlab, y=ylab, color="Labels") + theme_minimal() ) - elif color_type == 'numerical': + elif color_type == "numerical": # numerical coloring ignores the categorical palette df_num = df.copy() - df_num["Label"] = pd.to_numeric(lbls, errors='coerce') + df_num["Label"] = pd.to_numeric(lbls, errors="coerce") p = ( - ggplot(df_num, aes(x=xcol, y=ycol, color='Label')) + ggplot(df_num, aes(x=xcol, y=ycol, color="Label")) + geom_point() + scale_color_gradient(low="blue", high="red") + labs(title=plot_title, x=xlab, y=ylab, color="Label") @@ -191,11 +231,15 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable): """ Kaplan–Meier curves with alphabetical label ordering + shared palette. """ - data = pd.DataFrame({ - 'Duration': _labels_to_1d(durations), - 'Event': _labels_to_1d(events), - 'Group': pd.Series(_labels_to_1d(categorical_variable), dtype="object").astype(str) - }) + data = pd.DataFrame( + { + "Duration": _labels_to_1d(durations), + "Event": _labels_to_1d(events), + "Group": pd.Series( + _labels_to_1d(categorical_variable), dtype="object" + ).astype(str), + } + ) # shared palette + fixed legend order color_mapping = get_color_mapping(data["Group"]) @@ -205,46 +249,54 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable): kmf = KaplanMeierFitter() survival_curves = [] for g in order: # iterate in the same alphabetical order - gd = data[data['Group'] == g] + gd = data[data["Group"] == g] if len(gd) == 0: continue - kmf.fit(gd['Duration'], gd['Event'], label=g) + kmf.fit(gd["Duration"], gd["Event"], label=g) surv_df = kmf.survival_function_.reset_index() - surv_df.columns = ['Time', 'Survival'] - surv_df['Group'] = g + surv_df.columns = ["Time", "Survival"] + surv_df["Group"] = g survival_curves.append(surv_df) plot_data = pd.concat(survival_curves, ignore_index=True) - plot_data["Group"] = pd.Categorical(plot_data["Group"], categories=order, ordered=True) + plot_data["Group"] = pd.Categorical( + plot_data["Group"], categories=order, ordered=True + ) # log-rank text - categories = pd.unique(data['Group']) + categories = pd.unique(data["Group"]) if len(categories) == 2: g1, g2 = categories[0], categories[1] - grp1 = data[data['Group'] == g1] - grp2 = data[data['Group'] == g2] - result = logrank_test(grp1['Duration'], grp2['Duration'], - event_observed_A=grp1['Event'], - event_observed_B=grp2['Event']) + grp1 = data[data["Group"] == g1] + grp2 = data[data["Group"] == g2] + result = logrank_test( + grp1["Duration"], + grp2["Duration"], + event_observed_A=grp1["Event"], + event_observed_B=grp2["Event"], + ) p_text = f"Log-rank p = {result.p_value:.2e}" elif len(categories) > 2: - result = multivariate_logrank_test(data['Duration'], data['Group'], data['Event']) + result = multivariate_logrank_test( + data["Duration"], data["Group"], data["Event"] + ) p_text = f"Multivariate log-rank p = {result.p_value:.2e}" else: p_text = "Only one group — log-rank test not applicable" p = ( - ggplot(plot_data, aes(x='Time', y='Survival', color='Group')) + ggplot(plot_data, aes(x="Time", y="Survival", color="Group")) + geom_step() - + labs(x='Time', y='Survival Probability', color='Group') - + ggtitle('Kaplan-Meier Survival Curves by Group') - + annotate("text", x=0.1, y=0.1, label=p_text, size=10, ha='left') + + labs(x="Time", y="Survival Probability", color="Group") + + ggtitle("Kaplan-Meier Survival Curves by Group") + + annotate("text", x=0.1, y=0.1, label=p_text, size=10, ha="left") + theme_minimal() - + theme(legend_title=element_text(size=10, weight='bold')) + + theme(legend_title=element_text(size=10, weight="bold")) + scale_color_manual(values=color_mapping) ) return p + def plot_scatter(true_values, predicted_values): """ Plots a scatterplot of true vs predicted values, with a regression line and annotated with the Pearson correlation coefficient. @@ -267,20 +319,26 @@ def plot_scatter(true_values, predicted_values): corr_text = f"Pearson r: {corr:.2f}" # Create DataFrame - df = pd.DataFrame({"True Values": true_values, "Predicted Values": predicted_values}) + df = pd.DataFrame( + {"True Values": true_values, "Predicted Values": predicted_values} + ) # Generate scatter plot with regression line plot = ( - ggplot(df, aes(x="True Values", y="Predicted Values")) + - geom_point(alpha=0.5) + - geom_smooth(method='lm', color='red') + - annotate("text", x=min(true_values), y=max(predicted_values), label=corr_text, ha='left', va='top', size=10) + - labs( - title="True vs Predicted Values", - x="True Values", - y="Predicted Values" - ) + - theme_minimal() + ggplot(df, aes(x="True Values", y="Predicted Values")) + + geom_point(alpha=0.5) + + geom_smooth(method="lm", color="red") + + annotate( + "text", + x=min(true_values), + y=max(predicted_values), + label=corr_text, + ha="left", + va="top", + size=10, + ) + + labs(title="True vs Predicted Values", x="True Values", y="Predicted Values") + + theme_minimal() ) return plot @@ -288,7 +346,15 @@ def plot_scatter(true_values, predicted_values): from scipy.stats import mannwhitneyu, kruskal -def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Values', figsize=(10, 6), jittersize = 4): + +def plot_boxplot( + categorical_x, + numerical_y, + title_x="Categories", + title_y="Values", + figsize=(10, 6), + jittersize=4, +): df = pd.DataFrame({title_x: categorical_x, title_y: numerical_y}) # Compute p-value @@ -296,7 +362,7 @@ def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Valu if len(groups) == 2: group1 = df[df[title_x] == groups[0]][title_y] group2 = df[df[title_x] == groups[1]][title_y] - stat, p = mannwhitneyu(group1, group2, alternative='two-sided') + stat, p = mannwhitneyu(group1, group2, alternative="two-sided") test_name = "Mann-Whitney U" else: group_data = [df[df[title_x] == group][title_y] for group in groups] @@ -305,8 +371,25 @@ def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Valu # Create a boxplot with jittered points plt.figure(figsize=figsize) - sns.boxplot(x=title_x, y=title_y, hue=title_x, data=df, palette='Set2', legend=False, fill= False) - sns.stripplot(x=title_x, y=title_y, data=df, color='black', size=jittersize, jitter=True, dodge=True, alpha=0.4) + sns.boxplot( + x=title_x, + y=title_y, + hue=title_x, + data=df, + palette="Set2", + legend=False, + fill=False, + ) + sns.stripplot( + x=title_x, + y=title_y, + data=df, + color="black", + size=jittersize, + jitter=True, + dodge=True, + alpha=0.4, + ) # Labels and p-value annotation plt.xlabel(title_x) @@ -314,21 +397,23 @@ def plot_boxplot(categorical_x, numerical_y, title_x='Categories', title_y='Valu plt.text( x=-0.4, y=plt.ylim()[1], - s=f'{test_name} p = {p:.3e}', - verticalalignment='top', - horizontalalignment='left', + s=f"{test_name} p = {p:.3e}", + verticalalignment="top", + horizontalalignment="left", fontsize=12, - bbox=dict(boxstyle='round,pad=0.3', facecolor='white', edgecolor='gray') + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray"), ) plt.tight_layout() return plt.gcf() + # given a vector of numerical values which may contain # NAN values, return a binary grouping based on median values def split_by_median(v): return ((v - torch.nanmedian(v)) > 0).float() + def evaluate_survival(outputs, durations, events): """ Computes the concordance index (c-index) for survival predictions. @@ -358,14 +443,16 @@ def evaluate_survival(outputs, durations, events): # Compute concordance index (lifelines expects higher risk → lower survival) c_index = concordance_index(durations, -outputs, events) - return {'cindex': c_index} + return {"cindex": c_index} + def generate_bootstrap_indices(n, n_bootstraps=1000, seed=42): rng = np.random.default_rng(seed) return [rng.choice(n, size=n, replace=True) for _ in range(n_bootstraps)] + # bootstrapping function for regression/classification tasks -def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci = 95, **kwargs): +def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci=95, **kwargs): scores = [] y_true = np.array(y_true) y_pred = np.array(y_pred) @@ -377,6 +464,7 @@ def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci = 95, **kwargs) upper = np.percentile(scores, 100 - (100 - ci) / 2) return scores, (np.mean(scores), lower, upper) + def evaluate_classifier(y_true, y_probs, print_report=False): """ Evaluate the performance of a classifier using multiple metrics and optionally print a detailed classification report. @@ -405,7 +493,7 @@ def evaluate_classifier(y_true, y_probs, print_report=False): balanced_acc = balanced_accuracy_score(y_true, y_pred) # F1 score (weighted) - f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0) + f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0) # Cohen's Kappa kappa = cohen_kappa_score(y_true, y_pred) @@ -415,10 +503,16 @@ def evaluate_classifier(y_true, y_probs, print_report=False): if y_probs.shape[1] == 2: # Binary classification y_probs_binary = y_probs[:, 1] # Use positive class probabilities average_auroc = roc_auc_score(y_true, y_probs_binary) - average_aupr = average_precision_score(y_true, y_probs_binary) # AUC-PR for binary case + average_aupr = average_precision_score( + y_true, y_probs_binary + ) # AUC-PR for binary case else: # Multiclass classification - average_auroc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='weighted') - average_aupr = average_precision_score(y_true, y_probs, average='weighted') # Weighted AUC-PR for multiclass + average_auroc = roc_auc_score( + y_true, y_probs, multi_class="ovr", average="weighted" + ) + average_aupr = average_precision_score( + y_true, y_probs, average="weighted" + ) # Weighted AUC-PR for multiclass except ValueError: average_auroc = None # Handle cases where AUROC cannot be computed average_aupr = None # Handle cases where AUC-PR cannot be computed @@ -434,13 +528,15 @@ def evaluate_classifier(y_true, y_probs, print_report=False): "f1_score": f1, "kappa": kappa, "average_auroc": average_auroc, - "average_aupr": average_aupr # Added AUC-PR + "average_aupr": average_aupr, # Added AUC-PR } + from sklearn.metrics import roc_curve, roc_auc_score from sklearn.preprocessing import label_binarize from sklearn.metrics import precision_recall_curve, average_precision_score + def plot_roc_curves(y_true, y_probs): """ Plot ROC curves using plotnine for binary or multiclass classification. @@ -457,7 +553,15 @@ def plot_roc_curves(y_true, y_probs): # Binary classification fpr, tpr, _ = roc_curve(y_true, y_probs[:, 1]) auc_score = roc_auc_score(y_true, y_probs[:, 1]) - plot_data.append(pd.DataFrame({'fpr': fpr, 'tpr': tpr, 'label': [f'Class 1 (AUC = {auc_score:.2f})'] * len(fpr)})) + plot_data.append( + pd.DataFrame( + { + "fpr": fpr, + "tpr": tpr, + "label": [f"Class 1 (AUC = {auc_score:.2f})"] * len(fpr), + } + ) + ) else: # Multiclass classification classes = np.arange(n_classes) @@ -466,11 +570,13 @@ def plot_roc_curves(y_true, y_probs): for i in classes: fpr, tpr, _ = roc_curve(y_true_bin[:, i], y_probs[:, i]) auc_score = roc_auc_score(y_true_bin[:, i], y_probs[:, i]) - df = pd.DataFrame({ - 'fpr': fpr, - 'tpr': tpr, - 'label': [f'Class {i} (AUC = {auc_score:.2f})'] * len(fpr) - }) + df = pd.DataFrame( + { + "fpr": fpr, + "tpr": tpr, + "label": [f"Class {i} (AUC = {auc_score:.2f})"] * len(fpr), + } + ) plot_data.append(df) # Combine all data @@ -478,19 +584,16 @@ def plot_roc_curves(y_true, y_probs): # Plot using plotnine roc_plot = ( - ggplot(all_data, aes(x='fpr', y='tpr', color='label')) + - geom_line(size=1.2) + - geom_abline(intercept=0, slope=1, linetype='dashed', color='gray') + - labs( - title='ROC Curve', - x='False Positive Rate', - y='True Positive Rate' - ) + - theme_minimal() + ggplot(all_data, aes(x="fpr", y="tpr", color="label")) + + geom_line(size=1.2) + + geom_abline(intercept=0, slope=1, linetype="dashed", color="gray") + + labs(title="ROC Curve", x="False Positive Rate", y="True Positive Rate") + + theme_minimal() ) return roc_plot + def plot_pr_curves(y_true, y_probs): """ Plot Precision-Recall (PR) curves using plotnine for binary or multiclass classification. @@ -507,24 +610,32 @@ def plot_pr_curves(y_true, y_probs): # Binary classification precision, recall, _ = precision_recall_curve(y_true, y_probs[:, 1]) aupr = average_precision_score(y_true, y_probs[:, 1]) - plot_data.append(pd.DataFrame({ - 'recall': recall, - 'precision': precision, - 'label': [f'Class 1 (AUPR = {aupr:.2f})'] * len(recall) - })) + plot_data.append( + pd.DataFrame( + { + "recall": recall, + "precision": precision, + "label": [f"Class 1 (AUPR = {aupr:.2f})"] * len(recall), + } + ) + ) else: # Multiclass classification (one-vs-rest) classes = np.arange(n_classes) y_true_bin = label_binarize(y_true, classes=classes) for i in classes: - precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_probs[:, i]) + precision, recall, _ = precision_recall_curve( + y_true_bin[:, i], y_probs[:, i] + ) aupr = average_precision_score(y_true_bin[:, i], y_probs[:, i]) - df = pd.DataFrame({ - 'recall': recall, - 'precision': precision, - 'label': [f'Class {i} (AUPR = {aupr:.2f})'] * len(recall) - }) + df = pd.DataFrame( + { + "recall": recall, + "precision": precision, + "label": [f"Class {i} (AUPR = {aupr:.2f})"] * len(recall), + } + ) plot_data.append(df) # Combine all data @@ -532,14 +643,10 @@ def plot_pr_curves(y_true, y_probs): # Plot using plotnine pr_plot = ( - ggplot(all_data, aes(x='recall', y='precision', color='label')) + - geom_line(size=1.2) + - labs( - title='Precision-Recall Curve', - x='Recall', - y='Precision' - ) + - theme_minimal() + ggplot(all_data, aes(x="recall", y="precision", color="label")) + + geom_line(size=1.2) + + labs(title="Precision-Recall Curve", x="Recall", y="Precision") + + theme_minimal() ) return pr_plot @@ -564,11 +671,14 @@ def evaluate_regressor(y_true, y_pred): - 'pearson_corr': The Pearson correlation coefficient indicating the linear relationship strength between the true and predicted values. """ mse = mean_squared_error(y_true, y_pred) - slope, intercept, r_value, p_value, std_err = linregress(y_true,y_pred) + slope, intercept, r_value, p_value, std_err = linregress(y_true, y_pred) r2 = r_value**2 return {"mse": mse, "r2": r2, "pearson_corr": r_value} -def evaluate_wrapper(method, y_pred_dict, dataset, surv_event_var = None, surv_time_var = None): + +def evaluate_wrapper( + method, y_pred_dict, dataset, surv_event_var=None, surv_time_var=None +): """ Evaluates predictions for different variables within a dataset using appropriate metrics based on the variable type. Supports evaluation for numerical, categorical, and survival data. @@ -590,29 +700,34 @@ def evaluate_wrapper(method, y_pred_dict, dataset, surv_event_var = None, surv_t """ metrics_list = [] for var in y_pred_dict.keys(): - if dataset.variable_types[var] == 'numerical': + if dataset.variable_types[var] == "numerical": if var == surv_event_var: events = dataset.ann[surv_event_var] durations = dataset.ann[surv_time_var] metrics = evaluate_survival(y_pred_dict[var], durations, events) else: ind = ~torch.isnan(dataset.ann[var]) - metrics = evaluate_regressor(dataset.ann[var][ind], y_pred_dict[var][ind].flatten()) + metrics = evaluate_regressor( + dataset.ann[var][ind], y_pred_dict[var][ind].flatten() + ) else: ind = ~torch.isnan(dataset.ann[var]) metrics = evaluate_classifier(dataset.ann[var][ind], y_pred_dict[var][ind]) for metric, value in metrics.items(): - metrics_list.append({ - 'method': method, - 'var': var, - 'variable_type': dataset.variable_types[var], - 'metric': metric, - 'value': value - }) + metrics_list.append( + { + "method": method, + "var": var, + "variable_type": dataset.variable_types[var], + "metric": metric, + "value": value, + } + ) # Convert the list of metrics to a DataFrame return pd.DataFrame(metrics_list) + def get_predicted_labels(y_pred_dict, dataset, split, method_name): """ Generate a DataFrame with class probabilities and associated metadata. @@ -637,59 +752,87 @@ def get_predicted_labels(y_pred_dict, dataset, split, method_name): dfs = [] for var in y_pred_dict.keys(): - if dataset.variable_types[var] == 'categorical': + if dataset.variable_types[var] == "categorical": # Predicted probabilities probabilities = y_pred_dict[var] # Convert class indices to labels if mappings exist if var in dataset.label_mappings.keys(): - class_labels = [dataset.label_mappings[var][idx] for idx in range(probabilities.shape[1])] + class_labels = [ + dataset.label_mappings[var][idx] + for idx in range(probabilities.shape[1]) + ] else: - class_labels = [f'class_{i}' for i in range(probabilities.shape[1])] + class_labels = [f"class_{i}" for i in range(probabilities.shape[1])] # Get true labels (y_true) - y_true = [dataset.label_mappings[var][int(x.item())] if var in dataset.label_mappings.keys() and not np.isnan(x.item()) else np.nan for x in dataset.ann[var]] + y_true = [ + ( + dataset.label_mappings[var][int(x.item())] + if var in dataset.label_mappings.keys() and not np.isnan(x.item()) + else np.nan + ) + for x in dataset.ann[var] + ] # Predicted labels (argmax of probabilities) y_pred_indices = np.argmax(probabilities, axis=1) - y_pred = [dataset.label_mappings[var][idx] if var in dataset.label_mappings.keys() else idx for idx in y_pred_indices] + y_pred = [ + ( + dataset.label_mappings[var][idx] + if var in dataset.label_mappings.keys() + else idx + ) + for idx in y_pred_indices + ] # Create a DataFrame for each sample and its probabilities for i, sample_id in enumerate(dataset.samples): for j, class_label in enumerate(class_labels): - dfs.append({ - 'sample_id': sample_id, - 'variable': var, - 'class_label': class_label, - 'probability': probabilities[i, j], - 'known_label': y_true[i], - 'predicted_label': y_pred[i], - 'split': split, - 'method': method_name - }) + dfs.append( + { + "sample_id": sample_id, + "variable": var, + "class_label": class_label, + "probability": probabilities[i, j], + "known_label": y_true[i], + "predicted_label": y_pred[i], + "split": split, + "method": method_name, + } + ) else: # For numerical variables, set class_label and probability to NA y_true = [x.item() for x in dataset.ann[var]] y_pred = [x.item() for x in y_pred_dict[var]] for i, sample_id in enumerate(dataset.samples): - dfs.append({ - 'sample_id': sample_id, - 'variable': var, - 'class_label': np.nan, - 'probability': np.nan, - 'known_label': y_true[i], - 'predicted_label': y_pred[i], - 'split': split, - 'method': method_name - }) + dfs.append( + { + "sample_id": sample_id, + "variable": var, + "class_label": np.nan, + "probability": np.nan, + "known_label": y_true[i], + "predicted_label": y_pred[i], + "split": split, + "method": method_name, + } + ) # Combine all rows into a DataFrame return pd.DataFrame(dfs) - - -def evaluate_baseline_performance(train_dataset, test_dataset, variable_name, methods, n_folds=5, n_jobs=4, use_pca=False, n_components=100): +def evaluate_baseline_performance( + train_dataset, + test_dataset, + variable_name, + methods, + n_folds=5, + n_jobs=4, + use_pca=False, + n_components=100, +): """ Evaluates the performance of machine learning models on a given variable with optional PCA for dimensionality reduction. @@ -706,6 +849,7 @@ def evaluate_baseline_performance(train_dataset, test_dataset, variable_name, me Returns: pd.DataFrame: A DataFrame containing metrics for each method. """ + def prepare_data(data_object, pca_model=None, fit_pca=False): # Concatenate Data Matrices X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) @@ -733,41 +877,57 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): # Cross-Validation and Training kf = KFold(n_splits=n_folds, shuffle=True, random_state=42) - X_train, y_train, train_indices = prepare_data(train_dataset, pca_model=pca_model, fit_pca=True) + X_train, y_train, train_indices = prepare_data( + train_dataset, pca_model=pca_model, fit_pca=True + ) print("Train:", X_train.shape) - X_test, y_test, test_indices = prepare_data(test_dataset, pca_model=pca_model, fit_pca=False) + X_test, y_test, test_indices = prepare_data( + test_dataset, pca_model=pca_model, fit_pca=False + ) print("Test:", X_test.shape) metrics_list = [] predictions = [] # Collect all predictions for method in methods: - if variable_type == 'categorical': - if method == 'RandomForest': + if variable_type == "categorical": + if method == "RandomForest": model = RandomForestClassifier(random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [10, 20, None]} - elif method == 'SVM': + params = {"n_estimators": [100, 200, 300], "max_depth": [10, 20, None]} + elif method == "SVM": model = SVC(probability=True, random_state=42) - params = {'C': [0.1, 1, 10], 'kernel': ['rbf', 'poly']} - elif method == 'XGBoost': + params = {"C": [0.1, 1, 10], "kernel": ["rbf", "poly"]} + elif method == "XGBoost": if XGBClassifier is None: - print("[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping.") + print( + "[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping." + ) continue - model = XGBClassifier(eval_metric='logloss', random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [3, 6, 9], 'learning_rate': [0.01, 0.1, 0.2]} - elif variable_type == 'numerical': - if method == 'RandomForest': + model = XGBClassifier(eval_metric="logloss", random_state=42) + params = { + "n_estimators": [100, 200, 300], + "max_depth": [3, 6, 9], + "learning_rate": [0.01, 0.1, 0.2], + } + elif variable_type == "numerical": + if method == "RandomForest": model = RandomForestRegressor(random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [10, 20, None]} - elif method == 'SVM': + params = {"n_estimators": [100, 200, 300], "max_depth": [10, 20, None]} + elif method == "SVM": model = SVR() - params = {'C': [0.1, 1, 10], 'kernel': ['rbf', 'poly']} - elif method == 'XGBoost': + params = {"C": [0.1, 1, 10], "kernel": ["rbf", "poly"]} + elif method == "XGBoost": if XGBRegressor is None: - print("[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping.") + print( + "[WARNING] XGBoost is not available (on macOS, run: brew install libomp). Skipping." + ) continue model = XGBRegressor(random_state=42) - params = {'n_estimators': [100, 200, 300], 'max_depth': [3, 6, 9], 'learning_rate': [0.01, 0.1, 0.2]} + params = { + "n_estimators": [100, 200, 300], + "max_depth": [3, 6, 9], + "learning_rate": [0.01, 0.1, 0.2], + } print("Training method:", method) grid_search = GridSearchCV(model, params, cv=kf, n_jobs=n_jobs) @@ -775,35 +935,41 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): best_model = grid_search.best_estimator_ # Predict on test data - if variable_type == 'categorical': + if variable_type == "categorical": y_probs = best_model.predict_proba(X_test) metrics = evaluate_classifier(y_test, y_probs, print_report=True) y_pred_dict = {variable_name: y_probs} - elif variable_type == 'numerical': + elif variable_type == "numerical": y_pred = best_model.predict(X_test) metrics = evaluate_regressor(y_test, y_pred) y_pred_dict = {variable_name: y_pred} # need to get test indices to only consider samples with labels - df_preds = get_predicted_labels(y_pred_dict, test_dataset.subset(test_indices), 'test', method) + df_preds = get_predicted_labels( + y_pred_dict, test_dataset.subset(test_indices), "test", method + ) predictions.append(df_preds) for metric, value in metrics.items(): - metrics_list.append({ - 'method': method + ('Classifier' if variable_type == 'categorical' else 'Regressor'), - 'var': variable_name, - 'variable_type': variable_type, - 'metric': metric, - 'value': value - }) + metrics_list.append( + { + "method": method + + ("Classifier" if variable_type == "categorical" else "Regressor"), + "var": variable_name, + "variable_type": variable_type, + "metric": metric, + "value": value, + } + ) predictions = pd.concat(predictions, ignore_index=True) return pd.DataFrame(metrics_list), predictions - -def evaluate_baseline_survival_performance(train_dataset, test_dataset, duration_col, event_col, n_folds=5, n_jobs=4): +def evaluate_baseline_survival_performance( + train_dataset, test_dataset, duration_col, event_col, n_folds=5, n_jobs=4 +): """ Evaluates the baseline performance of a Random Survival Forest model on survival data using the Concordance Index. @@ -825,6 +991,7 @@ def evaluate_baseline_survival_performance(train_dataset, test_dataset, duration """ print(f"[INFO] Evaluating baseline survival prediction performance") + def prepare_data(data_object, duration_col, event_col): # Concatenate Data Matrices X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) @@ -832,8 +999,10 @@ def prepare_data(data_object, duration_col, event_col): # Prepare Survival Data (Durations and Events) durations = np.array(data_object.ann[duration_col]) events = np.array(data_object.ann[event_col]) - y = np.array([(event, duration) for event, duration in zip(events, durations)], - dtype=[('Event', '?'), ('Time', '= alpha or resulting child would have < min_samples_per_group. @@ -1080,7 +1287,7 @@ def recursive_binary_split_minN(df, score='pred_risk', time='OS.time', event='OS continue left = node[node[score] <= cutoff] - right = node[node[score] > cutoff] + right = node[node[score] > cutoff] # if either resulting child is smaller than required, reject split if len(left) < min_samples_per_group or len(right) < min_samples_per_group: @@ -1092,7 +1299,7 @@ def recursive_binary_split_minN(df, score='pred_risk', time='OS.time', event='OS queue.append(left) queue.append(right) - df['auto_group'] = df.index.map(groups) + df["auto_group"] = df.index.map(groups) return df @@ -1106,65 +1313,71 @@ def plot_hazard_ratios(cox_model): cox_model = cox_model[0] # Extract summary - coef_summary = cox_model.summary[['coef', 'coef lower 95%', 'coef upper 95%', 'p']].copy() - coef_summary.columns = ['coef', 'coef_lower_95', 'coef_upper_95', 'p'] - coef_summary['variable'] = coef_summary.index + coef_summary = cox_model.summary[ + ["coef", "coef lower 95%", "coef upper 95%", "p"] + ].copy() + coef_summary.columns = ["coef", "coef_lower_95", "coef_upper_95", "p"] + coef_summary["variable"] = coef_summary.index # Sort by p-value - coef_summary_sorted = coef_summary.sort_values('p').reset_index(drop=True) + coef_summary_sorted = coef_summary.sort_values("p").reset_index(drop=True) # Add significance stars def significance(p): if p < 0.0001: - return '***' + return "***" elif p < 0.001: - return '**' + return "**" elif p < 0.05: - return '*' + return "*" elif p < 0.1: - return '.' + return "." else: - return '' + return "" - coef_summary_sorted['stars'] = coef_summary_sorted['p'].apply(significance) + coef_summary_sorted["stars"] = coef_summary_sorted["p"].apply(significance) # Reverse the order for top-to-bottom importance - coef_summary_sorted['variable'] = pd.Categorical( - coef_summary_sorted['variable'], - categories=coef_summary_sorted['variable'][::-1], - ordered=True + coef_summary_sorted["variable"] = pd.Categorical( + coef_summary_sorted["variable"], + categories=coef_summary_sorted["variable"][::-1], + ordered=True, ) c_index = cox_model.concordance_index_ # Plot p = ( - ggplot(coef_summary_sorted, aes(x='coef', y='variable')) + ggplot(coef_summary_sorted, aes(x="coef", y="variable")) + geom_errorbarh( - aes(xmin='coef_lower_95', xmax='coef_upper_95'), - height=0.2, color='skyblue' + aes(xmin="coef_lower_95", xmax="coef_upper_95"), height=0.2, color="skyblue" + ) + + geom_point(color="skyblue", size=3) + + geom_text(aes(label="stars"), nudge_y=0.1, size=10) + + annotate("vline", xintercept=0, linetype="dashed", color="gray") + + labs( + x="Log Hazard Ratio", + y="", + title=f"Log Hazard Ratios Sorted by P-Value with 95% CI\n Model C-index: {c_index:.2f}", ) - + geom_point(color='skyblue', size=3) - + geom_text(aes(label='stars'), nudge_y=0.1, size=10) - + annotate('vline', xintercept=0, linetype='dashed', color='gray') - + labs(x='Log Hazard Ratio', y='', title=f'Log Hazard Ratios Sorted by P-Value with 95% CI\n Model C-index: {c_index:.2f}') + theme_bw() + theme( axis_text_y=element_text(size=10), axis_text_x=element_text(size=10), - plot_title=element_text(weight='bold'), + plot_title=element_text(weight="bold"), ) ) return p + def build_cox_model( df: pd.DataFrame, duration_col: str, event_col: str, n_splits: int = 5, random_state: int = 42, - eval_time: float | None = None, # single horizon, same units as duration_col + eval_time: float | None = None, # single horizon, same units as duration_col low_variance_threshold: float = 0.01, - cox_penalizer = 0.05, + cox_penalizer=0.05, return_metrics: bool = True, ): """ @@ -1186,14 +1399,18 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold for feature in df.drop(columns=[duration_col, event_col]).columns: v1 = df.loc[events, feature].var() v0 = df.loc[~events, feature].var() - if (v1 is not None and v1 < threshold) or (v0 is not None and v0 < threshold): + if (v1 is not None and v1 < threshold) or ( + v0 is not None and v0 < threshold + ): low_var.append(feature) df_f = df.drop(columns=low_var, errors="ignore") if low_var: print("Removed low variance features:", low_var) return df_f - df = remove_low_variance_survival_features(df, duration_col, event_col, low_variance_threshold) + df = remove_low_variance_survival_features( + df, duration_col, event_col, low_variance_threshold + ) metrics = { "cv_cindex_mean": None, "cv_auc_mean": None, @@ -1205,7 +1422,7 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold for train_idx, test_idx in kf.split(df): train_df = df.iloc[train_idx] - test_df = df.iloc[test_idx] + test_df = df.iloc[test_idx] model = CoxPHFitter(penalizer=cox_penalizer) model.fit(train_df, duration_col=duration_col, event_col=event_col) @@ -1215,7 +1432,7 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold ci = concordance_index( test_df[duration_col].values, -risk_scores, # higher hazard => higher risk - test_df[event_col].astype(int).values + test_df[event_col].astype(int).values, ) c_indices.append(ci) @@ -1223,11 +1440,11 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold if eval_time is not None: y_train = Surv.from_arrays( event=train_df[event_col].astype(bool).values, - time=train_df[duration_col].values + time=train_df[duration_col].values, ) y_test = Surv.from_arrays( event=test_df[event_col].astype(bool).values, - time=test_df[duration_col].values + time=test_df[duration_col].values, ) # Only evaluate if the horizon lies strictly inside this test fold's follow-up @@ -1264,7 +1481,7 @@ def k_means_clustering(data, k): - kmeans: The fitted KMeans instance, which can be used to access cluster centers and other attributes. """ # Initialize the KMeans model - kmeans = KMeans(n_clusters=k, n_init='auto', random_state=42) + kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42) # Fit the model to the data kmeans.fit(data) @@ -1274,6 +1491,7 @@ def k_means_clustering(data, k): return cluster_labels, kmeans + def louvain_clustering(X, threshold=None, k=None): """ Create a graph from pairwise distances within X. You can define a threshold to connect edges @@ -1293,11 +1511,11 @@ def louvain_clustering(X, threshold=None, k=None): for j in range(i + 1, distances.shape[0]): # If a threshold is defined, use it to create edges if threshold is not None and distances[i, j] < threshold: - G.add_edge(i, j, weight=1/distances[i, j]) + G.add_edge(i, j, weight=1 / distances[i, j]) # If k is defined, add an edge if j is one of i's k-nearest neighbors elif k is not None: - if np.argsort(distances[i])[:k + 1].__contains__(j): - G.add_edge(i, j, weight=1/distances[i, j]) + if np.argsort(distances[i])[: k + 1].__contains__(j): + G.add_edge(i, j, weight=1 / distances[i, j]) partition = community_louvain.best_partition(G) cluster_labels = np.full(len(X), np.nan, dtype=float) @@ -1331,33 +1549,39 @@ def get_optimal_clusters(data, min_k=2, max_k=10): cluster_labels_dict = {} # To store cluster labels for each k for k in range(min_k, max_k + 1): - kmeans = KMeans(n_clusters=k, n_init = 'auto', random_state=42) + kmeans = KMeans(n_clusters=k, n_init="auto", random_state=42) cluster_labels = kmeans.fit_predict(data) silhouette_avg = silhouette_score(data, cluster_labels) silhouette_scores.append((k, silhouette_avg)) cluster_labels_dict[k] = cluster_labels # Store cluster labels - #print(f"Number of clusters: {k}, Silhouette Score: {silhouette_avg:.4f}") + # print(f"Number of clusters: {k}, Silhouette Score: {silhouette_avg:.4f}") # Convert silhouette scores to DataFrame for easier handling and visualization - silhouette_scores_df = pd.DataFrame(silhouette_scores, columns=['k', 'silhouette_score']) + silhouette_scores_df = pd.DataFrame( + silhouette_scores, columns=["k", "silhouette_score"] + ) # Find the optimal k (number of clusters) with the highest silhouette score - optimal_k = silhouette_scores_df.loc[silhouette_scores_df['silhouette_score'].idxmax()]['k'] + optimal_k = silhouette_scores_df.loc[ + silhouette_scores_df["silhouette_score"].idxmax() + ]["k"] # Retrieve the cluster labels for the optimal k optimal_cluster_labels = cluster_labels_dict[optimal_k] return optimal_cluster_labels, optimal_k, silhouette_scores_df + # compute adjusted rand index; adjusted mutual information for two sets of paired labels def compute_ami_ari(labels1, labels2): - def convert_nan (labels): - return ['unavailable' if pd.isna(x) else x for x in labels] + def convert_nan(labels): + return ["unavailable" if pd.isna(x) else x for x in labels] + labels1 = convert_nan(labels1) labels2 = convert_nan(labels2) ami = adjusted_mutual_info_score(labels1, labels2) ari = adjusted_rand_score(labels1, labels2) - return {'ami': ami, 'ari': ari} + return {"ami": ami, "ari": ari} def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): @@ -1369,18 +1593,22 @@ def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): - labels2: The second set of labels. """ # Compute the cross-tabulation - ct = pd.crosstab(pd.Series(labels1, name='Labels Set 1'), pd.Series(labels2, name='Labels Set 2')) + ct = pd.crosstab( + pd.Series(labels1, name="Labels Set 1"), pd.Series(labels2, name="Labels Set 2") + ) # Normalize the cross-tabulation matrix column-wise ct_normalized = ct.div(ct.sum(axis=1), axis=0) # Plot the heatmap - plt.figure(figsize = figsize) - sns.heatmap(ct_normalized, annot=True,cmap='viridis', linewidths=.5)# col_cluster=False) - plt.title('Concordance between label groups') + plt.figure(figsize=figsize) + sns.heatmap( + ct_normalized, annot=True, cmap="viridis", linewidths=0.5 + ) # col_cluster=False) + plt.title("Concordance between label groups") return plt.gcf() -def scale_and_standardize_by_labels(data_matrix, labels): +def scale_and_standardize_by_labels(data_matrix, labels): """ Scale and standardize data_matrix by factor labels. Data is split by factors and each subset is scaled/standardized. @@ -1417,15 +1645,22 @@ def scale_and_standardize_by_labels(data_matrix, labels): return scaled_data_matrix + # df: annotation data frame ('clin.csv') # given a pandas data frame, go through each column and find out if the column is numeric or categorical def get_variable_types(df): # Select only the categorical columns - df_categorical = df.select_dtypes(include=['object', 'category']) - variable_types = {col: 'categorical' for col in df_categorical.columns} - variable_types.update({col: 'numerical' for col in df.select_dtypes(exclude=['object', 'category']).columns}) + df_categorical = df.select_dtypes(include=["object", "category"]) + variable_types = {col: "categorical" for col in df_categorical.columns} + variable_types.update( + { + col: "numerical" + for col in df.select_dtypes(exclude=["object", "category"]).columns + } + ) return variable_types + def create_covariate_matrix(covariates, variable_types, ann): """ Convert clinical variables used as covariates into a covariate matrix as a Pandas DataFrame. @@ -1444,20 +1679,26 @@ def create_covariate_matrix(covariates, variable_types, ann): feature_names = [] for var in covariates: - if variable_types.get(var) == 'categorical': + if variable_types.get(var) == "categorical": # One-hot-encode categorical variables with 0/1 encoding one_hot = pd.get_dummies(ann[var], prefix=var).astype(int) covariate_features.append(one_hot.T) # Transpose to make features rows feature_names.extend(one_hot.columns.tolist()) - elif variable_types.get(var) == 'numerical': + elif variable_types.get(var) == "numerical": # Handle numerical variables with missing values numerical_data = ann[[var]].copy() # Impute missing values using the median and assign back - numerical_data[var] = numerical_data[var].fillna(numerical_data[var].median()) - covariate_features.append(numerical_data.T) # Transpose to make features rows + numerical_data[var] = numerical_data[var].fillna( + numerical_data[var].median() + ) + covariate_features.append( + numerical_data.T + ) # Transpose to make features rows feature_names.append(var) else: - raise ValueError(f"Unknown variable type for {var}: {variable_types.get(var)}") + raise ValueError( + f"Unknown variable type for {var}: {variable_types.get(var)}" + ) # Concatenate all covariate features into a single DataFrame covariate_matrix = pd.concat(covariate_features, axis=0) @@ -1468,23 +1709,31 @@ def create_covariate_matrix(covariates, variable_types, ann): return covariate_matrix -def generate_synthetic_batches (n_samples_per_batch = 150, n_features = 50): + +def generate_synthetic_batches(n_samples_per_batch=150, n_features=50): # Generate batch 1 data (mean centered at 0, standard deviation 1) - batch1_data = np.random.normal(loc=0.0, scale=1.0, size=(n_samples_per_batch, n_features)) + batch1_data = np.random.normal( + loc=0.0, scale=1.0, size=(n_samples_per_batch, n_features) + ) # Generate batch 2 data (mean shifted by +2, standard deviation 1.5) - batch2_data = np.random.normal(loc=2.0, scale=1.5, size=(n_samples_per_batch, n_features)) + batch2_data = np.random.normal( + loc=2.0, scale=1.5, size=(n_samples_per_batch, n_features) + ) # Combine into a single dataset combined_data = np.vstack([batch1_data, batch2_data]) - batch_labels = np.array([0] * n_samples_per_batch + [1] * n_samples_per_batch) # Batch labels + batch_labels = np.array( + [0] * n_samples_per_batch + [1] * n_samples_per_batch + ) # Batch labels # Convert to Pandas DataFrame feature_columns = [f"feature_{i+1}" for i in range(n_features)] synthetic_data = pd.DataFrame(combined_data, columns=feature_columns) return synthetic_data, batch_labels -def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = False): + +def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=False): """ Align embeddings from two batches using Optimal Transport, preserving the order of samples. @@ -1512,7 +1761,7 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = Fa batch2_embeddings = embeddings.iloc[batch2_indices].to_numpy() # Compute the cost matrix (e.g., Euclidean distances) - cost_matrix = ot.dist(batch1_embeddings, batch2_embeddings, metric='euclidean') + cost_matrix = ot.dist(batch1_embeddings, batch2_embeddings, metric="euclidean") # Solve the optimal transport problem n_samples_1 = batch1_embeddings.shape[0] @@ -1534,19 +1783,35 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels = Fa scaler1 = StandardScaler() scaler2 = StandardScaler() - aligned_embeddings[batch1_indices] = scaler1.fit_transform(aligned_embeddings[batch1_indices]) - aligned_embeddings[batch2_indices] = scaler2.fit_transform(aligned_embeddings[batch2_indices]) + aligned_embeddings[batch1_indices] = scaler1.fit_transform( + aligned_embeddings[batch1_indices] + ) + aligned_embeddings[batch2_indices] = scaler2.fit_transform( + aligned_embeddings[batch2_indices] + ) # Convert back to pandas DataFrame and Series, preserving indices - aligned_embeddings_df = pd.DataFrame(aligned_embeddings, columns=embeddings.columns, index=embeddings.index) - aligned_batch_labels = pd.Series(batch_labels, index=embeddings.index, name="batch_labels") + aligned_embeddings_df = pd.DataFrame( + aligned_embeddings, columns=embeddings.columns, index=embeddings.index + ) + aligned_batch_labels = pd.Series( + batch_labels, index=embeddings.index, name="batch_labels" + ) return aligned_embeddings_df, aligned_batch_labels from sklearn.neighbors import NearestNeighbors -def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, standardize_by_labels=False, random_state=None): + +def reciprocal_pca_mnn( + embeddings, + batch_labels, + n_components=10, + n_neighbors=5, + standardize_by_labels=False, + random_state=None, +): """ Align embeddings from two batches using Reciprocal PCA (rPCA) and Mutual Nearest Neighbors (MNN). @@ -1579,8 +1844,12 @@ def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, # Standardize embeddings separately for each batch if required if standardize_by_labels: - batch1_embeddings = (batch1_embeddings - batch1_embeddings.mean(axis=0)) / batch1_embeddings.std(axis=0) - batch2_embeddings = (batch2_embeddings - batch2_embeddings.mean(axis=0)) / batch2_embeddings.std(axis=0) + batch1_embeddings = ( + batch1_embeddings - batch1_embeddings.mean(axis=0) + ) / batch1_embeddings.std(axis=0) + batch2_embeddings = ( + batch2_embeddings - batch2_embeddings.mean(axis=0) + ) / batch2_embeddings.std(axis=0) # Perform PCA on both batches pca1 = PCA(n_components=n_components, random_state=random_state) @@ -1628,10 +1897,14 @@ def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, aligned_embeddings[batch2_indices] = aligned_batch2 # Convert back to pandas DataFrame and Series, preserving indices - aligned_embeddings_df = pd.DataFrame(aligned_embeddings, - columns=[f"rPCA_{i+1}" for i in range(n_components)], - index=embeddings.index) - aligned_batch_labels = pd.Series(batch_labels, index=embeddings.index, name="batch_labels") + aligned_embeddings_df = pd.DataFrame( + aligned_embeddings, + columns=[f"rPCA_{i+1}" for i in range(n_components)], + index=embeddings.index, + ) + aligned_batch_labels = pd.Series( + batch_labels, index=embeddings.index, name="batch_labels" + ) return aligned_embeddings_df, aligned_batch_labels @@ -1640,6 +1913,7 @@ def reciprocal_pca_mnn(embeddings, batch_labels, n_components=10, n_neighbors=5, import requests from io import StringIO + class CBioPortalData: def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): self.base_url = base_url @@ -1673,7 +1947,9 @@ def extract_archive(self, archive_path): with tarfile.open(archive_path, "r:gz") as tar: tar.extractall() - self.data_files = [f for f in os.listdir(base) if f.startswith("data_") and f.endswith(".txt")] + self.data_files = [ + f for f in os.listdir(base) if f.startswith("data_") and f.endswith(".txt") + ] return base def read_data(self, files=None): @@ -1684,12 +1960,12 @@ def read_data(self, files=None): for datatype, file in files.items(): print(f"Importing {file}...") file_path = os.path.join(self.study_id, file) - df = pd.read_csv(file_path, sep='\t', comment='#', low_memory=False) + df = pd.read_csv(file_path, sep="\t", comment="#", low_memory=False) - if 'mutations' in file: + if "mutations" in file: print(f"Binarizing and converting {file} to matrix...") df = self.binarize_mutations(df) - elif 'clinical' not in file and 'drug_treatment' not in file: + elif "clinical" not in file and "drug_treatment" not in file: print(f"Converting {file} to matrix...") df = self.process_data(df) @@ -1697,12 +1973,12 @@ def read_data(self, files=None): return data def process_data(self, df): - if 'Hugo_Symbol' in df.columns and 'Entrez_Gene_Id' in df.columns: - df = df.drop(columns=['Entrez_Gene_Id'], errors='ignore') + if "Hugo_Symbol" in df.columns and "Entrez_Gene_Id" in df.columns: + df = df.drop(columns=["Entrez_Gene_Id"], errors="ignore") - if 'Hugo_Symbol' in df.columns: - df = df.drop_duplicates(subset=['Hugo_Symbol']) - df.set_index('Hugo_Symbol', inplace=True) + if "Hugo_Symbol" in df.columns: + df = df.drop_duplicates(subset=["Hugo_Symbol"]) + df.set_index("Hugo_Symbol", inplace=True) return df @@ -1711,10 +1987,18 @@ def binarize_mutations(self, df): for col in required_cols: if col not in df.columns: - raise ValueError(f"Can't map mutations to sample IDs. Column {col} not found.") + raise ValueError( + f"Can't map mutations to sample IDs. Column {col} not found." + ) - mutation_counts = df.groupby(["Hugo_Symbol", "Tumor_Sample_Barcode"]).size().reset_index(name='count') - mutation_matrix = mutation_counts.pivot(index='Hugo_Symbol', columns='Tumor_Sample_Barcode', values='count').fillna(0) + mutation_counts = ( + df.groupby(["Hugo_Symbol", "Tumor_Sample_Barcode"]) + .size() + .reset_index(name="count") + ) + mutation_matrix = mutation_counts.pivot( + index="Hugo_Symbol", columns="Tumor_Sample_Barcode", values="count" + ).fillna(0) mutation_matrix[mutation_matrix > 0] = 1 return mutation_matrix @@ -1729,23 +2013,25 @@ def get_cbioportal_data(self, study_id, files=None): if files is None: self.print_data_files() - print("\n\nPlease select a list of files to import. Example:\n get_cbioportal_data('study_id', files={'mut': 'data_mutations.txt', 'clin': 'data_clinical_patient.txt'})") + print( + "\n\nPlease select a list of files to import. Example:\n get_cbioportal_data('study_id', files={'mut': 'data_mutations.txt', 'clin': 'data_clinical_patient.txt'})" + ) return data = self.read_data(files) - if 'clin' in files: - clin = data['clin'] + if "clin" in files: + clin = data["clin"] clin = clin.drop_duplicates(subset=clin.columns[0]) clin.set_index(clin.columns[0], inplace=True) - data['clin'] = clin + data["clin"] = clin print({x: data[x].shape for x in data.keys()}) self.data = data def split_data(self, samples=None, ratio=0.7): if samples is None: - samples = self.data['clin'].index.tolist() + samples = self.data["clin"].index.tolist() train_samples = list(pd.Series(samples).sample(frac=ratio, random_state=42)) test_samples = list(set(samples) - set(train_samples)) @@ -1754,10 +2040,18 @@ def split_data(self, samples=None, ratio=0.7): test_data = {} for key, df in self.data.items(): - train_data[key] = df.loc[df.index.intersection(train_samples)] if key == 'clin' else df.loc[:, df.columns.intersection(train_samples)] - test_data[key] = df.loc[df.index.intersection(test_samples)] if key == 'clin' else df.loc[:, df.columns.intersection(test_samples)] + train_data[key] = ( + df.loc[df.index.intersection(train_samples)] + if key == "clin" + else df.loc[:, df.columns.intersection(train_samples)] + ) + test_data[key] = ( + df.loc[df.index.intersection(test_samples)] + if key == "clin" + else df.loc[:, df.columns.intersection(test_samples)] + ) - return {'train': train_data, 'test': test_data} + return {"train": train_data, "test": test_data} def print_dataset(self, dataset, outdir): if not os.path.exists(outdir): @@ -1769,8 +2063,7 @@ def print_dataset(self, dataset, outdir): os.makedirs(split_dir) for file, df in data.items(): - df.to_csv(os.path.join(split_dir, f"{file}.csv"), sep=',') - + df.to_csv(os.path.join(split_dir, f"{file}.csv"), sep=",") def compute_correlation_loss(embeddings, batch_labels): @@ -1778,7 +2071,9 @@ def compute_correlation_loss(embeddings, batch_labels): batch_labels = batch_labels.float() # Normalize embeddings - embeddings = (embeddings - embeddings.mean(dim=0, keepdim=True)) / (embeddings.std(dim=0, keepdim=True) + 1e-8) + embeddings = (embeddings - embeddings.mean(dim=0, keepdim=True)) / ( + embeddings.std(dim=0, keepdim=True) + 1e-8 + ) # Normalize batch labels batch_labels = (batch_labels - batch_labels.mean()) / (batch_labels.std() + 1e-8) @@ -1793,6 +2088,7 @@ def compute_correlation_loss(embeddings, batch_labels): loss = torch.sum(torch.abs(covariance)) return loss + # from geomloss import SamplesLoss def compute_transport_cost(embeddings, batch_labels, blur=0.5): """ @@ -1814,7 +2110,9 @@ def compute_transport_cost(embeddings, batch_labels, blur=0.5): batch2_embeddings = embeddings[batch_labels == 1] if batch1_embeddings.size(0) == 0 or batch2_embeddings.size(0) == 0: - raise ValueError("Both batches must have at least one sample for transport cost computation.") + raise ValueError( + "Both batches must have at least one sample for transport cost computation." + ) # Initialize the Sinkhorn loss function loss_fn = SamplesLoss("sinkhorn", blur=blur) @@ -1855,29 +2153,33 @@ def get_optimal_device(device_preference=None): - device_type: String indicating the device type for compatibility """ if device_preference is None: - device_preference = 'auto' + device_preference = "auto" # If specific device is requested, validate and return it - if device_preference == 'cuda': + if device_preference == "cuda": if torch.cuda.is_available(): - return 'cuda', 'gpu' + return "cuda", "gpu" else: - warnings.warn("CUDA requested but not available. Falling back to auto-detection.") - elif device_preference == 'mps': + warnings.warn( + "CUDA requested but not available. Falling back to auto-detection." + ) + elif device_preference == "mps": if torch.backends.mps.is_available(): - return 'mps', 'mps' + return "mps", "mps" else: - warnings.warn("MPS requested but not available. Falling back to auto-detection.") - elif device_preference == 'cpu': - return 'cpu', 'cpu' + warnings.warn( + "MPS requested but not available. Falling back to auto-detection." + ) + elif device_preference == "cpu": + return "cpu", "cpu" # Auto-detection logic (priority: CUDA > MPS > CPU) if torch.cuda.is_available(): - return 'cuda', 'gpu' + return "cuda", "gpu" elif torch.backends.mps.is_available(): - return 'mps', 'mps' + return "mps", "mps" else: - return 'cpu', 'cpu' + return "cpu", "cpu" def get_device_memory_info(device_str): @@ -1890,30 +2192,30 @@ def get_device_memory_info(device_str): Returns: dict: Memory information dictionary """ - if device_str == 'cuda' and torch.cuda.is_available(): + if device_str == "cuda" and torch.cuda.is_available(): return { - 'allocated': torch.cuda.memory_allocated() / (1024**2), # MB - 'reserved': torch.cuda.memory_reserved() / (1024**2), # MB - 'max_allocated': torch.cuda.max_memory_allocated() / (1024**2), # MB - 'device_name': torch.cuda.get_device_name(0), - 'device_count': torch.cuda.device_count() + "allocated": torch.cuda.memory_allocated() / (1024**2), # MB + "reserved": torch.cuda.memory_reserved() / (1024**2), # MB + "max_allocated": torch.cuda.max_memory_allocated() / (1024**2), # MB + "device_name": torch.cuda.get_device_name(0), + "device_count": torch.cuda.device_count(), } - elif device_str == 'mps' and torch.backends.mps.is_available(): + elif device_str == "mps" and torch.backends.mps.is_available(): # MPS doesn't have the same detailed memory tracking as CUDA return { - 'allocated': 'N/A (MPS)', - 'reserved': 'N/A (MPS)', - 'max_allocated': 'N/A (MPS)', - 'device_name': 'Apple Metal Performance Shaders', - 'device_count': 1 + "allocated": "N/A (MPS)", + "reserved": "N/A (MPS)", + "max_allocated": "N/A (MPS)", + "device_name": "Apple Metal Performance Shaders", + "device_count": 1, } else: return { - 'allocated': 'N/A (CPU)', - 'reserved': 'N/A (CPU)', - 'max_allocated': 'N/A (CPU)', - 'device_name': 'CPU', - 'device_count': 1 + "allocated": "N/A (CPU)", + "reserved": "N/A (CPU)", + "max_allocated": "N/A (CPU)", + "device_name": "CPU", + "device_count": 1, } @@ -1927,20 +2229,20 @@ def create_device_from_string(device_str): Returns: torch.device: PyTorch device object """ - if device_str in ['gpu', 'auto']: + if device_str in ["gpu", "auto"]: optimal_device, _ = get_optimal_device() return torch.device(optimal_device) - elif device_str == 'mps': + elif device_str == "mps": if torch.backends.mps.is_available(): - return torch.device('mps') + return torch.device("mps") else: warnings.warn("MPS not available, falling back to CPU") - return torch.device('cpu') - elif device_str == 'cuda': + return torch.device("cpu") + elif device_str == "cuda": if torch.cuda.is_available(): - return torch.device('cuda') + return torch.device("cuda") else: warnings.warn("CUDA not available, falling back to CPU") - return torch.device('cpu') + return torch.device("cpu") else: - return torch.device('cpu') + return torch.device("cpu") diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index 231822dd..02c494a1 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -5,22 +5,22 @@ def test_mps_device_detection(): """Test MPS device detection via Flexynesis.""" - device_str, device_type = get_optimal_device('mps') - + device_str, device_type = get_optimal_device("mps") + if not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - - assert device_type == 'mps', f"Expected device type 'mps', got '{device_type}'" - assert device_str == 'mps', f"Expected device string 'mps', got '{device_str}'" + + assert device_type == "mps", f"Expected device type 'mps', got '{device_type}'" + assert device_str == "mps", f"Expected device string 'mps', got '{device_str}'" def test_mps_tensor_operations(): """Test MPS tensor operations via Flexynesis safe transfer functions.""" - device_str, device_type = get_optimal_device('mps') - - if device_str != 'mps' or not torch.backends.mps.is_available(): + device_str, device_type = get_optimal_device("mps") + + if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - + device = torch.device(device_str) # Test basic tensor operations using safe transfer @@ -29,16 +29,18 @@ def test_mps_tensor_operations(): result = torch.mm(x, y) assert result.shape == (100, 25), f"Unexpected result shape: {result.shape}" - assert result.device.type == "mps", f"Result not on MPS device: {result.device.type}" + assert ( + result.device.type == "mps" + ), f"Result not on MPS device: {result.device.type}" def test_mps_memory_allocation(): """Test MPS memory allocation tracking.""" - device_str, device_type = get_optimal_device('mps') - - if device_str != 'mps' or not torch.backends.mps.is_available(): + device_str, device_type = get_optimal_device("mps") + + if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - + device = torch.device(device_str) # Test memory tracking @@ -51,11 +53,11 @@ def test_mps_memory_allocation(): def test_float64_to_float32_conversion(): """Test automatic float64 to float32 conversion for MPS compatibility.""" - device_str, device_type = get_optimal_device('mps') - - if device_str != 'mps' or not torch.backends.mps.is_available(): + device_str, device_type = get_optimal_device("mps") + + if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - + device = torch.device(device_str) # Test float64 tensor conversion From 43605c9d2925e6a6e5e212eb20713c106b5baccf Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 16:30:55 +0200 Subject: [PATCH 02/32] use isort to fix the order --- flexynesis/__init__.py | 3 +- flexynesis/__main__.py | 65 +++++---- flexynesis/config.py | 2 +- flexynesis/data.py | 35 ++--- flexynesis/feature_selection.py | 4 +- flexynesis/generate_coexpression_network.py | 7 +- flexynesis/inference.py | 8 +- flexynesis/main.py | 38 +++--- flexynesis/models/__init__.py | 4 +- flexynesis/models/crossmodal_pred.py | 16 +-- flexynesis/models/direct_pred.py | 20 +-- flexynesis/models/gnn_early.py | 13 +- flexynesis/models/supervised_vae.py | 18 ++- flexynesis/models/triplet_encoder.py | 21 ++- flexynesis/modules.py | 2 +- flexynesis/utils.py | 140 +++++++------------- tests/test_mps_device.py | 3 +- 17 files changed, 167 insertions(+), 232 deletions(-) diff --git a/flexynesis/__init__.py b/flexynesis/__init__.py index fe80b0bf..3a31a0d2 100644 --- a/flexynesis/__init__.py +++ b/flexynesis/__init__.py @@ -1,9 +1,10 @@ # Lazy imports to avoid slow startup # Only import essential components that are needed for basic functionality +from importlib.metadata import PackageNotFoundError, version + # Import core modules without heavy dependencies from .config import search_spaces -from importlib.metadata import PackageNotFoundError, version # Use the distribution name as published on PyPI (may differ from import name in some projects) _DISTRIBUTION_NAME = "flexynesis" diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index f6cc00b2..4cc6ebc6 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -1,13 +1,15 @@ +import argparse +import json import os +import random import sys -import argparse -import yaml import time -import random -import warnings -import json import tracemalloc +import warnings + import psutil +import yaml + from . import __version__ os.environ["OMP_NUM_THREADS"] = "1" @@ -736,7 +738,8 @@ def main(): # ---------- Inference mode: early exit path ---------- if args.pretrained_model and args.artifacts and args.data_path_test: import torch - from .utils import get_optimal_device, create_device_from_string + + from .utils import create_device_from_string, get_optimal_device # quick existence checks if not os.path.exists(args.pretrained_model): @@ -833,8 +836,8 @@ def main(): # Convert to GNN dataset if needed if args.model_class == "GNN": print("[INFO] Overlaying the dataset with network data from STRINGDB") - from .main import STRING from .data import MultiOmicDatasetNW + from .main import STRING # Get STRING organism from artifacts string_organism = importer.artifacts.get( @@ -863,37 +866,35 @@ def main(): # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import ( - evaluate_baseline_performance, - evaluate_baseline_survival_performance, - get_predicted_labels, - evaluate_wrapper, - get_optimal_device, - get_device_memory_info, - create_device_from_string, - ) + from .utils import (create_device_from_string, + evaluate_baseline_performance, + evaluate_baseline_survival_performance, + evaluate_wrapper, get_device_memory_info, + get_optimal_device, get_predicted_labels) if not (args.pretrained_model and args.artifacts and args.data_path_test): - import flexynesis - from lightning import seed_everything - import lightning as pl + import json + import tracemalloc from typing import NamedTuple - import torch + + import lightning as pl import pandas as pd + import psutil + import torch + from lightning import seed_everything from safetensors.torch import save_file + import flexynesis + + # data + utils + from .data import STRING, DataImporter, MultiOmicDatasetNW + from .main import FineTuner, HyperparameterTuning + from .models.crossmodal_pred import CrossModalPred # models from .models.direct_pred import DirectPred + from .models.gnn_early import GNN from .models.supervised_vae import supervised_vae from .models.triplet_encoder import MultiTripletNetwork - from .models.crossmodal_pred import CrossModalPred - from .models.gnn_early import GNN - - # data + utils - from .data import STRING, MultiOmicDatasetNW, DataImporter - from .main import HyperparameterTuning, FineTuner - import tracemalloc, psutil - import json # --------- Sanity checks on args --------- # 1. survival variables consistency @@ -1487,11 +1488,9 @@ def main(): elif args.safetensors: import numpy as np - from sklearn.preprocessing import ( - LabelEncoder, - OrdinalEncoder, - StandardScaler, - ) + from sklearn.preprocessing import (LabelEncoder, + OrdinalEncoder, + StandardScaler) json_ready = { "schema_version": artifacts["schema_version"], diff --git a/flexynesis/config.py b/flexynesis/config.py index 1a77c1d1..fda0444a 100644 --- a/flexynesis/config.py +++ b/flexynesis/config.py @@ -1,5 +1,5 @@ # config.py -from skopt.space import Integer, Categorical, Real +from skopt.space import Categorical, Integer, Real epochs = [500] diff --git a/flexynesis/data.py b/flexynesis/data.py index feee43b1..30174539 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -1,31 +1,23 @@ -from torch.utils.data import Dataset, DataLoader -from torch_geometric.data import download_url, extract_gz -from torch_geometric.data import Dataset as PYGDataset +import os +import shutil +from functools import reduce +from itertools import chain +from pathlib import Path import numpy as np import pandas as pd -from functools import reduce import torch -import os -import shutil - -from pathlib import Path from filelock import FileLock from platformdirs import user_cache_dir - +from sklearn.preprocessing import (MinMaxScaler, OrdinalEncoder, + PowerTransformer, StandardScaler) +from torch.utils.data import DataLoader, Dataset +from torch_geometric.data import Dataset as PYGDataset +from torch_geometric.data import download_url, extract_gz from tqdm import tqdm - -from sklearn.preprocessing import ( - OrdinalEncoder, - StandardScaler, - MinMaxScaler, - PowerTransformer, -) from .feature_selection import filter_by_laplacian -from .utils import get_variable_types, create_covariate_matrix - -from itertools import chain +from .utils import create_covariate_matrix, get_variable_types # convert_to_labels: if true, given a numeric list, convert to binary labels by median value @@ -669,8 +661,8 @@ def check_common_features(train_dat, test_dat): Add this to flexynesis/data.py """ -import pandas as pd import numpy as np +import pandas as pd class DataImporterInference: @@ -797,7 +789,8 @@ def import_data(self): # Create covariates matrix if needed if "covariates" in self.modalities and labels_df is not None: - from flexynesis.utils import create_covariate_matrix, get_variable_types + from flexynesis.utils import (create_covariate_matrix, + get_variable_types) covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: diff --git a/flexynesis/feature_selection.py b/flexynesis/feature_selection.py index 763dc7f9..a80ef39e 100644 --- a/flexynesis/feature_selection.py +++ b/flexynesis/feature_selection.py @@ -2,11 +2,9 @@ import numpy as np import pandas as pd +from scipy.sparse import csgraph, csr_matrix, diags from scipy.spatial.distance import pdist, squareform from sklearn.neighbors import kneighbors_graph -from scipy.sparse import csgraph - -from scipy.sparse import csr_matrix, diags from tqdm import tqdm diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index 24c98f4e..4cdf3cd8 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -22,12 +22,13 @@ """ -import pandas as pd -import numpy as np -from tqdm import tqdm import argparse import sys + +import numpy as np +import pandas as pd import torch +from tqdm import tqdm def build_network( diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 3b0eb67f..7b998d6c 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -7,8 +7,8 @@ import os from types import SimpleNamespace -import numpy as np import joblib +import numpy as np import torch from safetensors.torch import load_file @@ -22,8 +22,8 @@ def check_model_type(file_path): - import struct import json + import struct with open(file_path, "rb") as f: header_start = f.read(8) @@ -139,6 +139,7 @@ def load_and_sniff_artifacts(artifacts_path): loads it appropriately, and returns (file_type, content). """ import json + import joblib def check_file_type(file_path): @@ -178,7 +179,8 @@ def check_file_type(file_path): def _deserialize_json_artifacts(artifacts): import numpy as np - from sklearn.preprocessing import StandardScaler, LabelEncoder, OrdinalEncoder + from sklearn.preprocessing import (LabelEncoder, OrdinalEncoder, + StandardScaler) # Rebuild sklearn objects expected by inference code. deserialized = dict(artifacts) diff --git a/flexynesis/main.py b/flexynesis/main.py index c79086f3..391303bb 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -2,27 +2,24 @@ seed_everything(42, workers=True) -import torch -from torch.utils.data import DataLoader, random_split -import torch_geometric +import os import lightning as pl -from lightning.pytorch.callbacks import RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme -from lightning.pytorch.callbacks import EarlyStopping - -from tqdm import tqdm - +import numpy as np +import torch +import torch_geometric +import yaml +from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar +from lightning.pytorch.callbacks.progress.rich_progress import \ + RichProgressBarTheme from skopt import Optimizer +from skopt.space import Categorical, Integer, Real from skopt.utils import use_named_args -from .config import search_spaces -from .data import TripletMultiOmicDataset - -import numpy as np +from torch.utils.data import DataLoader, random_split +from tqdm import tqdm -import os, yaml -from skopt.space import Integer, Categorical, Real -from .data import STRING +from .config import search_spaces +from .data import STRING, TripletMultiOmicDataset torch.set_float32_matmul_precision("medium") @@ -449,10 +446,13 @@ def load_and_convert_config(self, config_path): return search_space_user -from torch.utils.data import DataLoader, random_split -from sklearn.model_selection import KFold +import copy +import logging +import random + import numpy as np -import random, copy, logging +from sklearn.model_selection import KFold +from torch.utils.data import DataLoader, random_split class FineTuner(pl.LightningModule): diff --git a/flexynesis/models/__init__.py b/flexynesis/models/__init__.py index a2582e01..eb677382 100644 --- a/flexynesis/models/__init__.py +++ b/flexynesis/models/__init__.py @@ -1,8 +1,8 @@ +from .crossmodal_pred import CrossModalPred from .direct_pred import DirectPred +from .gnn_early import GNN from .supervised_vae import supervised_vae from .triplet_encoder import MultiTripletNetwork -from .crossmodal_pred import CrossModalPred -from .gnn_early import GNN __all__ = [ "DirectPred", diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 92c6634e..6dd6dcad 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -1,16 +1,14 @@ -import torch import itertools -from torch import nn -from torch.nn import functional as F -from torch.utils.data import Dataset, DataLoader, random_split - -import pandas as pd -import numpy as np import lightning as pl +import numpy as np +import pandas as pd +import torch +from captum.attr import GradientShap, IntegratedGradients from scipy import stats - -from captum.attr import IntegratedGradients, GradientShap +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset, random_split from ..modules import * from ..utils import to_device_safe diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 565f7c05..2e528b53 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -1,16 +1,16 @@ -import torch -from torch import nn -from torch.nn import functional as F -import lightning as pl -from torch.utils.data import Dataset, DataLoader, random_split +import argparse +import os +from functools import reduce -import pandas as pd +import lightning as pl import numpy as np -import os, argparse +import pandas as pd +import torch +from captum.attr import GradientShap, IntegratedGradients from scipy import stats -from functools import reduce - -from captum.attr import IntegratedGradients, GradientShap +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset, random_split from ..modules import * from ..utils import to_device_safe diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index b4fe571e..a7a2e260 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -1,19 +1,14 @@ +import lightning as pl import numpy as np import pandas as pd - import torch +from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -from torch.utils.data import random_split - -import lightning as pl +from torch.utils.data import DataLoader, random_split -from torch.utils.data import DataLoader - -from captum.attr import IntegratedGradients, GradientShap - -from ..utils import to_device_safe from ..modules import MLP, cox_ph_loss, flexGCN +from ..utils import to_device_safe class GNN(pl.LightningModule): diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index b679ab39..58857c68 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -1,20 +1,18 @@ # Supervised VAE-MMD architecture -import torch import itertools -from torch import nn -from torch.nn import functional as F -from torch.utils.data import Dataset, DataLoader, random_split - -import pandas as pd -import numpy as np import lightning as pl +import numpy as np +import pandas as pd +import torch +from captum.attr import GradientShap, IntegratedGradients from scipy import stats +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset, random_split -from captum.attr import IntegratedGradients, GradientShap - -from ..utils import to_device_safe from ..modules import * +from ..utils import to_device_safe # Supervised Variational Auto-encoder that can train one or more layers of omics datasets diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 6824610c..0c964c40 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -1,21 +1,18 @@ # Generating encodings of multi-omic data using triplet loss -import torch -from torch import nn -from torch.nn import functional as F -from torch.utils.data import Dataset, DataLoader, random_split - import itertools -import pandas as pd -import numpy as np - import lightning as pl +import numpy as np +import pandas as pd +import torch +from captum.attr import GradientShap, IntegratedGradients +from torch import nn +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset, random_split -from ..utils import to_device_safe -from ..modules import * from ..data import TripletMultiOmicDataset - -from captum.attr import IntegratedGradients, GradientShap +from ..modules import * +from ..utils import to_device_safe class MultiTripletNetwork(pl.LightningModule): diff --git a/flexynesis/modules.py b/flexynesis/modules.py index f367959f..253cfc28 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_geometric.nn import aggr, GCNConv, GATConv, SAGEConv, GraphConv +from torch_geometric.nn import GATConv, GCNConv, GraphConv, SAGEConv, aggr __all__ = ["Encoder", "Decoder", "MLP", "flexGCN", "cox_ph_loss"] diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 01245003..ade47a40 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -1,44 +1,35 @@ -from lightning import seed_everything -import pandas as pd -import numpy as np -import torch +import logging import math -import warnings -import requests -import tarfile import os -from glob import glob import re -import logging -from tqdm import tqdm +import tarfile +import warnings +from glob import glob -from umap import UMAP -import seaborn as sns -import matplotlib.pyplot as plt import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import requests +import seaborn as sns +import torch +from lightning import seed_everything +from scipy.stats import linregress, pearsonr from sklearn.decomposition import PCA -from sklearn.metrics import ( - balanced_accuracy_score, - f1_score, - cohen_kappa_score, - classification_report, - roc_auc_score, - average_precision_score, -) -from sklearn.metrics import mean_squared_error -from sklearn.metrics import adjusted_mutual_info_score, adjusted_rand_score +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, + mutual_info_regression) +from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, + average_precision_score, balanced_accuracy_score, + classification_report, cohen_kappa_score, + f1_score, mean_squared_error, roc_auc_score) +from sklearn.model_selection import GridSearchCV, KFold, cross_val_score +from sklearn.svm import SVC, SVR from sklearn.utils import resample from sksurv.metrics import cumulative_dynamic_auc from sksurv.util import Surv - -from scipy.stats import pearsonr, linregress - - -from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.svm import SVC, SVR -from sklearn.feature_selection import SelectFromModel -from sklearn.feature_selection import mutual_info_regression, mutual_info_classif -from sklearn.model_selection import KFold, cross_val_score, GridSearchCV +from tqdm import tqdm +from umap import UMAP try: from xgboost import XGBClassifier, XGBRegressor @@ -46,70 +37,30 @@ XGBClassifier = None XGBRegressor = None -from sksurv.ensemble import RandomSurvivalForest -from sksurv.metrics import concordance_index_censored - -from lifelines import KaplanMeierFitter -from lifelines.utils import concordance_index -from lifelines import CoxPHFitter -from lifelines.statistics import logrank_test, multivariate_logrank_test - -from plotnine import ( - ggplot, - aes, - geom_point, - geom_smooth, - geom_line, - geom_abline, - geom_step, - labs, - ggtitle, - annotate, - theme_minimal, - theme, - element_text, - scale_color_manual, - scale_color_gradient, - scale_color_brewer, - geom_errorbarh, - geom_text, - theme_bw, - theme, - element_blank, - scale_y_discrete, -) - -from sklearn.cluster import KMeans -from sklearn.metrics import silhouette_score -from sklearn.metrics.pairwise import euclidean_distances -import networkx as nx import community as community_louvain - -from sklearn.preprocessing import StandardScaler -import ot - +import matplotlib.pyplot as plt +import networkx as nx # imports import numpy as np +import ot import pandas as pd -import matplotlib.pyplot as plt -from plotnine import ( - ggplot, - aes, - geom_point, - geom_step, - labs, - ggtitle, - annotate, - theme_minimal, - theme, - element_text, - scale_color_manual, - scale_color_gradient, -) +from lifelines import CoxPHFitter, KaplanMeierFitter +from lifelines.statistics import logrank_test, multivariate_logrank_test +from lifelines.utils import concordance_index +from plotnine import (aes, annotate, element_blank, element_text, geom_abline, + geom_errorbarh, geom_line, geom_point, geom_smooth, + geom_step, geom_text, ggplot, ggtitle, labs, + scale_color_brewer, scale_color_gradient, + scale_color_manual, scale_y_discrete, theme, theme_bw, + theme_minimal) +from sklearn.cluster import KMeans from sklearn.decomposition import PCA +from sklearn.metrics import silhouette_score +from sklearn.metrics.pairwise import euclidean_distances +from sklearn.preprocessing import StandardScaler +from sksurv.ensemble import RandomSurvivalForest +from sksurv.metrics import concordance_index_censored from umap import UMAP -from lifelines import KaplanMeierFitter -from lifelines.statistics import logrank_test, multivariate_logrank_test def _labels_to_1d(labels): @@ -344,7 +295,7 @@ def plot_scatter(true_values, predicted_values): return plot -from scipy.stats import mannwhitneyu, kruskal +from scipy.stats import kruskal, mannwhitneyu def plot_boxplot( @@ -532,9 +483,9 @@ def evaluate_classifier(y_true, y_probs, print_report=False): } -from sklearn.metrics import roc_curve, roc_auc_score +from sklearn.metrics import (average_precision_score, precision_recall_curve, + roc_auc_score, roc_curve) from sklearn.preprocessing import label_binarize -from sklearn.metrics import precision_recall_curve, average_precision_score def plot_roc_curves(y_true, y_probs): @@ -1910,9 +1861,10 @@ def reciprocal_pca_mnn( import tarfile -import requests from io import StringIO +import requests + class CBioPortalData: def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index 02c494a1..15f118df 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -1,5 +1,6 @@ -import torch import pytest +import torch + from flexynesis.utils import get_optimal_device, to_device_safe From 0e8feb87a98858c939fab5215c767419a484c76a Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 20:56:56 +0200 Subject: [PATCH 03/32] remove unused modules --- flexynesis/__main__.py | 17 +++-------------- flexynesis/data.py | 6 ++---- flexynesis/generate_coexpression_network.py | 9 ++++----- flexynesis/inference.py | 1 - flexynesis/main.py | 5 +---- flexynesis/models/crossmodal_pred.py | 4 +--- flexynesis/models/direct_pred.py | 7 +------ flexynesis/models/gnn_early.py | 2 +- flexynesis/models/supervised_vae.py | 4 +--- flexynesis/models/triplet_encoder.py | 3 +-- flexynesis/modules.py | 2 +- flexynesis/utils.py | 17 ++++------------- tests/unit/test_smoke.py | 2 +- 13 files changed, 21 insertions(+), 58 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 4cc6ebc6..2b92fe44 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -1,15 +1,9 @@ import argparse -import json import os -import random import sys import time -import tracemalloc import warnings -import psutil -import yaml - from . import __version__ os.environ["OMP_NUM_THREADS"] = "1" @@ -739,7 +733,7 @@ def main(): if args.pretrained_model and args.artifacts and args.data_path_test: import torch - from .utils import create_device_from_string, get_optimal_device + from .utils import get_optimal_device # quick existence checks if not os.path.exists(args.pretrained_model): @@ -866,8 +860,7 @@ def main(): # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import (create_device_from_string, - evaluate_baseline_performance, + from .utils import (evaluate_baseline_performance, evaluate_baseline_survival_performance, evaluate_wrapper, get_device_memory_info, get_optimal_device, get_predicted_labels) @@ -875,16 +868,12 @@ def main(): if not (args.pretrained_model and args.artifacts and args.data_path_test): import json import tracemalloc - from typing import NamedTuple - import lightning as pl import pandas as pd import psutil import torch - from lightning import seed_everything from safetensors.torch import save_file - import flexynesis # data + utils from .data import STRING, DataImporter, MultiOmicDatasetNW @@ -1047,7 +1036,7 @@ def main(): # classical ML baselines if args.model_class == "XGBoost": try: - from xgboost import XGBClassifier + from xgboost import XGBClassifier # noqa: F401 except Exception: raise ImportError( "XGBoost is not available. On macOS, install the OpenMP runtime: brew install libomp" diff --git a/flexynesis/data.py b/flexynesis/data.py index 30174539..b319a76f 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -1,5 +1,4 @@ import os -import shutil from functools import reduce from itertools import chain from pathlib import Path @@ -10,11 +9,10 @@ from filelock import FileLock from platformdirs import user_cache_dir from sklearn.preprocessing import (MinMaxScaler, OrdinalEncoder, - PowerTransformer, StandardScaler) -from torch.utils.data import DataLoader, Dataset + StandardScaler) +from torch.utils.data import Dataset from torch_geometric.data import Dataset as PYGDataset from torch_geometric.data import download_url, extract_gz -from tqdm import tqdm from .feature_selection import filter_by_laplacian from .utils import create_covariate_matrix, get_variable_types diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index 4cdf3cd8..2b2acfb6 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -15,7 +15,7 @@ Input format: CSV/TSV file with genes as rows and samples as columns First column should be gene names/IDs - + Output format: CSV/TSV file with columns: GeneA, GeneB, Score Format matches input file extension (.csv or .tsv) @@ -25,7 +25,6 @@ import argparse import sys -import numpy as np import pandas as pd import torch from tqdm import tqdm @@ -271,7 +270,7 @@ def main(): python generate_coexpression_network.py \\ --input expression_data.csv \\ --output coexpression_network.csv - + # Use Pearson correlation with stricter threshold python generate_coexpression_network.py \\ --input expression_data.csv \\ @@ -279,7 +278,7 @@ def main(): --method pearson \\ --min_correlation 0.5 \\ --top_k 15 - + # More permissive network (more edges) python generate_coexpression_network.py \\ --input expression_data.csv \\ @@ -290,7 +289,7 @@ def main(): Input file format: CSV/TSV file with genes as rows, samples as columns First column should contain gene names/IDs - + Example: ,Sample1,Sample2,Sample3 TP53,5.2,6.1,5.8 diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 7b998d6c..35fd0b0e 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -7,7 +7,6 @@ import os from types import SimpleNamespace -import joblib import numpy as np import torch from safetensors.torch import load_file diff --git a/flexynesis/main.py b/flexynesis/main.py index 391303bb..130a4921 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -7,19 +7,17 @@ import lightning as pl import numpy as np import torch -import torch_geometric import yaml from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar from lightning.pytorch.callbacks.progress.rich_progress import \ RichProgressBarTheme from skopt import Optimizer from skopt.space import Categorical, Integer, Real -from skopt.utils import use_named_args from torch.utils.data import DataLoader, random_split from tqdm import tqdm from .config import search_spaces -from .data import STRING, TripletMultiOmicDataset +from .data import TripletMultiOmicDataset torch.set_float32_matmul_precision("medium") @@ -448,7 +446,6 @@ def load_and_convert_config(self, config_path): import copy import logging -import random import numpy as np from sklearn.model_selection import KFold diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 6dd6dcad..6ce742a2 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -5,13 +5,11 @@ import pandas as pd import torch from captum.attr import GradientShap, IntegratedGradients -from scipy import stats from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader from ..modules import * -from ..utils import to_device_safe class CrossModalPred(pl.LightningModule): diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 2e528b53..564a15ab 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -1,16 +1,11 @@ -import argparse -import os -from functools import reduce - import lightning as pl import numpy as np import pandas as pd import torch from captum.attr import GradientShap, IntegratedGradients -from scipy import stats from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader from ..modules import * from ..utils import to_device_safe diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index a7a2e260..aad8435f 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -5,7 +5,7 @@ from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split +from torch.utils.data import DataLoader from ..modules import MLP, cox_ph_loss, flexGCN from ..utils import to_device_safe diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index 58857c68..a0f8d919 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -6,13 +6,11 @@ import pandas as pd import torch from captum.attr import GradientShap, IntegratedGradients -from scipy import stats from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader from ..modules import * -from ..utils import to_device_safe # Supervised Variational Auto-encoder that can train one or more layers of omics datasets diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 0c964c40..11b796f9 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -8,11 +8,10 @@ from captum.attr import GradientShap, IntegratedGradients from torch import nn from torch.nn import functional as F -from torch.utils.data import DataLoader, Dataset, random_split +from torch.utils.data import DataLoader from ..data import TripletMultiOmicDataset from ..modules import * -from ..utils import to_device_safe class MultiTripletNetwork(pl.LightningModule): diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 253cfc28..f84066fa 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -2,7 +2,7 @@ import torch from torch import nn -from torch_geometric.nn import GATConv, GCNConv, GraphConv, SAGEConv, aggr +from torch_geometric.nn import GATConv, GCNConv, GraphConv, SAGEConv __all__ = ["Encoder", "Decoder", "MLP", "flexGCN", "cox_ph_loss"] diff --git a/flexynesis/utils.py b/flexynesis/utils.py index ade47a40..af544a8c 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -1,19 +1,13 @@ -import logging -import math import os -import re import tarfile import warnings -from glob import glob -import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd import requests import seaborn as sns import torch -from lightning import seed_everything from scipy.stats import linregress, pearsonr from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor @@ -23,12 +17,10 @@ average_precision_score, balanced_accuracy_score, classification_report, cohen_kappa_score, f1_score, mean_squared_error, roc_auc_score) -from sklearn.model_selection import GridSearchCV, KFold, cross_val_score +from sklearn.model_selection import GridSearchCV, KFold from sklearn.svm import SVC, SVR -from sklearn.utils import resample from sksurv.metrics import cumulative_dynamic_auc from sksurv.util import Surv -from tqdm import tqdm from umap import UMAP try: @@ -47,11 +39,11 @@ from lifelines import CoxPHFitter, KaplanMeierFitter from lifelines.statistics import logrank_test, multivariate_logrank_test from lifelines.utils import concordance_index -from plotnine import (aes, annotate, element_blank, element_text, geom_abline, +from plotnine import (aes, annotate, element_text, geom_abline, geom_errorbarh, geom_line, geom_point, geom_smooth, geom_step, geom_text, ggplot, ggtitle, labs, - scale_color_brewer, scale_color_gradient, - scale_color_manual, scale_y_discrete, theme, theme_bw, + scale_color_gradient, + scale_color_manual, theme, theme_bw, theme_minimal) from sklearn.cluster import KMeans from sklearn.decomposition import PCA @@ -1861,7 +1853,6 @@ def reciprocal_pca_mnn( import tarfile -from io import StringIO import requests diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py index 701b7349..bba7119e 100644 --- a/tests/unit/test_smoke.py +++ b/tests/unit/test_smoke.py @@ -1,2 +1,2 @@ def test_smoke(): - import flexynesis + import flexynesis # noqa: F401 From 7ca24cdb9616926432ff05823c7697dbe798c048 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 20:57:37 +0200 Subject: [PATCH 04/32] run black again --- flexynesis/__main__.py | 22 ++++++++++------ flexynesis/data.py | 6 ++--- flexynesis/inference.py | 3 +-- flexynesis/main.py | 3 +-- flexynesis/utils.py | 57 ++++++++++++++++++++++++++++++---------- tests/unit/test_smoke.py | 2 +- 6 files changed, 62 insertions(+), 31 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 2b92fe44..6dbdd415 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -860,10 +860,14 @@ def main(): # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import (evaluate_baseline_performance, - evaluate_baseline_survival_performance, - evaluate_wrapper, get_device_memory_info, - get_optimal_device, get_predicted_labels) + from .utils import ( + evaluate_baseline_performance, + evaluate_baseline_survival_performance, + evaluate_wrapper, + get_device_memory_info, + get_optimal_device, + get_predicted_labels, + ) if not (args.pretrained_model and args.artifacts and args.data_path_test): import json @@ -874,11 +878,11 @@ def main(): import torch from safetensors.torch import save_file - # data + utils from .data import STRING, DataImporter, MultiOmicDatasetNW from .main import FineTuner, HyperparameterTuning from .models.crossmodal_pred import CrossModalPred + # models from .models.direct_pred import DirectPred from .models.gnn_early import GNN @@ -1477,9 +1481,11 @@ def main(): elif args.safetensors: import numpy as np - from sklearn.preprocessing import (LabelEncoder, - OrdinalEncoder, - StandardScaler) + from sklearn.preprocessing import ( + LabelEncoder, + OrdinalEncoder, + StandardScaler, + ) json_ready = { "schema_version": artifacts["schema_version"], diff --git a/flexynesis/data.py b/flexynesis/data.py index b319a76f..29ceeef3 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -8,8 +8,7 @@ import torch from filelock import FileLock from platformdirs import user_cache_dir -from sklearn.preprocessing import (MinMaxScaler, OrdinalEncoder, - StandardScaler) +from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder, StandardScaler from torch.utils.data import Dataset from torch_geometric.data import Dataset as PYGDataset from torch_geometric.data import download_url, extract_gz @@ -787,8 +786,7 @@ def import_data(self): # Create covariates matrix if needed if "covariates" in self.modalities and labels_df is not None: - from flexynesis.utils import (create_covariate_matrix, - get_variable_types) + from flexynesis.utils import create_covariate_matrix, get_variable_types covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 35fd0b0e..e40975ff 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -178,8 +178,7 @@ def check_file_type(file_path): def _deserialize_json_artifacts(artifacts): import numpy as np - from sklearn.preprocessing import (LabelEncoder, OrdinalEncoder, - StandardScaler) + from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler # Rebuild sklearn objects expected by inference code. deserialized = dict(artifacts) diff --git a/flexynesis/main.py b/flexynesis/main.py index 130a4921..2bfcdb8f 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -9,8 +9,7 @@ import torch import yaml from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import \ - RichProgressBarTheme +from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme from skopt import Optimizer from skopt.space import Categorical, Integer, Real from torch.utils.data import DataLoader, random_split diff --git a/flexynesis/utils.py b/flexynesis/utils.py index af544a8c..6c272eb6 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -11,12 +11,22 @@ from scipy.stats import linregress, pearsonr from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, - mutual_info_regression) -from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, - average_precision_score, balanced_accuracy_score, - classification_report, cohen_kappa_score, - f1_score, mean_squared_error, roc_auc_score) +from sklearn.feature_selection import ( + SelectFromModel, + mutual_info_classif, + mutual_info_regression, +) +from sklearn.metrics import ( + adjusted_mutual_info_score, + adjusted_rand_score, + average_precision_score, + balanced_accuracy_score, + classification_report, + cohen_kappa_score, + f1_score, + mean_squared_error, + roc_auc_score, +) from sklearn.model_selection import GridSearchCV, KFold from sklearn.svm import SVC, SVR from sksurv.metrics import cumulative_dynamic_auc @@ -32,6 +42,7 @@ import community as community_louvain import matplotlib.pyplot as plt import networkx as nx + # imports import numpy as np import ot @@ -39,12 +50,26 @@ from lifelines import CoxPHFitter, KaplanMeierFitter from lifelines.statistics import logrank_test, multivariate_logrank_test from lifelines.utils import concordance_index -from plotnine import (aes, annotate, element_text, geom_abline, - geom_errorbarh, geom_line, geom_point, geom_smooth, - geom_step, geom_text, ggplot, ggtitle, labs, - scale_color_gradient, - scale_color_manual, theme, theme_bw, - theme_minimal) +from plotnine import ( + aes, + annotate, + element_text, + geom_abline, + geom_errorbarh, + geom_line, + geom_point, + geom_smooth, + geom_step, + geom_text, + ggplot, + ggtitle, + labs, + scale_color_gradient, + scale_color_manual, + theme, + theme_bw, + theme_minimal, +) from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.metrics import silhouette_score @@ -475,8 +500,12 @@ def evaluate_classifier(y_true, y_probs, print_report=False): } -from sklearn.metrics import (average_precision_score, precision_recall_curve, - roc_auc_score, roc_curve) +from sklearn.metrics import ( + average_precision_score, + precision_recall_curve, + roc_auc_score, + roc_curve, +) from sklearn.preprocessing import label_binarize diff --git a/tests/unit/test_smoke.py b/tests/unit/test_smoke.py index bba7119e..0bcee2ab 100644 --- a/tests/unit/test_smoke.py +++ b/tests/unit/test_smoke.py @@ -1,2 +1,2 @@ def test_smoke(): - import flexynesis # noqa: F401 + import flexynesis # noqa: F401 From aa8bb8c42e9333cf8caf4dfd5a9122b456f4ae3e Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:13:53 +0200 Subject: [PATCH 05/32] use black to reformat lines to have length < 79 --- flexynesis/__main__.py | 222 +++++++++++---- flexynesis/data.py | 295 +++++++++++++++----- flexynesis/feature_selection.py | 20 +- flexynesis/generate_coexpression_network.py | 25 +- flexynesis/inference.py | 43 ++- flexynesis/main.py | 77 +++-- flexynesis/models/crossmodal_pred.py | 49 +++- flexynesis/models/direct_pred.py | 48 +++- flexynesis/models/gnn_early.py | 50 +++- flexynesis/models/supervised_vae.py | 49 +++- flexynesis/models/triplet_encoder.py | 72 +++-- flexynesis/modules.py | 15 +- flexynesis/utils.py | 173 +++++++++--- tests/test_mps_device.py | 21 +- 14 files changed, 872 insertions(+), 287 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 6dbdd415..7f79f5a2 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -17,7 +17,9 @@ def print_test_installation(): ) print(" tar -xzvf dataset1.tgz") print() - print(" # Test the installation (should finish within a minute on a typical CPU)") + print( + " # Test the installation (should finish within a minute on a typical CPU)" + ) print( " flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --hpo_iter 1 --features_top_percentile 5 --data_types gex,cnv" ) @@ -39,7 +41,9 @@ def print_help(): print( " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) - print(" (Required) The kind of model class to instantiate") + print( + " (Required) The kind of model class to instantiate" + ) print(" --data_types DATA_TYPES") print( " (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'" @@ -85,9 +89,13 @@ def print_full_help(): " Use a saved .pth/.safetensors model for inference (skip training)" ) print(" --artifacts ARTIFACTS") - print(" Path to training-time artifacts .joblib or .json") + print( + " Path to training-time artifacts .joblib or .json" + ) print(" --data_path_test DATA_PATH_TEST") - print(" Folder with test-only dataset for inference") + print( + " Folder with test-only dataset for inference" + ) print(" --join_key JOIN_KEY Column name in 'clin.csv' for sample IDs") # --- existing flags (keep full list) --- @@ -99,7 +107,9 @@ def print_full_help(): print( " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) - print(" (Required) The kind of model class to instantiate") + print( + " (Required) The kind of model class to instantiate" + ) print(" --gnn_conv_type {GC,GCN,SAGE}") print( " If model_class is set to GNN, choose which graph convolution type to use" @@ -174,7 +184,9 @@ def print_full_help(): print( " --outdir OUTDIR Path to the output folder to save the model outputs (default: current working directory)" ) - print(" --prefix PREFIX Job prefix to use for output files (default: 'job')") + print( + " --prefix PREFIX Job prefix to use for output files (default: 'job')" + ) print(" --log_transform {True,False}") print( " whether to apply log-transformation to input data matrices (default: False)" @@ -215,7 +227,9 @@ def print_full_help(): print( " --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available." ) - print(" --feature_importance_method {IntegratedGradients,GradientShap,Both}") + print( + " --feature_importance_method {IntegratedGradients,GradientShap,Both}" + ) print( " Choose feature importance score method (default: IntegratedGradients)" ) @@ -231,7 +245,9 @@ def print_full_help(): print( " Path to user-provided gene-gene interaction network file." ) - print(" Must have at least 3 columns: GeneA, GeneB, Score.") + print( + " Must have at least 3 columns: GeneA, GeneB, Score." + ) print( " If provided, this will be used instead of STRING DB." ) @@ -576,7 +592,10 @@ def main(): help="Path to the output folder to save the model outputs", ) parser.add_argument( - "--prefix", type=str, default="job", help="Job prefix to use for output files" + "--prefix", + type=str, + default="job", + help="Job prefix to use for output files", ) parser.add_argument( "--log_transform", @@ -658,7 +677,10 @@ def main(): ) # GNN args. parser.add_argument( - "--string_organism", type=int, default=9606, help="STRING DB organism id." + "--string_organism", + type=int, + default=9606, + help="STRING DB organism id.", ) parser.add_argument( "--string_node_name", @@ -746,7 +768,8 @@ def main(): # Handle device selection for inference (same logic as training) if args.use_gpu: warnings.warn( - "--use_gpu is deprecated. Use --device instead.", DeprecationWarning + "--use_gpu is deprecated. Use --device instead.", + DeprecationWarning, ) if args.device != "auto": device_preference = args.device @@ -782,7 +805,9 @@ def main(): config_path = model_base + "_config.json" if not os.path.exists(config_path): # Try alongside the safetensors file with standard naming - base = args.pretrained_model.replace(".final_model.safetensors", "") + base = args.pretrained_model.replace( + ".final_model.safetensors", "" + ) config_path = base + ".final_model_config.json" if not os.path.exists(config_path): raise FileNotFoundError( @@ -799,17 +824,23 @@ def main(): # Standard .pth load — robust across PyTorch versions try: model = torch.load( - args.pretrained_model, map_location=device, weights_only=False + args.pretrained_model, + map_location=device, + weights_only=False, ) except TypeError: model = torch.load(args.pretrained_model, map_location=device) except Exception: try: model = torch.load( - args.pretrained_model, map_location=device, weights_only=True + args.pretrained_model, + map_location=device, + weights_only=True, ) except TypeError: - model = torch.load(args.pretrained_model, map_location=device) + model = torch.load( + args.pretrained_model, map_location=device + ) # Extract model class name for metrics args.model_class = model.__class__.__name__ @@ -829,7 +860,9 @@ def main(): # Convert to GNN dataset if needed if args.model_class == "GNN": - print("[INFO] Overlaying the dataset with network data from STRINGDB") + print( + "[INFO] Overlaying the dataset with network data from STRINGDB" + ) from .data import MultiOmicDatasetNW from .main import STRING @@ -837,10 +870,14 @@ def main(): string_organism = importer.artifacts.get( "string_organism", 9606 ) # default human - string_node_name = importer.artifacts.get("string_node_name", "HGNC") + string_node_name = importer.artifacts.get( + "string_node_name", "HGNC" + ) # Load STRING network obj = STRING( - os.path.join(args.data_path_test, "_".join(["processed", args.prefix])), + os.path.join( + args.data_path_test, "_".join(["processed", args.prefix]) + ), string_organism, string_node_name, ) @@ -856,7 +893,9 @@ def main(): # Move dataset to same device as model if hasattr(test_dataset, "to_device"): test_dataset.to_device(device) - print(f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples") + print( + f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples" + ) # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- @@ -915,7 +954,8 @@ def main(): # Support legacy --use_gpu flag for backward compatibility if args.use_gpu: warnings.warn( - "--use_gpu is deprecated. Use --device instead.", DeprecationWarning + "--use_gpu is deprecated. Use --device instead.", + DeprecationWarning, ) # If --device is not explicitly set (still at default auto), let auto-detection handle it if args.device != "auto": @@ -939,7 +979,9 @@ def main(): memory_info = get_device_memory_info(device_str) print(f"[INFO] Device name: {memory_info['device_name']}") if device_str == "cuda": - print(f"[INFO] Available CUDA devices: {memory_info['device_count']}") + print( + f"[INFO] Available CUDA devices: {memory_info['device_count']}" + ) # gnn if args.model_class == "GNN": @@ -978,12 +1020,17 @@ def main(): f"Input --data_path doesn't exist at: {args.data_path}" ) if not os.path.exists(args.outdir): - raise FileNotFoundError(f"Path to --outdir doesn't exist at: {args.outdir}") + raise FileNotFoundError( + f"Path to --outdir doesn't exist at: {args.outdir}" + ) available_models = { "DirectPred": (DirectPred, "DirectPred"), "supervised_vae": (supervised_vae, "supervised_vae"), - "MultiTripletNetwork": (MultiTripletNetwork, "MultiTripletNetwork"), + "MultiTripletNetwork": ( + MultiTripletNetwork, + "MultiTripletNetwork", + ), "CrossModalPred": (CrossModalPred, "CrossModalPred"), "GNN": (GNN, "GNN"), "RandomForest": ("RandomForest", None), @@ -1002,7 +1049,9 @@ def main(): # covariates if args.covariates: - if args.model_class == "GNN": # Covariates not yet supported for GNNs + if ( + args.model_class == "GNN" + ): # Covariates not yet supported for GNNs warnings.warn( "\n\n!!! Covariates are currently not supported for GNN models, they will be ignored. !!!\n" ) @@ -1058,18 +1107,23 @@ def main(): n_jobs=args.threads, ) metrics.to_csv( - os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + os.path.join( + args.outdir, ".".join([args.prefix, "stats.csv"]) + ), header=True, index=False, ) predictions.to_csv( os.path.join( - args.outdir, ".".join([args.prefix, "predicted_labels.csv"]) + args.outdir, + ".".join([args.prefix, "predicted_labels.csv"]), ), header=True, index=False, ) - print(f"{args.model_class} evaluation complete. Results saved.") + print( + f"{args.model_class} evaluation complete. Results saved." + ) sys.exit(0) else: raise ValueError( @@ -1090,18 +1144,23 @@ def main(): n_jobs=int(args.threads), ) metrics.to_csv( - os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), + os.path.join( + args.outdir, ".".join([args.prefix, "stats.csv"]) + ), header=True, index=False, ) predictions.to_csv( os.path.join( - args.outdir, ".".join([args.prefix, "predicted_labels.csv"]) + args.outdir, + ".".join([args.prefix, "predicted_labels.csv"]), ), header=True, index=False, ) - print(f"{args.model_class} evaluation complete. Results saved.") + print( + f"{args.model_class} evaluation complete. Results saved." + ) sys.exit(0) else: raise ValueError( @@ -1111,16 +1170,24 @@ def main(): if args.model_class == "GNN" and train_dataset is not None: # Check if user provided a custom graph if args.user_graph is not None: - print(f"[INFO] Using user-provided network from: {args.user_graph}") + print( + f"[INFO] Using user-provided network from: {args.user_graph}" + ) from flexynesis.data import read_user_graph graph_df = read_user_graph(args.user_graph) - print(f"[INFO] Loaded {len(graph_df)} interactions from user graph") + print( + f"[INFO] Loaded {len(graph_df)} interactions from user graph" + ) else: # Fallback to STRING DB (default behavior) - print("[INFO] No user graph provided. Using STRING DB network (default)") + print( + "[INFO] No user graph provided. Using STRING DB network (default)" + ) obj = STRING( - os.path.join(args.data_path, "_".join(["processed", args.prefix])), + os.path.join( + args.data_path, "_".join(["processed", args.prefix]) + ), args.string_organism, args.string_node_name, ) @@ -1144,7 +1211,8 @@ def main(): for key in feature_logs.keys(): feature_logs[key].to_csv( os.path.join( - args.outdir, ".".join([args.prefix, "feature_logs", key, "csv"]) + args.outdir, + ".".join([args.prefix, "feature_logs", key, "csv"]), ), header=True, index=False, @@ -1178,19 +1246,27 @@ def main(): # do a hyperparameter search training multiple models and get the best configuration t1 = time.time() - model, best_params = tuner.perform_tuning(hpo_patience=args.hpo_patience) + model, best_params = tuner.perform_tuning( + hpo_patience=args.hpo_patience + ) hpo_time = time.time() - t1 hpo_system_ram = process.memory_info().rss # if fine-tuning is enabled; fine tune the model on a portion of test samples if args.finetuning_samples > 0: finetuneSampleN = args.finetuning_samples - print("[INFO] Finetuning the model on ", finetuneSampleN, "test samples") + print( + "[INFO] Finetuning the model on ", + finetuneSampleN, + "test samples", + ) # split test dataset into finetuning and holdout datasets all_indices = range(len(test_dataset)) import random as _random - finetune_indices = _random.sample(list(all_indices), finetuneSampleN) + finetune_indices = _random.sample( + list(all_indices), finetuneSampleN + ) holdout_indices = list(set(all_indices) - set(finetune_indices)) finetune_dataset = test_dataset.subset(finetune_indices) holdout_dataset = test_dataset.subset(holdout_indices) @@ -1209,12 +1285,16 @@ def main(): if train_dataset is not None: embeddings_train = model.transform(train_dataset) embeddings_train.to_csv( - os.path.join(args.outdir, ".".join([args.prefix, "embeddings_train.csv"])), + os.path.join( + args.outdir, ".".join([args.prefix, "embeddings_train.csv"]) + ), header=True, ) embeddings_test = model.transform(test_dataset) embeddings_test.to_csv( - os.path.join(args.outdir, ".".join([args.prefix, "embeddings_test.csv"])), + os.path.join( + args.outdir, ".".join([args.prefix, "embeddings_test.csv"]) + ), header=True, ) @@ -1236,19 +1316,32 @@ def main(): ) for var in model.target_variables: model.compute_feature_importance( - train_dataset, var, steps_or_samples=25, method=explainer + train_dataset, + var, + steps_or_samples=25, + method=explainer, ) import pandas as pd df_imp = pd.concat( - [model.feature_importances[x] for x in model.target_variables], + [ + model.feature_importances[x] + for x in model.target_variables + ], ignore_index=True, ) df_imp["explainer"] = explainer df_imp.to_csv( os.path.join( args.outdir, - ".".join([args.prefix, "feature_importance", explainer, "csv"]), + ".".join( + [ + args.prefix, + "feature_importance", + explainer, + "csv", + ] + ), ), header=True, index=False, @@ -1275,10 +1368,15 @@ def main(): ) else: predicted_labels = get_predicted_labels( - model.predict(test_dataset), test_dataset, "test", args.model_class + model.predict(test_dataset), + test_dataset, + "test", + args.model_class, ) predicted_labels.to_csv( - os.path.join(args.outdir, ".".join([args.prefix, "predicted_labels.csv"])), + os.path.join( + args.outdir, ".".join([args.prefix, "predicted_labels.csv"]) + ), header=True, index=False, ) @@ -1315,7 +1413,8 @@ def main(): for layer in output_layers_test.keys(): output_layers_test[layer].to_csv( os.path.join( - args.outdir, ".".join([args.prefix, "test_decoded", layer, "csv"]) + args.outdir, + ".".join([args.prefix, "test_decoded", layer, "csv"]), ), header=True, ) @@ -1387,13 +1486,16 @@ def main(): if not args.safetensors: torch.save( model, - os.path.join(args.outdir, ".".join([args.prefix, "final_model.pth"])), + os.path.join( + args.outdir, ".".join([args.prefix, "final_model.pth"]) + ), ) else: save_file( model.state_dict(), os.path.join( - args.outdir, ".".join([args.prefix, "final_model.safetensors"]) + args.outdir, + ".".join([args.prefix, "final_model.safetensors"]), ), ) # save model config as JSON @@ -1449,7 +1551,9 @@ def main(): "," ), # Original modalities from CLI before concatenation "target_variables": ( - args.target_variables.split(",") if args.target_variables else [] + args.target_variables.split(",") + if args.target_variables + else [] ), "feature_lists": ( data_importer.train_features @@ -1457,7 +1561,9 @@ def main(): else {} ), "transforms": ( - data_importer.scalers if hasattr(data_importer, "scalers") else {} + data_importer.scalers + if hasattr(data_importer, "scalers") + else {} ), "label_encoders": ( data_importer.label_encoders @@ -1498,7 +1604,9 @@ def main(): "string_node_name": artifacts["string_node_name"], "feature_lists": { modality: list(features) - for modality, features in artifacts["feature_lists"].items() + for modality, features in artifacts[ + "feature_lists" + ].items() }, "transforms": {}, "label_encoders": {}, @@ -1525,7 +1633,9 @@ def main(): if hasattr(scaler, "var_"): scaler_dict["var"] = scaler.var_.tolist() if hasattr(scaler, "n_features_in_"): - scaler_dict["n_features_in"] = int(scaler.n_features_in_) + scaler_dict["n_features_in"] = int( + scaler.n_features_in_ + ) if hasattr(scaler, "feature_names_in_"): scaler_dict["feature_names_in"] = ( scaler.feature_names_in_.tolist() @@ -1552,7 +1662,9 @@ def main(): elif isinstance(encoder, OrdinalEncoder): encoder_dict = { "type": "OrdinalEncoder", - "categories": [cat.tolist() for cat in encoder.categories_], + "categories": [ + cat.tolist() for cat in encoder.categories_ + ], "handle_unknown": encoder.handle_unknown, "unknown_value": encoder.unknown_value, } @@ -1564,7 +1676,9 @@ def main(): else val ) if hasattr(encoder, "n_features_in_"): - encoder_dict["n_features_in"] = int(encoder.n_features_in_) + encoder_dict["n_features_in"] = int( + encoder.n_features_in_ + ) if hasattr(encoder, "feature_names_in_"): encoder_dict["feature_names_in"] = ( encoder.feature_names_in_.tolist() diff --git a/flexynesis/data.py b/flexynesis/data.py index 29ceeef3..3a7bbaf4 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -146,7 +146,9 @@ def get_user_features(self): """ if self.restrict_to_features is not None: if not os.path.isfile(self.restrict_to_features): - raise FileNotFoundError(f"File not found: {self.restrict_to_features}") + raise FileNotFoundError( + f"File not found: {self.restrict_to_features}" + ) try: with open(self.restrict_to_features, "r") as fp: # Read and process the file @@ -172,19 +174,27 @@ def import_data(self): test_dat = self.read_data(testing_path) if self.downsample > 0: - print("[INFO] Randomly drawing", self.downsample, "samples for training") + print( + "[INFO] Randomly drawing", + self.downsample, + "samples for training", + ) train_dat = self.subsample(train_dat, self.downsample) if self.restrict_to_features is not None: - train_dat = self.filter_by_features(train_dat, self.restrict_to_features) - test_dat = self.filter_by_features(test_dat, self.restrict_to_features) + train_dat = self.filter_by_features( + train_dat, self.restrict_to_features + ) + test_dat = self.filter_by_features( + test_dat, self.restrict_to_features + ) # check for any problems with the the input files self.validate_input_data(train_dat, test_dat) # cleanup uninformative features/samples, subset annotation data, do feature selection on training data - train_dat, train_ann, train_samples, train_features = self.process_data( - train_dat, split="train" + train_dat, train_ann, train_samples, train_features = ( + self.process_data(train_dat, split="train") ) test_dat, test_ann, test_samples, test_features = self.process_data( test_dat, split="test" @@ -201,8 +211,12 @@ def import_data(self): # Normalize the training data (for testing data, use normalisation factors # learned from training data to apply on test data (see fit = False) - train_dat = self.normalize_data(train_dat, scaler_type="standard", fit=True) - test_dat = self.normalize_data(test_dat, scaler_type="standard", fit=False) + train_dat = self.normalize_data( + train_dat, scaler_type="standard", fit=True + ) + test_dat = self.normalize_data( + test_dat, scaler_type="standard", fit=False + ) # if covariates are defined, create a covariate matrix and add to the dictionary of data matrices if self.covariates: @@ -220,8 +234,12 @@ def import_data(self): train_dat, test_dat = self.harmonize(train_dat, test_dat) # encode the variable annotations, convert data matrices and annotations pytorch datasets - training_dataset = self.get_torch_dataset(train_dat, train_ann, train_samples) - testing_dataset = self.get_torch_dataset(test_dat, test_ann, test_samples) + training_dataset = self.get_torch_dataset( + train_dat, train_ann, train_samples + ) + testing_dataset = self.get_torch_dataset( + test_dat, test_ann, test_samples + ) # for early fusion, concatenate all data matrices and feature lists if self.concatenate: @@ -234,7 +252,9 @@ def import_data(self): } training_dataset.features = { "all": list( - chain(*[training_dataset.features[x] for x in modality_order]) + chain( + *[training_dataset.features[x] for x in modality_order] + ) ) } @@ -245,13 +265,18 @@ def import_data(self): } testing_dataset.features = { "all": list( - chain(*[testing_dataset.features[x] for x in modality_order]) + chain( + *[testing_dataset.features[x] for x in modality_order] + ) ) } # Save final feature lists AFTER concatenation (for inference mode) self.train_features = training_dataset.features.copy() - print("[INFO] Training Data Stats: ", training_dataset.get_dataset_stats()) + print( + "[INFO] Training Data Stats: ", + training_dataset.get_dataset_stats(), + ) print("[INFO] Test Data Stats: ", testing_dataset.get_dataset_stats()) print("[INFO] Merging Feature Logs...") logs = self.feature_logs @@ -316,7 +341,11 @@ def filter_by_features(self, dat, features): subset train/test data to only include those features """ dat_filtered = { - key: df if key == "clin" else df.loc[df.index.intersection(features)] + key: ( + df + if key == "clin" + else df.loc[df.index.intersection(features)] + ) for key, df in dat.items() } @@ -430,7 +459,9 @@ def cleanup_data(self, df_dict): for key in cleaned_dfs.keys(): original_samples_count = cleaned_dfs[key].shape[1] cleaned_dfs[key] = cleaned_dfs[key].loc[:, common_mask] - removed_samples_count = original_samples_count - cleaned_dfs[key].shape[1] + removed_samples_count = ( + original_samples_count - cleaned_dfs[key].shape[1] + ) print( f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%)." ) @@ -455,7 +486,10 @@ def get_labels(self, dat, ann): # unsupervised feature selection using laplacian score and correlation filters (optional) def select_features(self, dat): counts = { - x: max(int(dat[x].shape[0] * self.top_percentile / 100), self.min_features) + x: max( + int(dat[x].shape[0] * self.top_percentile / 100), + self.min_features, + ) for x in dat.keys() } dat_filtered = {} @@ -468,7 +502,9 @@ def select_features(self, dat): topN=counts[layer], correlation_threshold=self.correlation_threshold, ) - dat_filtered[layer] = X_filt.T # transpose after laplacian filtering again + dat_filtered[layer] = ( + X_filt.T + ) # transpose after laplacian filtering again # Features will be stored after concatenation feature_logs[layer] = log_df # update main feature logs with events from this function @@ -476,7 +512,9 @@ def select_features(self, dat): return dat_filtered def harmonize(self, dat1, dat2): - print("\n[INFO] ----------------- Harmonizing Data Sets ----------------- ") + print( + "\n[INFO] ----------------- Harmonizing Data Sets ----------------- " + ) # common data layers common_layers = dat1.keys() & dat2.keys() # Get common features @@ -486,7 +524,9 @@ def harmonize(self, dat1, dat2): # Subset both datasets to only include common features dat1 = {x: dat1[x].loc[common_features[x]] for x in common_layers} dat2 = {x: dat2[x].loc[common_features[x]] for x in common_layers} - print("\n[INFO] ----------------- Finished Harmonizing ----------------- ") + print( + "\n[INFO] ----------------- Finished Harmonizing ----------------- " + ) return dat1, dat2 @@ -501,11 +541,17 @@ def normalize_data(self, data, scaler_type="standard", fit=True): # while scaling methods assume features to be on the columns. if fit: if scaler_type == "standard": - self.scalers = {x: StandardScaler().fit(data[x].T) for x in data.keys()} + self.scalers = { + x: StandardScaler().fit(data[x].T) for x in data.keys() + } elif scaler_type == "min_max": - self.scalers = {x: MinMaxScaler().fit(data[x].T) for x in data.keys()} + self.scalers = { + x: MinMaxScaler().fit(data[x].T) for x in data.keys() + } else: - raise ValueError("Invalid scaler_type. Choose 'standard' or 'min_max'.") + raise ValueError( + "Invalid scaler_type. Choose 'standard' or 'min_max'." + ) normalized_data = { x: pd.DataFrame( @@ -520,7 +566,9 @@ def normalize_data(self, data, scaler_type="standard", fit=True): def get_torch_dataset(self, dat, ann, samples): features = {x: dat[x].index for x in dat.keys()} - dat = {x: torch.from_numpy(np.array(dat[x].T)).float() for x in dat.keys()} + dat = { + x: torch.from_numpy(np.array(dat[x].T)).float() for x in dat.keys() + } ann, variable_types, label_mappings = self.encode_labels(ann) @@ -554,23 +602,28 @@ def encode_column(series): # NEW: Store encoder for inference mode self.label_encoders[series.name] = self.encoders[series.name] else: - encoded_series = self.encoders[series.name].transform(series.to_frame()) + encoded_series = self.encoders[series.name].transform( + series.to_frame() + ) # also save label mappings label_mappings[series.name] = { int(code): label - for code, label in enumerate(self.encoders[series.name].categories_[0]) + for code, label in enumerate( + self.encoders[series.name].categories_[0] + ) } return encoded_series.ravel() # Select only the categorical columns - df_categorical = df.select_dtypes(include=["object", "category"]).apply( - encode_column - ) + df_categorical = df.select_dtypes( + include=["object", "category"] + ).apply(encode_column) # Combine the encoded categorical data with the numerical data df_encoded = pd.concat( - [df.select_dtypes(exclude=["object", "category"]), df_categorical], axis=1 + [df.select_dtypes(exclude=["object", "category"]), df_categorical], + axis=1, ) # Store the variable types @@ -578,7 +631,9 @@ def encode_column(series): variable_types.update( { col: "numerical" - for col in df.select_dtypes(exclude=["object", "category"]).columns + for col in df.select_dtypes( + exclude=["object", "category"] + ).columns } ) @@ -596,7 +651,9 @@ def check_rownames(dat, split): for file_name, df in dat.items(): if not df.index.is_unique: identifier_type = ( - "Sample labels" if file_name == "clin" else "Feature names" + "Sample labels" + if file_name == "clin" + else "Feature names" ) errors.append( f"Error in {split}/{file_name}.csv: {identifier_type} in the first column must be unique." @@ -623,7 +680,9 @@ def check_common_features(train_dat, test_dat): if file_name != "clin" and file_name in test_dat: train_features = set(train_dat[file_name].index) test_features = set(test_dat[file_name].index) - common_features = train_features.intersection(test_features) + common_features = train_features.intersection( + test_features + ) if not common_features: errors.append( f"Error: No common features found between train/{file_name}.csv and test/{file_name}.csv." @@ -647,7 +706,9 @@ def check_common_features(train_dat, test_dat): print("[INFO] Found problems with the input data:\n") for i, error in enumerate(errors, 1): print(f"[ERROR] {i}. {error}") - raise Exception("[ERROR] Please correct the above errors and try again.") + raise Exception( + "[ERROR] Please correct the above errors and try again." + ) if not warnings and not errors: print("[INFO] Data structure is valid with no errors or warnings.") @@ -726,7 +787,9 @@ def import_data(self): else: # Filter out 'covariates' from modalities_to_load # Covariates will be created from clinical data later - modalities_to_load = [m for m in self.modalities if m != "covariates"] + modalities_to_load = [ + m for m in self.modalities if m != "covariates" + ] # Load each modality (skip 'covariates' - it's in clin.csv) for modality in modalities_to_load: @@ -734,7 +797,9 @@ def import_data(self): continue # Covariates are in clin.csv, not a separate file file_path = os.path.join(self.test_data_path, f"{modality}.csv") if not os.path.exists(file_path): - raise FileNotFoundError(f"[ERROR] Required file not found: {file_path}") + raise FileNotFoundError( + f"[ERROR] Required file not found: {file_path}" + ) df = pd.read_csv(file_path, index_col=0) # Transpose if needed: data files have features as rows, samples as columns @@ -786,12 +851,17 @@ def import_data(self): # Create covariates matrix if needed if "covariates" in self.modalities and labels_df is not None: - from flexynesis.utils import create_covariate_matrix, get_variable_types + from flexynesis.utils import ( + create_covariate_matrix, + get_variable_types, + ) covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: if self.verbose: - print(f"[INFO] Creating covariate matrix for: {covariate_vars}") + print( + f"[INFO] Creating covariate matrix for: {covariate_vars}" + ) variable_types_clin = get_variable_types(labels_df) covariates_df = create_covariate_matrix( covariate_vars, variable_types_clin, labels_df @@ -821,7 +891,9 @@ def import_data(self): df_reordered = pd.DataFrame( test_data[modality].numpy(), index=old_samples ).loc[common_samples] - test_data[modality] = torch.from_numpy(df_reordered.values).float() + test_data[modality] = torch.from_numpy( + df_reordered.values + ).float() samples = common_samples @@ -844,23 +916,30 @@ def import_data(self): else: # OrdinalEncoder if valid_mask.sum() > 0: encoded[valid_mask] = encoder.transform( - labels_df[col][valid_mask].values.reshape(-1, 1) + labels_df[col][valid_mask].values.reshape( + -1, 1 + ) ).ravel() ann_dict[col] = torch.from_numpy(encoded) variable_types[col] = "categorical" label_mappings[col] = { - int(c): l for c, l in enumerate(encoder.categories_[0]) + int(c): l + for c, l in enumerate(encoder.categories_[0]) } label_mappings[col][-1] = "Unknown" # For missing values else: - ann_dict[col] = torch.from_numpy(labels_df[col].values).float() + ann_dict[col] = torch.from_numpy( + labels_df[col].values + ).float() variable_types[col] = "numerical" # Create features dict # For early fusion, get features from scalers since feature_lists only has 'all' if self.modalities == ["all"]: - modalities_for_features = self.artifacts.get("original_modalities", []) + modalities_for_features = self.artifacts.get( + "original_modalities", [] + ) # Get features from scalers for each modality features = { modality: list(self.scalers[modality].feature_names_in_) @@ -868,7 +947,8 @@ def import_data(self): } else: features = { - modality: self.feature_names[modality] for modality in self.modalities + modality: self.feature_names[modality] + for modality in self.modalities } # CRITICAL: Reorder test_data dict to match self.modalities order (model expects specific order) @@ -897,20 +977,28 @@ def import_data(self): ) # Concatenate the data tensors - concatenated_data = torch.cat([test_data[x] for x in modality_order], dim=1) + concatenated_data = torch.cat( + [test_data[x] for x in modality_order], dim=1 + ) # Chain features in the same order - all_features = list(chain(*[dataset.features[x] for x in modality_order])) + all_features = list( + chain(*[dataset.features[x] for x in modality_order]) + ) # Filter to expected features from artifacts expected_all_features = self.feature_names["all"] feature_indices = [ - i for i, f in enumerate(all_features) if f in expected_all_features + i + for i, f in enumerate(all_features) + if f in expected_all_features ] # Update dataset with concatenated data dataset.dat = {"all": concatenated_data[:, feature_indices]} - dataset.features = {"all": [all_features[i] for i in feature_indices]} + dataset.features = { + "all": [all_features[i] for i in feature_indices] + } return dataset @@ -1004,14 +1092,17 @@ def get_feature_subset(self, feature_df): A pandas DataFrame that concatenates the data matrices for the specified features from all layers. """ # Convert the DataFrame to a dictionary - feature_dict = feature_df.groupby("layer")["name"].apply(list).to_dict() + feature_dict = ( + feature_df.groupby("layer")["name"].apply(list).to_dict() + ) dfs = [] for layer, features in feature_dict.items(): if layer in self.dat: # Create a dictionary to look up indices by feature name for each layer feature_index_dict = { - feature: i for i, feature in enumerate(self.features[layer]) + feature: i + for i, feature in enumerate(self.features[layer]) } # Get the indices for the requested features indices = [ @@ -1074,7 +1165,10 @@ def __getitem__(self, index): real_index = self.valid_indices[index] # get anchor sample and its label - anchor, y_dict = self.dataset[real_index][0], self.dataset[real_index][1] + anchor, y_dict = ( + self.dataset[real_index][0], + self.dataset[real_index][1], + ) # choose another sample with same label label = y_dict[self.main_var].item() @@ -1087,7 +1181,9 @@ def __getitem__(self, index): import random negative_label = random.choice(list(self.labels_set - set([label]))) - negative_index = np.random.choice(self.label_to_indices[negative_label]) + negative_index = np.random.choice( + self.label_to_indices[negative_label] + ) pos = self.dataset[positive_index][0] # positive example neg = self.dataset[negative_index][0] # negative example @@ -1119,7 +1215,9 @@ def __init__(self, multiomic_dataset, interaction_df, modality_order=None): self.multiomic_dataset = multiomic_dataset self.interaction_df = interaction_df self.modality_order = ( - modality_order if modality_order else sorted(multiomic_dataset.dat.keys()) + modality_order + if modality_order + else sorted(multiomic_dataset.dat.keys()) ) # Compute union of features in the data matrices that also appear in the network @@ -1145,7 +1243,10 @@ def __init__(self, multiomic_dataset, interaction_df, modality_order=None): def find_union_features(self): # Find the union of all features in the multiomic dataset all_omic_features = set().union( - *(set(features) for features in self.multiomic_dataset.features.values()) + *( + set(features) + for features in self.multiomic_dataset.features.values() + ) ) # Find the union of proteins involved in interactions interaction_genes = set(self.interaction_df["protein1"]).union( @@ -1161,7 +1262,10 @@ def create_edge_index(self): & (self.interaction_df["protein2"].isin(self.common_features)) ] edge_list = [ - (self.gene_to_index[row["protein1"]], self.gene_to_index[row["protein2"]]) + ( + self.gene_to_index[row["protein1"]], + self.gene_to_index[row["protein2"]], + ) for index, row in filtered_df.iterrows() ] return torch.tensor(edge_list, dtype=torch.long).t() @@ -1171,7 +1275,9 @@ def precompute_node_features(self): num_nodes = len(self.common_features) num_data_types = len(self.multiomic_dataset.dat) all_features = torch.full( - (num_samples, num_nodes, num_data_types), float("nan"), dtype=torch.float + (num_samples, num_nodes, num_data_types), + float("nan"), + dtype=torch.float, ) # CRITICAL: Use sorted keys to ensure consistent order across training/inference @@ -1192,7 +1298,9 @@ def precompute_node_features(self): ) # Fill in the available data - all_features[:, feature_positions, i] = data_matrix[:, valid_indices] + all_features[:, feature_positions, i] = data_matrix[ + :, valid_indices + ] # Precompute medians for all data types, ignoring NaN values medians = torch.nanmedian( @@ -1215,7 +1323,8 @@ def subset(self, indices): def __getitem__(self, idx): node_features_tensor = self.node_features_tensor[idx] y_dict = { - target_name: self.labels[target_name][idx] for target_name in self.labels + target_name: self.labels[target_name][idx] + for target_name in self.labels } return node_features_tensor, y_dict, self.samples[idx] @@ -1232,7 +1341,9 @@ def print_stats(self): # Calculate degree for each node degrees = torch.zeros(num_nodes, dtype=torch.long) - degrees.index_add_(0, self.edge_index[0], torch.ones_like(self.edge_index[0])) + degrees.index_add_( + 0, self.edge_index[0], torch.ones_like(self.edge_index[0]) + ) degrees.index_add_( 0, self.edge_index[1], torch.ones_like(self.edge_index[1]) ) # For undirected graphs @@ -1241,7 +1352,9 @@ def print_stats(self): non_singletons = degrees[degrees > 0] mean_edges_per_node = ( - non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 + non_singletons.float().mean().item() + if len(non_singletons) > 0 + else 0 ) median_edges_per_node = ( non_singletons.median().item() if len(non_singletons) > 0 else 0 @@ -1351,7 +1464,10 @@ def get(self, idx: int): @property def raw_file_names(self) -> list[str]: - return [f"{self.organism}.protein.{f}.v{self.version}.txt" for f in self.files] + return [ + f"{self.organism}.protein.{f}.v{self.version}.txt" + for f in self.files + ] @property def processed_file_names(self) -> str: @@ -1365,7 +1481,9 @@ def _ensure_raw_in_cache(self) -> None: dest = self._cache_raw_dir / fname if dest.exists(): continue - url = self.url.format(organism=self.organism, data=d, version=self.version) + url = self.url.format( + organism=self.organism, data=d, version=self.version + ) gz_path = download_url(url, str(self._cache_raw_dir)) extract_gz(gz_path, str(self._cache_raw_dir)) os.unlink(gz_path) @@ -1427,11 +1545,15 @@ def read_user_graph(fpath, sep=None, header="infer", **pd_read_csv_kw): sniffer = csv.Sniffer() dialect = sniffer.sniff(sample, delimiters="\t,| ") sep = dialect.delimiter - print(f"[INFO] Auto-detected separator using CSV Sniffer: {repr(sep)}") + print( + f"[INFO] Auto-detected separator using CSV Sniffer: {repr(sep)}" + ) except csv.Error: # Fallback to tab if Sniffer fails sep = "\t" - print(f"[INFO] CSV Sniffer failed, using default separator: {repr(sep)}") + print( + f"[INFO] CSV Sniffer failed, using default separator: {repr(sep)}" + ) # Read the file df = pd.read_csv(fpath, sep=sep, header=header, **pd_read_csv_kw) @@ -1468,7 +1590,15 @@ def read_user_graph(fpath, sep=None, header="infer", **pd_read_csv_kw): "source", "from", ], - "gene_b": ["geneb", "gene_b", "gene2", "protein2", "node2", "target", "to"], + "gene_b": [ + "geneb", + "gene_b", + "gene2", + "protein2", + "node2", + "target", + "to", + ], "score": [ "score", "weight", @@ -1573,7 +1703,9 @@ def score_column_match(col, col_idx, category, total_cols): # Convert score to numeric if not already if not pd.api.types.is_numeric_dtype(result_df["combined_score"]): try: - result_df["combined_score"] = pd.to_numeric(result_df["combined_score"]) + result_df["combined_score"] = pd.to_numeric( + result_df["combined_score"] + ) except ValueError: raise ValueError( f"Score column must contain numeric values. " @@ -1588,7 +1720,9 @@ def score_column_match(col, col_idx, category, total_cols): f"[WARNING] Removed {original_len - len(result_df)} rows with missing values" ) - print(f"[INFO] Successfully loaded user graph: {len(result_df)} interactions") + print( + f"[INFO] Successfully loaded user graph: {len(result_df)} interactions" + ) print( f"[INFO] Score range: [{result_df['combined_score'].min():.4f}, {result_df['combined_score'].max():.4f}]" ) @@ -1629,13 +1763,19 @@ def read_stringdb_links(fname, top_neighbors=5): ] ) # Sort the expanded DataFrame by 'combined_score' in descending order - df_expanded_sorted = df_expanded.sort_values(by="combined_score", ascending=False) + df_expanded_sorted = df_expanded.sort_values( + by="combined_score", ascending=False + ) # Reduce to unique interactions to avoid counting duplicates df_expanded_unique = df_expanded_sorted.drop_duplicates( subset=["protein", "partner"] ) - top_interactions = df_expanded_unique.groupby("protein").head(top_neighbors) - df = top_interactions.rename(columns={"protein": "protein1", "partner": "protein2"}) + top_interactions = df_expanded_unique.groupby("protein").head( + top_neighbors + ) + df = top_interactions.rename( + columns={"protein": "protein1", "partner": "protein2"} + ) df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map( lambda a: a.split(".")[-1] ) @@ -1658,7 +1798,10 @@ def read_stringdb_aliases(fname: str, node_name: str) -> dict[str, str]: protein_id_to_gene_id[data[0].split(".")[1]] = data[1] elif data[-1].endswith(source[1]): # TODO: Check here if the values are the same - if protein_id_to_gene_id.get(data[0].split(".")[1], None) is None: + if ( + protein_id_to_gene_id.get(data[0].split(".")[1], None) + is None + ): protein_id_to_gene_id[data[0].split(".")[1]] = data[1] else: continue @@ -1674,7 +1817,9 @@ def read_stringdb_graph(node_name, edges_data_path, nodes_data_path): if node_name in ("gene_name", "gene_id"): node_name_mapping = read_stringdb_aliases(nodes_data_path, node_name) else: - raise NotImplementedError("Node name must be either 'gene_name' or 'gene_id'.") + raise NotImplementedError( + "Node name must be either 'gene_name' or 'gene_id'." + ) def fn(a): try: @@ -1684,7 +1829,9 @@ def fn(a): out = pd.NA return out - graph_df[["protein1", "protein2"]] = graph_df[["protein1", "protein2"]].map(fn) + graph_df[["protein1", "protein2"]] = graph_df[ + ["protein1", "protein2"] + ].map(fn) return graph_df @@ -1700,7 +1847,9 @@ def split_by_median(tensor_dict): if tensor.dtype in {torch.float16, torch.float32, torch.float64}: # Remove NaNs and compute median tensor_no_nan = tensor[torch.isfinite(tensor)] - median_val = tensor_no_nan.sort().values[tensor_no_nan.numel() // 2] + median_val = tensor_no_nan.sort().values[ + tensor_no_nan.numel() // 2 + ] # Convert to categorical, but preserve NaNs tensor_cat = (tensor > median_val).float() diff --git a/flexynesis/feature_selection.py b/flexynesis/feature_selection.py index a80ef39e..3988d59b 100644 --- a/flexynesis/feature_selection.py +++ b/flexynesis/feature_selection.py @@ -151,7 +151,9 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): "correlated_with": X.columns[ redundant_features[idx]["correlated_with"] ], - "correlation_score": redundant_features[idx]["correlation_score"], + "correlation_score": redundant_features[idx][ + "correlation_score" + ], } for idx in redundant_features ] @@ -161,7 +163,9 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): return X.columns[selected_features], pd.DataFrame() -def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0.9): +def filter_by_laplacian( + X, layer, k=5, t=None, topN=100, correlation_threshold=0.9 +): """ Filters features in a dataset based on Laplacian scores and removes highly correlated features, retaining only the top N features with the lowest scores and optionally considering correlation. @@ -221,7 +225,9 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 " samples ", ) - feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": np.nan}) + feature_log = pd.DataFrame( + {"feature": X.columns, "laplacian_score": np.nan} + ) # only apply filtering if topN < n_features if topN >= X.shape[1]: print( @@ -233,7 +239,9 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 # compute laplacian scores scores = laplacian_score(X.values, k, t) - feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": scores}) + feature_log = pd.DataFrame( + {"feature": X.columns, "laplacian_score": scores} + ) # Sort the features based on their Laplacian Scores sorted_indices = np.argsort(scores) @@ -270,6 +278,8 @@ def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0 X_selected = X[selected_features] feature_log["selected"] = False - feature_log.loc[feature_log["feature"].isin(selected_features), "selected"] = True + feature_log.loc[ + feature_log["feature"].isin(selected_features), "selected" + ] = True return X_selected, feature_log diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index 2b2acfb6..b9dfed8c 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -79,7 +79,9 @@ def build_network( pbar.update(end_i - i) data = ranks elif method != "pearson": - raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'") + raise ValueError( + f"Unknown method: {method}. Use 'spearman' or 'pearson'" + ) # Standardize for correlation computation print("Standardizing data...") @@ -185,13 +187,17 @@ def generate_coexpression_network( print(f"[ERROR] Failed to load file: {e2}") sys.exit(1) - print(f" Expression matrix: {expr_df.shape[0]} genes × {expr_df.shape[1]} samples") + print( + f" Expression matrix: {expr_df.shape[0]} genes × {expr_df.shape[1]} samples" + ) # Check for missing values na_count = expr_df.isna().sum().sum() if na_count > 0: genes_with_na = expr_df.isna().any(axis=1).sum() - print(f" [WARNING] Found {na_count} missing values in {genes_with_na} genes") + print( + f" [WARNING] Found {na_count} missing values in {genes_with_na} genes" + ) print(f" [INFO] Removing genes with missing data.") expr_df = expr_df.dropna() print( @@ -208,7 +214,9 @@ def generate_coexpression_network( network_df = pd.DataFrame(edges) if len(network_df) == 0: - print("[WARNING] No edges found! Try lowering min_correlation threshold.") + print( + "[WARNING] No edges found! Try lowering min_correlation threshold." + ) print("[ERROR] No edges in network! Exiting.") sys.exit(1) @@ -219,7 +227,9 @@ def generate_coexpression_network( lambda row: tuple(sorted([row["GeneA"], row["GeneB"]])), axis=1 ) original_len = len(network_df) - network_df = network_df.drop_duplicates(subset="pair").drop(columns="pair") + network_df = network_df.drop_duplicates(subset="pair").drop( + columns="pair" + ) print(f" Removed {original_len - len(network_df)} duplicate edges") # Save network (auto-detect format from extension) @@ -306,7 +316,10 @@ def main(): ) parser.add_argument( - "--output", "-o", required=True, help="Output network file (CSV/TSV supported)" + "--output", + "-o", + required=True, + help="Output network file (CSV/TSV supported)", ) parser.add_argument( diff --git a/flexynesis/inference.py b/flexynesis/inference.py index e40975ff..d5d570e8 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -15,7 +15,10 @@ "DirectPred": ("flexynesis.models.direct_pred", "DirectPred"), "supervised_vae": ("flexynesis.models.supervised_vae", "supervised_vae"), "CrossModalPred": ("flexynesis.models.crossmodal_pred", "CrossModalPred"), - "MultiTripletNetwork": ("flexynesis.models.triplet_encoder", "MultiTripletNetwork"), + "MultiTripletNetwork": ( + "flexynesis.models.triplet_encoder", + "MultiTripletNetwork", + ), "GNN": ("flexynesis.models.gnn_early", "GNN"), } @@ -99,7 +102,9 @@ def _build_dataset_namespace(config, artifacts): cats = ( enc.categories_[0].tolist() if hasattr(enc, "categories_") - else (enc.get("categories", [[]])[0] if isinstance(enc, dict) else []) + else ( + enc.get("categories", [[]])[0] if isinstance(enc, dict) else [] + ) ) if cats: ann[var] = cats @@ -123,11 +128,15 @@ def _resolve_input_dims(config, artifacts): """Ensure input_dims is present in config, deriving from feature_lists if needed.""" feature_lists = artifacts.get("feature_lists", {}) layers = ( - config.get("input_layers") or config.get("layers") or list(feature_lists.keys()) + config.get("input_layers") + or config.get("layers") + or list(feature_lists.keys()) ) input_dims = config.get("input_dims") if not input_dims: - input_dims = [len(feature_lists[l]) for l in layers if l in feature_lists] + input_dims = [ + len(feature_lists[l]) for l in layers if l in feature_lists + ] config["input_dims"] = input_dims return config @@ -178,7 +187,11 @@ def check_file_type(file_path): def _deserialize_json_artifacts(artifacts): import numpy as np - from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler + from sklearn.preprocessing import ( + LabelEncoder, + OrdinalEncoder, + StandardScaler, + ) # Rebuild sklearn objects expected by inference code. deserialized = dict(artifacts) @@ -229,7 +242,9 @@ def _deserialize_json_artifacts(artifacts): encoder_type = encoder_dict.get("type") if encoder_type == "LabelEncoder": enc = LabelEncoder() - enc.classes_ = np.array(encoder_dict.get("classes", []), dtype=object) + enc.classes_ = np.array( + encoder_dict.get("classes", []), dtype=object + ) label_encoders[variable] = enc continue @@ -269,10 +284,18 @@ def _deserialize_json_artifacts(artifacts): setattr( enc, "_missing_indices", - {int(k): v for k, v in mi.items()} if isinstance(mi, dict) else mi, + ( + {int(k): v for k, v in mi.items()} + if isinstance(mi, dict) + else mi + ), ) if "_infrequent_enabled" in encoder_dict: - setattr(enc, "_infrequent_enabled", encoder_dict["_infrequent_enabled"]) + setattr( + enc, + "_infrequent_enabled", + encoder_dict["_infrequent_enabled"], + ) label_encoders[variable] = enc continue @@ -292,7 +315,9 @@ def _load_artifacts(artifacts_path): return raw -def reconstruct_model(safetensors_path, config_path, artifacts_path, device="cpu"): +def reconstruct_model( + safetensors_path, config_path, artifacts_path, device="cpu" +): """ Reconstruct a full Flexynesis model from: - safetensors_path : .safetensors weights file diff --git a/flexynesis/main.py b/flexynesis/main.py index 2bfcdb8f..be37e315 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -9,7 +9,9 @@ import torch import yaml from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme +from lightning.pytorch.callbacks.progress.rich_progress import ( + RichProgressBarTheme, +) from skopt import Optimizer from skopt.space import Categorical, Integer, Real from torch.utils.data import DataLoader, random_split @@ -94,9 +96,7 @@ def __init__( num_workers=2, ): self.dataset = dataset # dataset for model initiation - self.loader_dataset = ( - dataset # dataset for defining data loaders (this can be model specific) - ) + self.loader_dataset = dataset # dataset for defining data loaders (this can be model specific) self.model_class = model_class self.target_variables = target_variables self.device_type = device_type @@ -110,9 +110,7 @@ def __init__( self.batch_variables = batch_variables self.config_name = config_name self.n_iter = n_iter - self.plot_losses = ( - plot_losses # Whether to show live loss plots (useful in interactive mode) - ) + self.plot_losses = plot_losses # Whether to show live loss plots (useful in interactive mode) self.val_size = val_size self.use_cv = use_cv self.n_splits = cv_splits @@ -182,10 +180,14 @@ def get_batch_space(self, min_size=32, max_size=128): end = int(np.log2(max_size)) if m < end: end = m - s = Categorical([np.power(2, x) for x in range(st, end + 1)], name="batch_size") + s = Categorical( + [np.power(2, x) for x in range(st, end + 1)], name="batch_size" + ) return s - def setup_trainer(self, params, current_step, total_steps, full_train=False): + def setup_trainer( + self, params, current_step, total_steps, full_train=False + ): # Configure callbacks and trainer for the current fold mycallbacks = [] if self.plot_losses: @@ -280,8 +282,12 @@ def objective(self, params, current_step, total_steps, full_train=False): print( f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}" ) - train_subset = torch.utils.data.Subset(self.loader_dataset, train_index) - val_subset = torch.utils.data.Subset(self.loader_dataset, val_index) + train_subset = torch.utils.data.Subset( + self.loader_dataset, train_index + ) + val_subset = torch.utils.data.Subset( + self.loader_dataset, val_index + ) train_loader = self.DataLoader( train_subset, batch_size=int(params["batch_size"]), @@ -308,13 +314,17 @@ def objective(self, params, current_step, total_steps, full_train=False): ) print(f"[INFO] hpo config:{params}") trainer.fit( - model, train_dataloaders=train_loader, val_dataloaders=val_loader + model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, ) if early_stop_callback.stopped_epoch: epochs.append(early_stop_callback.stopped_epoch) else: epochs.append(int(params["epochs"])) - validation_result = trainer.validate(model, dataloaders=val_loader) + validation_result = trainer.validate( + model, dataloaders=val_loader + ) val_loss = validation_result[0]["val_loss"] validation_losses.append(val_loss) i += 1 @@ -351,7 +361,9 @@ def perform_tuning(self, hpo_patience=0): for param, value in zip(self.space, suggested_params_list) } loss, avg_epochs, model = self.objective( - suggested_params_dict, current_step=i + 1, total_steps=self.n_iter + suggested_params_dict, + current_step=i + 1, + total_steps=self.n_iter, ) if self.use_cv: print( @@ -364,9 +376,13 @@ def perform_tuning(self, hpo_patience=0): best_params = suggested_params_list best_epochs = avg_epochs best_model = model - no_improvement_count = 0 # Reset the no improvement counter + no_improvement_count = ( + 0 # Reset the no improvement counter + ) else: - no_improvement_count += 1 # Increment the no improvement counter + no_improvement_count += ( + 1 # Increment the no improvement counter + ) # Print result of each iteration pbar.set_postfix({"Iteration": i + 1, "Best Loss": best_loss}) @@ -379,7 +395,10 @@ def perform_tuning(self, hpo_patience=0): ) break # Break out of the loop best_params_dict = ( - {param.name: value for param, value in zip(self.space, best_params)} + { + param.name: value + for param, value in zip(self.space, best_params) + } if best_params else None ) @@ -399,7 +418,10 @@ def perform_tuning(self, hpo_patience=0): f"[INFO] Building a final model using best params: {best_params_dict}" ) best_model = self.objective( - best_params_dict, current_step=0, total_steps=1, full_train=True + best_params_dict, + current_step=0, + total_steps=1, + full_train=True, ) return best_model, best_params_dict @@ -500,7 +522,11 @@ def __init__( self.learning_rates = ( learning_rates if learning_rates - else [model.config["lr"], model.config["lr"] / 10, model.config["lr"] / 100] + else [ + model.config["lr"], + model.config["lr"] / 10, + model.config["lr"] / 100, + ] ) self.folds_data = list(self.kfold.split(np.arange(len(self.dataset)))) self.max_epoch = max_epoch @@ -533,7 +559,9 @@ def train_dataloader(self): # Override to load data for the current fold train_idx, val_idx = self.folds_data[self.current_fold] train_subset = torch.utils.data.Subset(self.dataset, train_idx) - return DataLoader(train_subset, batch_size=self.batch_size, shuffle=True) + return DataLoader( + train_subset, batch_size=self.batch_size, shuffle=True + ) def val_dataloader(self): # Override to load validation data for the current fold @@ -573,7 +601,10 @@ def run_experiments(self): self.current_fold = fold self.learning_rate = lr early_stopping = EarlyStopping( - monitor="val_loss", patience=3, verbose=False, mode="min" + monitor="val_loss", + patience=3, + verbose=False, + mode="min", ) trainer = pl.Trainer( max_epochs=self.max_epoch, @@ -614,7 +645,9 @@ def run_experiments(self): ) # Find the best configuration based on validation loss - best_config = min(val_loss_results, key=lambda x: x["average_val_loss"]) + best_config = min( + val_loss_results, key=lambda x: x["average_val_loss"] + ) print( f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}", diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 6ce742a2..b5edd955 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -46,7 +46,9 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [self.surv_event_var] + self.target_variables = self.target_variables + [ + self.surv_event_var + ] self.batch_variables = batch_variables self.variables = ( self.target_variables + self.batch_variables @@ -56,7 +58,9 @@ def __init__( self.variable_types = dataset.variable_types self.ann = dataset.ann - self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) + self.input_layers = ( + input_layers if input_layers else list(dataset.dat.keys()) + ) self.output_layers = ( output_layers if output_layers else list(dataset.dat.keys()) ) @@ -174,7 +178,9 @@ def forward(self, x_list_input): z = self.reparameterization(mean, log_var) # decode the latent space to target output layer(s) - x_hat_list = [self.decoders[i](z) for i in range(len(self.output_layers))] + x_hat_list = [ + self.decoders[i](z) for i in range(len(self.output_layers)) + ] # run the supervisor heads using the latent layer as input outputs = {} @@ -251,7 +257,9 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) return loss def compute_total_loss(self, losses): @@ -527,7 +535,9 @@ def MMD_loss(self, latent_dim, z, xhat, x): # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): - input_data = list(args[:-2]) # one or more tensors (one per omics layer) + input_data = list( + args[:-2] + ) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] @@ -607,15 +617,22 @@ def compute_feature_importance( for batch in dataloader: dat, _, _ = batch - x_list = [to_device_safe(dat[x], device) for x in self.input_layers] - input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) + x_list = [ + to_device_safe(dat[x], device) for x in self.input_layers + ] + input_data = tuple( + [data.unsqueeze(0).requires_grad_() for data in x_list] + ) if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == "GradientShap": # provide multiple baselines for Gr.Shap + elif ( + method == "GradientShap" + ): # provide multiple baselines for Gr.Shap baseline = tuple( torch.cat( - [torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0 + [torch.zeros_like(x) for _ in range(steps_or_samples)], + dim=0, ) for x in input_data ) @@ -643,7 +660,10 @@ def compute_feature_importance( attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_steps=steps_or_samples, ) @@ -651,7 +671,10 @@ def compute_feature_importance( attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_samples=steps_or_samples, ) @@ -668,7 +691,9 @@ def compute_feature_importance( # Process each layer within the class for layer_idx in range(num_layers): # Extract all batch tensors for this layer across all batches for the current class - layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] + layer_tensors = [ + batch_attr[layer_idx] for batch_attr in class_attr + ] # Concatenate tensors along the batch dimension attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 564a15ab..b649bec7 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -45,7 +45,9 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [self.surv_event_var] + self.target_variables = self.target_variables + [ + self.surv_event_var + ] self.batch_variables = batch_variables self.variables = ( self.target_variables + batch_variables @@ -66,7 +68,8 @@ def __init__( self.ann = dataset.ann self.layers = list(dataset.dat.keys()) self.input_dims = [ - len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) + len(dataset.features[self.layers[i]]) + for i in range(len(self.layers)) ] self.encoders = nn.ModuleList( @@ -184,7 +187,9 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) return loss def compute_total_loss(self, losses): @@ -332,7 +337,9 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: - logits = outputs[var].detach().cpu() # Raw model outputs (logits) + logits = ( + outputs[var].detach().cpu() + ) # Raw model outputs (logits) if dataset.variable_types[var] == "categorical": probs = torch.softmax( @@ -372,7 +379,9 @@ def transform(self, dataset): dataset, batch_size=64, shuffle=False ) # Adjust the batch size as needed - embeddings_list = [] # Initialize a list to collect all batch embeddings + embeddings_list = ( + [] + ) # Initialize a list to collect all batch embeddings sample_names = [] # List to collect sample names # Process each batch @@ -414,7 +423,9 @@ def transform(self, dataset): # Adaptor forward function for captum integrated gradients or gradient shap def forward_target(self, *args): - input_data = list(args[:-2]) # one or more tensors (one per omics layer) + input_data = list( + args[:-2] + ) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest steps = args[ -1 @@ -489,14 +500,19 @@ def compute_feature_importance( for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] - input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) + input_data = tuple( + [data.unsqueeze(0).requires_grad_() for data in x_list] + ) if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == "GradientShap": # provide multiple baselines for Gr.Shap + elif ( + method == "GradientShap" + ): # provide multiple baselines for Gr.Shap baseline = tuple( torch.cat( - [torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0 + [torch.zeros_like(x) for _ in range(steps_or_samples)], + dim=0, ) for x in input_data ) @@ -524,7 +540,10 @@ def compute_feature_importance( attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_steps=steps_or_samples, ) @@ -532,7 +551,10 @@ def compute_feature_importance( attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_samples=steps_or_samples, ) @@ -545,7 +567,9 @@ def compute_feature_importance( class_attr = aggregated_attributions[class_idx] layer_attributions = [] for layer_idx in range(num_layers): - layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] + layer_tensors = [ + batch_attr[layer_idx] for batch_attr in class_attr + ] attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index aad8435f..87ec49d2 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -70,7 +70,9 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [self.surv_event_var] + self.target_variables = self.target_variables + [ + self.surv_event_var + ] self.batch_variables = batch_variables self.variables = ( self.target_variables + self.batch_variables @@ -283,7 +285,9 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) return loss def compute_total_loss(self, losses): @@ -355,7 +359,9 @@ def predict(self, dataset): outputs = self.forward(x, edge_index) for var in self.variables: - logits = outputs[var].detach().cpu() # Raw model outputs (logits) + logits = ( + outputs[var].detach().cpu() + ) # Raw model outputs (logits) if dataset.variable_types[var] == "categorical": probs = torch.softmax( @@ -422,7 +428,9 @@ def transform(self, dataset): # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): - input_data = list(args[:-2]) # expect a single tensor (early integration) + input_data = list( + args[:-2] + ) # expect a single tensor (early integration) target_var = args[-2] # target variable of interest steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] @@ -519,9 +527,14 @@ def bytes_to_gb(bytes): if method == "IntegratedGradients": baseline = torch.zeros_like(input_data) - elif method == "GradientShap": # provide multiple baselines for Gr.Shap + elif ( + method == "GradientShap" + ): # provide multiple baselines for Gr.Shap baseline = torch.cat( - [torch.zeros_like(input_data) for _ in range(steps_or_samples)], + [ + torch.zeros_like(input_data) + for _ in range(steps_or_samples) + ], dim=0, ) @@ -548,7 +561,10 @@ def bytes_to_gb(bytes): attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_steps=steps_or_samples, ) @@ -556,7 +572,10 @@ def bytes_to_gb(bytes): attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_samples=steps_or_samples, ) @@ -567,12 +586,15 @@ def bytes_to_gb(bytes): for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] # Concatenate tensors along the batch dimension - attr_concat = torch.cat([batch_attr for batch_attr in class_attr], dim=1) + attr_concat = torch.cat( + [batch_attr for batch_attr in class_attr], dim=1 + ) processed_attributions.append(attr_concat) # compute absolute importance and move to cpu abs_attr = [ - torch.abs(attr_class).cpu() for attr_class in processed_attributions + torch.abs(attr_class).cpu() + for attr_class in processed_attributions ] # average over samples imp = [a.mean(dim=1) for a in abs_attr] @@ -590,7 +612,9 @@ def bytes_to_gb(bytes): # MPS and CPU don't have detailed memory tracking like CUDA df_list = [] - layers = list(getattr(dataset, "multiomic_dataset", dataset).dat.keys()) + layers = list( + getattr(dataset, "multiomic_dataset", dataset).dat.keys() + ) for i in range(num_class): features = dataset.common_features target_class_label = ( @@ -602,9 +626,7 @@ def bytes_to_gb(bytes): # Extracting node feature attributes coming from different omic layers importances_array = imp[i].squeeze().detach().numpy() if importances_array.ndim == 1: - importances = ( - importances_array # Use the array as is if it is 1-dimensional - ) + importances = importances_array # Use the array as is if it is 1-dimensional else: importances = importances_array[ :, l diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index a0f8d919..bddaaacc 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -58,7 +58,9 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [self.surv_event_var] + self.target_variables = self.target_variables + [ + self.surv_event_var + ] self.batch_variables = batch_variables self.variables = ( self.target_variables + batch_variables @@ -81,7 +83,9 @@ def __init__( self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) layers = list(dataset.dat.keys()) - input_dims = [len(dataset.features[layers[i]]) for i in range(len(layers))] + input_dims = [ + len(dataset.features[layers[i]]) for i in range(len(layers)) + ] # create a list of Encoder instances for separately encoding each omics layer self.encoders = nn.ModuleList( [ @@ -251,7 +255,9 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) return loss def compute_total_loss(self, losses): @@ -425,7 +431,9 @@ def transform(self, dataset): sample_names.extend(samples) # Collect sample names for this batch # Concatenate all batch latent representations into one array - concatenated_latents = np.concatenate(all_latent_representations, axis=0) + concatenated_latents = np.concatenate( + all_latent_representations, axis=0 + ) # Convert the array to a DataFrame z = pd.DataFrame(concatenated_latents) @@ -474,7 +482,9 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: - logits = outputs[var].detach().cpu() # Raw model outputs (logits) + logits = ( + outputs[var].detach().cpu() + ) # Raw model outputs (logits) if dataset.variable_types[var] == "categorical": probs = torch.softmax( @@ -550,7 +560,9 @@ def MMD_loss(self, latent_dim, z, xhat, x): # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): - input_data = list(args[:-2]) # one or more tensors (one per omics layer) + input_data = list( + args[:-2] + ) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] @@ -630,14 +642,19 @@ def compute_feature_importance( for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] - input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) + input_data = tuple( + [data.unsqueeze(0).requires_grad_() for data in x_list] + ) if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif method == "GradientShap": # provide multiple baselines for Gr.Shap + elif ( + method == "GradientShap" + ): # provide multiple baselines for Gr.Shap baseline = tuple( torch.cat( - [torch.zeros_like(x) for _ in range(steps_or_samples)], dim=0 + [torch.zeros_like(x) for _ in range(steps_or_samples)], + dim=0, ) for x in input_data ) @@ -664,7 +681,10 @@ def compute_feature_importance( attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_steps=steps_or_samples, ) @@ -672,7 +692,10 @@ def compute_feature_importance( attributions = explainer.attribute( input_data, baseline, - additional_forward_args=(target_var, steps_or_samples), + additional_forward_args=( + target_var, + steps_or_samples, + ), target=target_class, n_samples=steps_or_samples, ) @@ -689,7 +712,9 @@ def compute_feature_importance( # Process each layer within the class for layer_idx in range(num_layers): # Extract all batch tensors for this layer across all batches for the current class - layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] + layer_tensors = [ + batch_attr[layer_idx] for batch_attr in class_attr + ] # Concatenate tensors along the batch dimension attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 11b796f9..8fe1a407 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -52,7 +52,9 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [self.surv_event_var] + self.target_variables = self.target_variables + [ + self.surv_event_var + ] self.batch_variables = batch_variables self.variables = ( self.target_variables + batch_variables @@ -83,7 +85,8 @@ def __init__( self.layers = list(dataset.dat.keys()) self.input_dims = [ - len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) + len(dataset.features[self.layers[i]]) + for i in range(len(self.layers)) ] self.encoders = nn.ModuleList( @@ -157,7 +160,12 @@ def forward(self, anchor, positive, negative): outputs = {} for var, mlp in self.MLPs.items(): outputs[var] = mlp(anchor_embedding) - return anchor_embedding, positive_embedding, negative_embedding, outputs + return ( + anchor_embedding, + positive_embedding, + negative_embedding, + outputs, + ) def configure_optimizers(self): """ @@ -231,7 +239,9 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) + loss = torch.tensor( + 0.0, device=y_hat.device, requires_grad=True + ) return loss def compute_total_loss(self, losses): @@ -442,9 +452,12 @@ def forward_target(self, input_data, layer_sizes, target_var, steps): anchor = {k: anchor[k] for k in range(len(anchor))} positive = {k: anchor[k] for k in range(len(positive))} negative = {k: anchor[k] for k in range(len(negative))} - anchor_embedding, positive_embedding, negative_embedding, outputs = ( - self.forward(anchor, positive, negative) - ) + ( + anchor_embedding, + positive_embedding, + negative_embedding, + outputs, + ) = self.forward(anchor, positive, negative) outputs_list.append(outputs[target_var]) return torch.cat(outputs_list, dim=0) @@ -498,7 +511,9 @@ def compute_feature_importance( # define data loader triplet_dataset = TripletMultiOmicDataset(dataset, self.main_var) - dataloader = DataLoader(triplet_dataset, batch_size=batch_size, shuffle=False) + dataloader = DataLoader( + triplet_dataset, batch_size=batch_size, shuffle=False + ) # Choose the attribution method dynamically if method == "IntegratedGradients": @@ -518,22 +533,38 @@ def compute_feature_importance( aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: # see training_step to see how elements are accessed in batches - anchor, positive, negative, y_dict = batch[0], batch[1], batch[2], batch[3] + anchor, positive, negative, y_dict = ( + batch[0], + batch[1], + batch[2], + batch[3], + ) # Move tensors to the specified device using MPS-safe method anchor = {k: to_device_safe(v, device) for k, v in anchor.items()} - positive = {k: to_device_safe(v, device) for k, v in positive.items()} - negative = {k: to_device_safe(v, device) for k, v in negative.items()} + positive = { + k: to_device_safe(v, device) for k, v in positive.items() + } + negative = { + k: to_device_safe(v, device) for k, v in negative.items() + } anchor = [data.requires_grad_() for data in list(anchor.values())] - positive = [data.requires_grad_() for data in list(positive.values())] - negative = [data.requires_grad_() for data in list(negative.values())] + positive = [ + data.requires_grad_() for data in list(positive.values()) + ] + negative = [ + data.requires_grad_() for data in list(negative.values()) + ] # concatenate multiomic layers of each list element # then stack the anchor/positive/negative # the purpose is to get a single tensor input_data = torch.stack( - [torch.cat(sublist, dim=1) for sublist in [anchor, positive, negative]] + [ + torch.cat(sublist, dim=1) + for sublist in [anchor, positive, negative] + ] ).unsqueeze(0) # layer sizes will be needed to revert the concatenated tensor @@ -543,9 +574,14 @@ def compute_feature_importance( # Define a baseline if method == "IntegratedGradients": baseline = torch.zeros_like(input_data) - elif method == "GradientShap": # provide multiple baselines for Gr.Shap + elif ( + method == "GradientShap" + ): # provide multiple baselines for Gr.Shap baseline = torch.cat( - [torch.zeros_like(input_data) for _ in range(steps_or_samples)], + [ + torch.zeros_like(input_data) + for _ in range(steps_or_samples) + ], dim=0, ) @@ -617,7 +653,9 @@ def compute_feature_importance( # Process each layer within the class for layer_idx in range(num_layers): # Extract all batch tensors for this layer across all batches for the current class - layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] + layer_tensors = [ + batch_attr[layer_idx] for batch_attr in class_attr + ] # Concatenate tensors along the batch dimension attr_concat = torch.cat(layer_tensors, dim=2) layer_attributions.append(attr_concat) diff --git a/flexynesis/modules.py b/flexynesis/modules.py index f84066fa..81f4a67b 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -224,7 +224,8 @@ def __init__( } if conv not in conv_options: raise ValueError( - "Unknown convolution type. Choose one of: ", list(conv_options.keys()) + "Unknown convolution type. Choose one of: ", + list(conv_options.keys()), ) self.act = act_options[act] @@ -233,7 +234,9 @@ def __init__( self.dropout = nn.Dropout(dropout_rate) # Initialize the first convolution layer separately if different input size - self.convs.append(conv_options[conv](node_feature_count, node_embedding_dim)) + self.convs.append( + conv_options[conv](node_feature_count, node_embedding_dim) + ) self.bns.append(nn.BatchNorm1d(node_embedding_dim)) # Loop to create the remaining convolution and BN layers @@ -284,7 +287,9 @@ def cox_ph_loss(outputs, durations, events): hazards = hazards.unsqueeze(0) # Make hazards 1D if it's a scalar # Calculate the risk set sum log_risk_set_sum = torch.log( - torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0) + torch.cumsum( + hazards[torch.argsort(durations, descending=True)], dim=0 + ) ) # Get the indices that sort the durations in descending order sorted_indices = torch.argsort(durations, descending=True) @@ -296,7 +301,9 @@ def cox_ph_loss(outputs, durations, events): ) - torch.sum(log_risk_set_sum[events_sorted == 1]) total_loss = -uncensored_loss / torch.sum(events) else: - total_loss = torch.tensor(0.0, device=outputs.device, requires_grad=True) + total_loss = torch.tensor( + 0.0, device=outputs.device, requires_grad=True + ) if not torch.isfinite(total_loss): return torch.tensor(0.0, device=outputs.device, requires_grad=True) return total_loss diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 6c272eb6..4c6d531a 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -190,7 +190,9 @@ def plot_dim_reduced( + theme_minimal() ) else: - raise ValueError("Invalid color_type. Choose 'categorical' or 'numerical'.") + raise ValueError( + "Invalid color_type. Choose 'categorical' or 'numerical'." + ) return p @@ -305,7 +307,11 @@ def plot_scatter(true_values, predicted_values): va="top", size=10, ) - + labs(title="True vs Predicted Values", x="True Values", y="Predicted Values") + + labs( + title="True vs Predicted Values", + x="True Values", + y="Predicted Values", + ) + theme_minimal() ) @@ -369,7 +375,9 @@ def plot_boxplot( verticalalignment="top", horizontalalignment="left", fontsize=12, - bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray"), + bbox=dict( + boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray" + ), ) plt.tight_layout() @@ -559,7 +567,9 @@ def plot_roc_curves(y_true, y_probs): ggplot(all_data, aes(x="fpr", y="tpr", color="label")) + geom_line(size=1.2) + geom_abline(intercept=0, slope=1, linetype="dashed", color="gray") - + labs(title="ROC Curve", x="False Positive Rate", y="True Positive Rate") + + labs( + title="ROC Curve", x="False Positive Rate", y="True Positive Rate" + ) + theme_minimal() ) @@ -676,7 +686,9 @@ def evaluate_wrapper( if var == surv_event_var: events = dataset.ann[surv_event_var] durations = dataset.ann[surv_time_var] - metrics = evaluate_survival(y_pred_dict[var], durations, events) + metrics = evaluate_survival( + y_pred_dict[var], durations, events + ) else: ind = ~torch.isnan(dataset.ann[var]) metrics = evaluate_regressor( @@ -684,7 +696,9 @@ def evaluate_wrapper( ) else: ind = ~torch.isnan(dataset.ann[var]) - metrics = evaluate_classifier(dataset.ann[var][ind], y_pred_dict[var][ind]) + metrics = evaluate_classifier( + dataset.ann[var][ind], y_pred_dict[var][ind] + ) for metric, value in metrics.items(): metrics_list.append( @@ -735,13 +749,16 @@ def get_predicted_labels(y_pred_dict, dataset, split, method_name): for idx in range(probabilities.shape[1]) ] else: - class_labels = [f"class_{i}" for i in range(probabilities.shape[1])] + class_labels = [ + f"class_{i}" for i in range(probabilities.shape[1]) + ] # Get true labels (y_true) y_true = [ ( dataset.label_mappings[var][int(x.item())] - if var in dataset.label_mappings.keys() and not np.isnan(x.item()) + if var in dataset.label_mappings.keys() + and not np.isnan(x.item()) else np.nan ) for x in dataset.ann[var] @@ -824,7 +841,9 @@ def evaluate_baseline_performance( def prepare_data(data_object, pca_model=None, fit_pca=False): # Concatenate Data Matrices - X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) + X = np.concatenate( + [tensor for tensor in data_object.dat.values()], axis=1 + ) y = np.array(data_object.ann[variable_name]) # Filter out samples without a valid label @@ -865,7 +884,10 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): if variable_type == "categorical": if method == "RandomForest": model = RandomForestClassifier(random_state=42) - params = {"n_estimators": [100, 200, 300], "max_depth": [10, 20, None]} + params = { + "n_estimators": [100, 200, 300], + "max_depth": [10, 20, None], + } elif method == "SVM": model = SVC(probability=True, random_state=42) params = {"C": [0.1, 1, 10], "kernel": ["rbf", "poly"]} @@ -884,7 +906,10 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): elif variable_type == "numerical": if method == "RandomForest": model = RandomForestRegressor(random_state=42) - params = {"n_estimators": [100, 200, 300], "max_depth": [10, 20, None]} + params = { + "n_estimators": [100, 200, 300], + "max_depth": [10, 20, None], + } elif method == "SVM": model = SVR() params = {"C": [0.1, 1, 10], "kernel": ["rbf", "poly"]} @@ -926,7 +951,11 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): metrics_list.append( { "method": method - + ("Classifier" if variable_type == "categorical" else "Regressor"), + + ( + "Classifier" + if variable_type == "categorical" + else "Regressor" + ), "var": variable_name, "variable_type": variable_type, "metric": metric, @@ -966,7 +995,9 @@ def evaluate_baseline_survival_performance( def prepare_data(data_object, duration_col, event_col): # Concatenate Data Matrices - X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) + X = np.concatenate( + [tensor for tensor in data_object.dat.values()], axis=1 + ) # Prepare Survival Data (Durations and Events) durations = np.array(data_object.ann[duration_col]) @@ -986,7 +1017,9 @@ def prepare_data(data_object, duration_col, event_col): X_train, y_train, train_indices = prepare_data( train_dataset, duration_col, event_col ) - X_test, y_test, test_indices = prepare_data(test_dataset, duration_col, event_col) + X_test, y_test, test_indices = prepare_data( + test_dataset, duration_col, event_col + ) # Initialize Random Survival Forest rsf = RandomSurvivalForest( @@ -1137,7 +1170,9 @@ def subset_assays_by_features(dataset, features_dict): # data matrix for each key in features_dict subset_dat = {} for layer in features_dict.keys(): - indices = [dataset.features[layer].get_loc(x) for x in features_dict[layer]] + indices = [ + dataset.features[layer].get_loc(x) for x in features_dict[layer] + ] subset_dat[layer] = dataset.dat[layer][:, indices] # Convert subset_dat to pandas DataFrame and prepend feature names with layer names df_list = [] @@ -1146,7 +1181,9 @@ def subset_assays_by_features(dataset, features_dict): df_temp = pd.DataFrame(data) # Rename columns to prepend with layer name - df_temp.columns = [f"{layer}_{feature}" for feature in features_dict[layer]] + df_temp.columns = [ + f"{layer}_{feature}" for feature in features_dict[layer] + ] df_list.append(df_temp) # Concatenate dataframes horizontally concatenated_df = pd.concat(df_list, axis=1) @@ -1248,7 +1285,9 @@ def recursive_binary_split_minN( continue try: - cutoff, pval = find_optimal_cutoff(node[score], node[time], node[event]) + cutoff, pval = find_optimal_cutoff( + node[score], node[time], node[event] + ) except Exception: cutoff, pval = None, 1.0 @@ -1262,7 +1301,10 @@ def recursive_binary_split_minN( right = node[node[score] > cutoff] # if either resulting child is smaller than required, reject split - if len(left) < min_samples_per_group or len(right) < min_samples_per_group: + if ( + len(left) < min_samples_per_group + or len(right) < min_samples_per_group + ): groups.update({i: next_gid for i in node.index}) next_gid += 1 continue @@ -1321,7 +1363,9 @@ def significance(p): p = ( ggplot(coef_summary_sorted, aes(x="coef", y="variable")) + geom_errorbarh( - aes(xmin="coef_lower_95", xmax="coef_upper_95"), height=0.2, color="skyblue" + aes(xmin="coef_lower_95", xmax="coef_upper_95"), + height=0.2, + color="skyblue", ) + geom_point(color="skyblue", size=3) + geom_text(aes(label="stars"), nudge_y=0.1, size=10) @@ -1347,7 +1391,9 @@ def build_cox_model( event_col: str, n_splits: int = 5, random_state: int = 42, - eval_time: float | None = None, # single horizon, same units as duration_col + eval_time: ( + float | None + ) = None, # single horizon, same units as duration_col low_variance_threshold: float = 0.01, cox_penalizer=0.05, return_metrics: bool = True, @@ -1365,7 +1411,9 @@ def build_cox_model( } """ - def remove_low_variance_survival_features(df, duration_col, event_col, threshold): + def remove_low_variance_survival_features( + df, duration_col, event_col, threshold + ): events = df[event_col].astype(bool) low_var = [] for feature in df.drop(columns=[duration_col, event_col]).columns: @@ -1425,13 +1473,20 @@ def remove_low_variance_survival_features(df, duration_col, event_col, threshold eps = 1e-8 if (test_min + eps) < float(eval_time) < (test_max - eps): _, auc_val = cumulative_dynamic_auc( - y_train, y_test, risk_scores, np.asarray([float(eval_time)]) + y_train, + y_test, + risk_scores, + np.asarray([float(eval_time)]), ) auc_val = float(np.atleast_1d(auc_val)[0]) auc_per_fold.append(auc_val) # Aggregate CV metrics - metrics["cv_cindex_mean"] = float(np.mean(c_indices)) if c_indices else None - metrics["cv_auc_mean"] = float(np.mean(auc_per_fold)) if auc_per_fold else None + metrics["cv_cindex_mean"] = ( + float(np.mean(c_indices)) if c_indices else None + ) + metrics["cv_auc_mean"] = ( + float(np.mean(auc_per_fold)) if auc_per_fold else None + ) # Fit final model on full data for downstream use (forest plots, HRs, etc.) final_model = CoxPHFitter(penalizer=cox_penalizer) @@ -1493,7 +1548,9 @@ def louvain_clustering(X, threshold=None, k=None): cluster_labels = np.full(len(X), np.nan, dtype=float) # Fill the array with the cluster labels from the partition dictionary for node_id, cluster_label in partition.items(): - if node_id in range(len(X)): # Check if the node_id is a valid index in X + if node_id in range( + len(X) + ): # Check if the node_id is a valid index in X cluster_labels[node_id] = cluster_label else: # If node_id is not a valid index in X, it's already set to NaN @@ -1566,7 +1623,8 @@ def plot_label_concordance_heatmap(labels1, labels2, figsize=(12, 10)): """ # Compute the cross-tabulation ct = pd.crosstab( - pd.Series(labels1, name="Labels Set 1"), pd.Series(labels2, name="Labels Set 2") + pd.Series(labels1, name="Labels Set 1"), + pd.Series(labels2, name="Labels Set 2"), ) # Normalize the cross-tabulation matrix column-wise ct_normalized = ct.div(ct.sum(axis=1), axis=0) @@ -1654,7 +1712,9 @@ def create_covariate_matrix(covariates, variable_types, ann): if variable_types.get(var) == "categorical": # One-hot-encode categorical variables with 0/1 encoding one_hot = pd.get_dummies(ann[var], prefix=var).astype(int) - covariate_features.append(one_hot.T) # Transpose to make features rows + covariate_features.append( + one_hot.T + ) # Transpose to make features rows feature_names.extend(one_hot.columns.tolist()) elif variable_types.get(var) == "numerical": # Handle numerical variables with missing values @@ -1705,7 +1765,9 @@ def generate_synthetic_batches(n_samples_per_batch=150, n_features=50): return synthetic_data, batch_labels -def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=False): +def optimal_transport_align( + embeddings, batch_labels, standardize_by_labels=False +): """ Align embeddings from two batches using Optimal Transport, preserving the order of samples. @@ -1723,7 +1785,9 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=Fals # Identify unique batches unique_batches = np.unique(batch_labels_np) if len(unique_batches) != 2: - raise ValueError("Optimal transport supports aligning exactly two batches.") + raise ValueError( + "Optimal transport supports aligning exactly two batches." + ) # Split embeddings by batch, preserving the original indices batch1_indices = np.where(batch_labels_np == unique_batches[0])[0] @@ -1733,7 +1797,9 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=Fals batch2_embeddings = embeddings.iloc[batch2_indices].to_numpy() # Compute the cost matrix (e.g., Euclidean distances) - cost_matrix = ot.dist(batch1_embeddings, batch2_embeddings, metric="euclidean") + cost_matrix = ot.dist( + batch1_embeddings, batch2_embeddings, metric="euclidean" + ) # Solve the optimal transport problem n_samples_1 = batch1_embeddings.shape[0] @@ -1805,7 +1871,9 @@ def reciprocal_pca_mnn( # Identify unique batches unique_batches = np.unique(batch_labels_np) if len(unique_batches) != 2: - raise ValueError("Reciprocal PCA supports aligning exactly two batches.") + raise ValueError( + "Reciprocal PCA supports aligning exactly two batches." + ) # Split embeddings by batch, preserving the original indices batch1_indices = np.where(batch_labels_np == unique_batches[0])[0] @@ -1835,8 +1903,12 @@ def reciprocal_pca_mnn( batch2_to_batch1 = pca1.transform(batch2_embeddings) # Use MNN to identify anchors - neighbors1 = NearestNeighbors(n_neighbors=n_neighbors).fit(batch2_to_batch1) - neighbors2 = NearestNeighbors(n_neighbors=n_neighbors).fit(batch1_to_batch2) + neighbors1 = NearestNeighbors(n_neighbors=n_neighbors).fit( + batch2_to_batch1 + ) + neighbors2 = NearestNeighbors(n_neighbors=n_neighbors).fit( + batch1_to_batch2 + ) distances1, indices1 = neighbors1.kneighbors(batch1_pca) distances2, indices2 = neighbors2.kneighbors(batch2_pca) @@ -1849,7 +1921,9 @@ def reciprocal_pca_mnn( mutual_anchors.append((i, neighbor)) if not mutual_anchors: - raise ValueError("No mutual nearest neighbors (MNN) found between the batches.") + raise ValueError( + "No mutual nearest neighbors (MNN) found between the batches." + ) # Align the datasets using anchors mutual_anchors = np.array(mutual_anchors) @@ -1887,7 +1961,9 @@ def reciprocal_pca_mnn( class CBioPortalData: - def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): + def __init__( + self, study_id, base_url="https://datahub.assets.cbioportal.org" + ): self.base_url = base_url self.study_id = study_id self.data_files = None @@ -1901,7 +1977,9 @@ def download_study_archive(self, force=False, timeout=60): return dest_file print(f"Downloading {url}...") - r = requests.get(url, stream=True, allow_redirects=True, timeout=timeout) + r = requests.get( + url, stream=True, allow_redirects=True, timeout=timeout + ) r.raise_for_status() # <-- key: fail fast on 404/403/etc. with open(dest_file, "wb") as f: @@ -1920,7 +1998,9 @@ def extract_archive(self, archive_path): tar.extractall() self.data_files = [ - f for f in os.listdir(base) if f.startswith("data_") and f.endswith(".txt") + f + for f in os.listdir(base) + if f.startswith("data_") and f.endswith(".txt") ] return base @@ -1932,7 +2012,9 @@ def read_data(self, files=None): for datatype, file in files.items(): print(f"Importing {file}...") file_path = os.path.join(self.study_id, file) - df = pd.read_csv(file_path, sep="\t", comment="#", low_memory=False) + df = pd.read_csv( + file_path, sep="\t", comment="#", low_memory=False + ) if "mutations" in file: print(f"Binarizing and converting {file} to matrix...") @@ -2005,7 +2087,9 @@ def split_data(self, samples=None, ratio=0.7): if samples is None: samples = self.data["clin"].index.tolist() - train_samples = list(pd.Series(samples).sample(frac=ratio, random_state=42)) + train_samples = list( + pd.Series(samples).sample(frac=ratio, random_state=42) + ) test_samples = list(set(samples) - set(train_samples)) train_data = {} @@ -2048,13 +2132,17 @@ def compute_correlation_loss(embeddings, batch_labels): ) # Normalize batch labels - batch_labels = (batch_labels - batch_labels.mean()) / (batch_labels.std() + 1e-8) + batch_labels = (batch_labels - batch_labels.mean()) / ( + batch_labels.std() + 1e-8 + ) # Reshape batch_labels to (num_samples, 1) for broadcasting batch_labels = batch_labels.unsqueeze(1) # Compute covariance (dot product of batch_labels and embeddings) - covariance = torch.matmul(batch_labels.T, embeddings) / (embeddings.shape[0] - 1) + covariance = torch.matmul(batch_labels.T, embeddings) / ( + embeddings.shape[0] - 1 + ) # Compute sum of squared correlations loss = torch.sum(torch.abs(covariance)) @@ -2168,7 +2256,8 @@ def get_device_memory_info(device_str): return { "allocated": torch.cuda.memory_allocated() / (1024**2), # MB "reserved": torch.cuda.memory_reserved() / (1024**2), # MB - "max_allocated": torch.cuda.max_memory_allocated() / (1024**2), # MB + "max_allocated": torch.cuda.max_memory_allocated() + / (1024**2), # MB "device_name": torch.cuda.get_device_name(0), "device_count": torch.cuda.device_count(), } diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index 15f118df..d9b5b801 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -11,8 +11,12 @@ def test_mps_device_detection(): if not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - assert device_type == "mps", f"Expected device type 'mps', got '{device_type}'" - assert device_str == "mps", f"Expected device string 'mps', got '{device_str}'" + assert ( + device_type == "mps" + ), f"Expected device type 'mps', got '{device_type}'" + assert ( + device_str == "mps" + ), f"Expected device string 'mps', got '{device_str}'" def test_mps_tensor_operations(): @@ -29,7 +33,10 @@ def test_mps_tensor_operations(): y = to_device_safe(torch.randn(50, 25), device) result = torch.mm(x, y) - assert result.shape == (100, 25), f"Unexpected result shape: {result.shape}" + assert result.shape == ( + 100, + 25, + ), f"Unexpected result shape: {result.shape}" assert ( result.device.type == "mps" ), f"Result not on MPS device: {result.device.type}" @@ -49,7 +56,9 @@ def test_mps_memory_allocation(): large_tensor = to_device_safe(torch.randn(1000, 1000), device) memory_after = torch.mps.current_allocated_memory() - assert memory_after > memory_before, "Memory did not increase after allocation." + assert ( + memory_after > memory_before + ), "Memory did not increase after allocation." def test_float64_to_float32_conversion(): @@ -66,4 +75,6 @@ def test_float64_to_float32_conversion(): x_mps = to_device_safe(x_float64, device) assert x_mps.dtype == torch.float32, f"Expected float32, got {x_mps.dtype}" - assert x_mps.device.type == "mps", f"Tensor not on MPS device: {x_mps.device.type}" + assert ( + x_mps.device.type == "mps" + ), f"Tensor not on MPS device: {x_mps.device.type}" From b3f9556b5be0a8f8258a0ff22d01a9be5bc6bb92 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:19:27 +0200 Subject: [PATCH 06/32] remove variables which were never used --- flexynesis/__main__.py | 8 -------- flexynesis/feature_selection.py | 1 - flexynesis/generate_coexpression_network.py | 1 + flexynesis/models/crossmodal_pred.py | 1 - flexynesis/models/triplet_encoder.py | 3 +-- flexynesis/utils.py | 2 -- tests/test_mps_device.py | 2 -- 7 files changed, 2 insertions(+), 16 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 7f79f5a2..5588c558 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -1079,13 +1079,8 @@ def main(): # import data tracemalloc.start() - process = psutil.Process(os.getpid()) - t1 = time.time() train_dataset, test_dataset = data_importer.import_data() - data_import_time = time.time() - t1 - data_import_ram = process.memory_info().rss - # classical ML baselines if args.model_class == "XGBoost": try: @@ -1245,12 +1240,9 @@ def main(): ) # do a hyperparameter search training multiple models and get the best configuration - t1 = time.time() model, best_params = tuner.perform_tuning( hpo_patience=args.hpo_patience ) - hpo_time = time.time() - t1 - hpo_system_ram = process.memory_info().rss # if fine-tuning is enabled; fine tune the model on a portion of test samples if args.finetuning_samples > 0: diff --git a/flexynesis/feature_selection.py b/flexynesis/feature_selection.py index 3988d59b..f0496259 100644 --- a/flexynesis/feature_selection.py +++ b/flexynesis/feature_selection.py @@ -125,7 +125,6 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): # Topping up from redundant features if fewer than topN features are selected if topN is not None and len(selected_features) < topN: - shortfall = topN - len(selected_features) # Sort redundant features by their laplacian score, prioritizing lower scores sorted_redundant_indices = sorted( redundant_features.keys(), key=lambda x: laplacian_scores[x] diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index b9dfed8c..0fe3f6fb 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -179,6 +179,7 @@ def generate_coexpression_network( sep = "\t" if input_file.endswith(".tsv") else "," expr_df = pd.read_csv(input_file, sep=sep, index_col=0) except Exception as e: + print(f"[ERROR] Failed to load file: {e}") # Fallback: try the other separator try: sep = "," if sep == "\t" else "\t" diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index b5edd955..be4efca9 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -464,7 +464,6 @@ def decode(self, dataset): """ self.eval() x_list_input = [dataset.dat[x] for x in self.input_layers] - x_list_output = [dataset.dat[x] for x in self.output_layers] x_hat_list, z, mean, log_var, outputs = self.forward(x_list_input) X = {} for i in range(len(self.output_layers)): diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 8fe1a407..9aa929de 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -533,11 +533,10 @@ def compute_feature_importance( aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: # see training_step to see how elements are accessed in batches - anchor, positive, negative, y_dict = ( + anchor, positive, negative= ( batch[0], batch[1], batch[2], - batch[3], ) # Move tensors to the specified device using MPS-safe method diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 4c6d531a..6bccdc9e 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -2062,8 +2062,6 @@ def print_data_files(self): print(df.to_string(index=False)) def get_cbioportal_data(self, study_id, files=None): - archive_path = self.download_study_archive() - study_dir = self.extract_archive(archive_path) if files is None: self.print_data_files() diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index d9b5b801..b1c467cc 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -49,11 +49,9 @@ def test_mps_memory_allocation(): if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - device = torch.device(device_str) # Test memory tracking memory_before = torch.mps.current_allocated_memory() - large_tensor = to_device_safe(torch.randn(1000, 1000), device) memory_after = torch.mps.current_allocated_memory() assert ( From c277a1b2599322a7ef33c1f20c60cba1829ec8b5 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:20:08 +0200 Subject: [PATCH 07/32] fix too many blank lines --- tests/test_mps_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index b1c467cc..23057298 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -49,8 +49,7 @@ def test_mps_memory_allocation(): if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - - # Test memory tracking +# Test memory tracking memory_before = torch.mps.current_allocated_memory() memory_after = torch.mps.current_allocated_memory() From 685a6ed7e22aecd3857e5376476f24e361a1b59a Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:24:22 +0200 Subject: [PATCH 08/32] fix undefined modules --- flexynesis/data.py | 1 + flexynesis/utils.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/flexynesis/data.py b/flexynesis/data.py index 3a7bbaf4..e1e5e00a 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -1,4 +1,5 @@ import os +import tempfile from functools import reduce from itertools import chain from pathlib import Path diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 6bccdc9e..2985550f 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -79,6 +79,11 @@ from sksurv.metrics import concordance_index_censored from umap import UMAP +try: + from geomloss import SamplesLoss +except Exception: + SamplesLoss = None + def _labels_to_1d(labels): """Coerce labels (list/array/Series/Index/DataFrame[1-col]) to a 1-D numpy array (no index alignment).""" @@ -2147,7 +2152,6 @@ def compute_correlation_loss(embeddings, batch_labels): return loss -# from geomloss import SamplesLoss def compute_transport_cost(embeddings, batch_labels, blur=0.5): """ Compute a transport cost using Sinkhorn loss to align embeddings between batches. @@ -2172,6 +2176,11 @@ def compute_transport_cost(embeddings, batch_labels, blur=0.5): "Both batches must have at least one sample for transport cost computation." ) + if SamplesLoss is None: + raise ImportError( + "geomloss is required for compute_transport_cost. Install it with: pip install geomloss" + ) + # Initialize the Sinkhorn loss function loss_fn = SamplesLoss("sinkhorn", blur=blur) From d579746c7c06f6f703015978741ca218da97b8ae Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:24:40 +0200 Subject: [PATCH 09/32] run black again --- flexynesis/__main__.py | 153 +++--------- flexynesis/data.py | 247 +++++--------------- flexynesis/feature_selection.py | 20 +- flexynesis/generate_coexpression_network.py | 20 +- flexynesis/inference.py | 20 +- flexynesis/main.py | 49 ++-- flexynesis/models/crossmodal_pred.py | 36 +-- flexynesis/models/direct_pred.py | 35 +-- flexynesis/models/gnn_early.py | 40 +--- flexynesis/models/supervised_vae.py | 36 +-- flexynesis/models/triplet_encoder.py | 51 ++-- flexynesis/modules.py | 12 +- flexynesis/utils.py | 145 +++--------- tests/test_mps_device.py | 18 +- 14 files changed, 229 insertions(+), 653 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 5588c558..86379548 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -17,9 +17,7 @@ def print_test_installation(): ) print(" tar -xzvf dataset1.tgz") print() - print( - " # Test the installation (should finish within a minute on a typical CPU)" - ) + print(" # Test the installation (should finish within a minute on a typical CPU)") print( " flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --hpo_iter 1 --features_top_percentile 5 --data_types gex,cnv" ) @@ -41,9 +39,7 @@ def print_help(): print( " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) - print( - " (Required) The kind of model class to instantiate" - ) + print(" (Required) The kind of model class to instantiate") print(" --data_types DATA_TYPES") print( " (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'" @@ -89,13 +85,9 @@ def print_full_help(): " Use a saved .pth/.safetensors model for inference (skip training)" ) print(" --artifacts ARTIFACTS") - print( - " Path to training-time artifacts .joblib or .json" - ) + print(" Path to training-time artifacts .joblib or .json") print(" --data_path_test DATA_PATH_TEST") - print( - " Folder with test-only dataset for inference" - ) + print(" Folder with test-only dataset for inference") print(" --join_key JOIN_KEY Column name in 'clin.csv' for sample IDs") # --- existing flags (keep full list) --- @@ -107,9 +99,7 @@ def print_full_help(): print( " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) - print( - " (Required) The kind of model class to instantiate" - ) + print(" (Required) The kind of model class to instantiate") print(" --gnn_conv_type {GC,GCN,SAGE}") print( " If model_class is set to GNN, choose which graph convolution type to use" @@ -184,9 +174,7 @@ def print_full_help(): print( " --outdir OUTDIR Path to the output folder to save the model outputs (default: current working directory)" ) - print( - " --prefix PREFIX Job prefix to use for output files (default: 'job')" - ) + print(" --prefix PREFIX Job prefix to use for output files (default: 'job')") print(" --log_transform {True,False}") print( " whether to apply log-transformation to input data matrices (default: False)" @@ -227,9 +215,7 @@ def print_full_help(): print( " --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available." ) - print( - " --feature_importance_method {IntegratedGradients,GradientShap,Both}" - ) + print(" --feature_importance_method {IntegratedGradients,GradientShap,Both}") print( " Choose feature importance score method (default: IntegratedGradients)" ) @@ -245,9 +231,7 @@ def print_full_help(): print( " Path to user-provided gene-gene interaction network file." ) - print( - " Must have at least 3 columns: GeneA, GeneB, Score." - ) + print(" Must have at least 3 columns: GeneA, GeneB, Score.") print( " If provided, this will be used instead of STRING DB." ) @@ -805,9 +789,7 @@ def main(): config_path = model_base + "_config.json" if not os.path.exists(config_path): # Try alongside the safetensors file with standard naming - base = args.pretrained_model.replace( - ".final_model.safetensors", "" - ) + base = args.pretrained_model.replace(".final_model.safetensors", "") config_path = base + ".final_model_config.json" if not os.path.exists(config_path): raise FileNotFoundError( @@ -838,9 +820,7 @@ def main(): weights_only=True, ) except TypeError: - model = torch.load( - args.pretrained_model, map_location=device - ) + model = torch.load(args.pretrained_model, map_location=device) # Extract model class name for metrics args.model_class = model.__class__.__name__ @@ -860,9 +840,7 @@ def main(): # Convert to GNN dataset if needed if args.model_class == "GNN": - print( - "[INFO] Overlaying the dataset with network data from STRINGDB" - ) + print("[INFO] Overlaying the dataset with network data from STRINGDB") from .data import MultiOmicDatasetNW from .main import STRING @@ -870,14 +848,10 @@ def main(): string_organism = importer.artifacts.get( "string_organism", 9606 ) # default human - string_node_name = importer.artifacts.get( - "string_node_name", "HGNC" - ) + string_node_name = importer.artifacts.get("string_node_name", "HGNC") # Load STRING network obj = STRING( - os.path.join( - args.data_path_test, "_".join(["processed", args.prefix]) - ), + os.path.join(args.data_path_test, "_".join(["processed", args.prefix])), string_organism, string_node_name, ) @@ -893,9 +867,7 @@ def main(): # Move dataset to same device as model if hasattr(test_dataset, "to_device"): test_dataset.to_device(device) - print( - f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples" - ) + print(f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples") # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- @@ -979,9 +951,7 @@ def main(): memory_info = get_device_memory_info(device_str) print(f"[INFO] Device name: {memory_info['device_name']}") if device_str == "cuda": - print( - f"[INFO] Available CUDA devices: {memory_info['device_count']}" - ) + print(f"[INFO] Available CUDA devices: {memory_info['device_count']}") # gnn if args.model_class == "GNN": @@ -1020,9 +990,7 @@ def main(): f"Input --data_path doesn't exist at: {args.data_path}" ) if not os.path.exists(args.outdir): - raise FileNotFoundError( - f"Path to --outdir doesn't exist at: {args.outdir}" - ) + raise FileNotFoundError(f"Path to --outdir doesn't exist at: {args.outdir}") available_models = { "DirectPred": (DirectPred, "DirectPred"), @@ -1049,9 +1017,7 @@ def main(): # covariates if args.covariates: - if ( - args.model_class == "GNN" - ): # Covariates not yet supported for GNNs + if args.model_class == "GNN": # Covariates not yet supported for GNNs warnings.warn( "\n\n!!! Covariates are currently not supported for GNN models, they will be ignored. !!!\n" ) @@ -1102,9 +1068,7 @@ def main(): n_jobs=args.threads, ) metrics.to_csv( - os.path.join( - args.outdir, ".".join([args.prefix, "stats.csv"]) - ), + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), header=True, index=False, ) @@ -1116,9 +1080,7 @@ def main(): header=True, index=False, ) - print( - f"{args.model_class} evaluation complete. Results saved." - ) + print(f"{args.model_class} evaluation complete. Results saved.") sys.exit(0) else: raise ValueError( @@ -1139,9 +1101,7 @@ def main(): n_jobs=int(args.threads), ) metrics.to_csv( - os.path.join( - args.outdir, ".".join([args.prefix, "stats.csv"]) - ), + os.path.join(args.outdir, ".".join([args.prefix, "stats.csv"])), header=True, index=False, ) @@ -1153,9 +1113,7 @@ def main(): header=True, index=False, ) - print( - f"{args.model_class} evaluation complete. Results saved." - ) + print(f"{args.model_class} evaluation complete. Results saved.") sys.exit(0) else: raise ValueError( @@ -1165,24 +1123,16 @@ def main(): if args.model_class == "GNN" and train_dataset is not None: # Check if user provided a custom graph if args.user_graph is not None: - print( - f"[INFO] Using user-provided network from: {args.user_graph}" - ) + print(f"[INFO] Using user-provided network from: {args.user_graph}") from flexynesis.data import read_user_graph graph_df = read_user_graph(args.user_graph) - print( - f"[INFO] Loaded {len(graph_df)} interactions from user graph" - ) + print(f"[INFO] Loaded {len(graph_df)} interactions from user graph") else: # Fallback to STRING DB (default behavior) - print( - "[INFO] No user graph provided. Using STRING DB network (default)" - ) + print("[INFO] No user graph provided. Using STRING DB network (default)") obj = STRING( - os.path.join( - args.data_path, "_".join(["processed", args.prefix]) - ), + os.path.join(args.data_path, "_".join(["processed", args.prefix])), args.string_organism, args.string_node_name, ) @@ -1240,9 +1190,7 @@ def main(): ) # do a hyperparameter search training multiple models and get the best configuration - model, best_params = tuner.perform_tuning( - hpo_patience=args.hpo_patience - ) + model, best_params = tuner.perform_tuning(hpo_patience=args.hpo_patience) # if fine-tuning is enabled; fine tune the model on a portion of test samples if args.finetuning_samples > 0: @@ -1256,9 +1204,7 @@ def main(): all_indices = range(len(test_dataset)) import random as _random - finetune_indices = _random.sample( - list(all_indices), finetuneSampleN - ) + finetune_indices = _random.sample(list(all_indices), finetuneSampleN) holdout_indices = list(set(all_indices) - set(finetune_indices)) finetune_dataset = test_dataset.subset(finetune_indices) holdout_dataset = test_dataset.subset(holdout_indices) @@ -1277,16 +1223,12 @@ def main(): if train_dataset is not None: embeddings_train = model.transform(train_dataset) embeddings_train.to_csv( - os.path.join( - args.outdir, ".".join([args.prefix, "embeddings_train.csv"]) - ), + os.path.join(args.outdir, ".".join([args.prefix, "embeddings_train.csv"])), header=True, ) embeddings_test = model.transform(test_dataset) embeddings_test.to_csv( - os.path.join( - args.outdir, ".".join([args.prefix, "embeddings_test.csv"]) - ), + os.path.join(args.outdir, ".".join([args.prefix, "embeddings_test.csv"])), header=True, ) @@ -1316,10 +1258,7 @@ def main(): import pandas as pd df_imp = pd.concat( - [ - model.feature_importances[x] - for x in model.target_variables - ], + [model.feature_importances[x] for x in model.target_variables], ignore_index=True, ) df_imp["explainer"] = explainer @@ -1366,9 +1305,7 @@ def main(): args.model_class, ) predicted_labels.to_csv( - os.path.join( - args.outdir, ".".join([args.prefix, "predicted_labels.csv"]) - ), + os.path.join(args.outdir, ".".join([args.prefix, "predicted_labels.csv"])), header=True, index=False, ) @@ -1478,9 +1415,7 @@ def main(): if not args.safetensors: torch.save( model, - os.path.join( - args.outdir, ".".join([args.prefix, "final_model.pth"]) - ), + os.path.join(args.outdir, ".".join([args.prefix, "final_model.pth"])), ) else: save_file( @@ -1543,9 +1478,7 @@ def main(): "," ), # Original modalities from CLI before concatenation "target_variables": ( - args.target_variables.split(",") - if args.target_variables - else [] + args.target_variables.split(",") if args.target_variables else [] ), "feature_lists": ( data_importer.train_features @@ -1553,9 +1486,7 @@ def main(): else {} ), "transforms": ( - data_importer.scalers - if hasattr(data_importer, "scalers") - else {} + data_importer.scalers if hasattr(data_importer, "scalers") else {} ), "label_encoders": ( data_importer.label_encoders @@ -1596,9 +1527,7 @@ def main(): "string_node_name": artifacts["string_node_name"], "feature_lists": { modality: list(features) - for modality, features in artifacts[ - "feature_lists" - ].items() + for modality, features in artifacts["feature_lists"].items() }, "transforms": {}, "label_encoders": {}, @@ -1625,9 +1554,7 @@ def main(): if hasattr(scaler, "var_"): scaler_dict["var"] = scaler.var_.tolist() if hasattr(scaler, "n_features_in_"): - scaler_dict["n_features_in"] = int( - scaler.n_features_in_ - ) + scaler_dict["n_features_in"] = int(scaler.n_features_in_) if hasattr(scaler, "feature_names_in_"): scaler_dict["feature_names_in"] = ( scaler.feature_names_in_.tolist() @@ -1654,9 +1581,7 @@ def main(): elif isinstance(encoder, OrdinalEncoder): encoder_dict = { "type": "OrdinalEncoder", - "categories": [ - cat.tolist() for cat in encoder.categories_ - ], + "categories": [cat.tolist() for cat in encoder.categories_], "handle_unknown": encoder.handle_unknown, "unknown_value": encoder.unknown_value, } @@ -1668,9 +1593,7 @@ def main(): else val ) if hasattr(encoder, "n_features_in_"): - encoder_dict["n_features_in"] = int( - encoder.n_features_in_ - ) + encoder_dict["n_features_in"] = int(encoder.n_features_in_) if hasattr(encoder, "feature_names_in_"): encoder_dict["feature_names_in"] = ( encoder.feature_names_in_.tolist() diff --git a/flexynesis/data.py b/flexynesis/data.py index e1e5e00a..054a3ea4 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -147,9 +147,7 @@ def get_user_features(self): """ if self.restrict_to_features is not None: if not os.path.isfile(self.restrict_to_features): - raise FileNotFoundError( - f"File not found: {self.restrict_to_features}" - ) + raise FileNotFoundError(f"File not found: {self.restrict_to_features}") try: with open(self.restrict_to_features, "r") as fp: # Read and process the file @@ -183,19 +181,15 @@ def import_data(self): train_dat = self.subsample(train_dat, self.downsample) if self.restrict_to_features is not None: - train_dat = self.filter_by_features( - train_dat, self.restrict_to_features - ) - test_dat = self.filter_by_features( - test_dat, self.restrict_to_features - ) + train_dat = self.filter_by_features(train_dat, self.restrict_to_features) + test_dat = self.filter_by_features(test_dat, self.restrict_to_features) # check for any problems with the the input files self.validate_input_data(train_dat, test_dat) # cleanup uninformative features/samples, subset annotation data, do feature selection on training data - train_dat, train_ann, train_samples, train_features = ( - self.process_data(train_dat, split="train") + train_dat, train_ann, train_samples, train_features = self.process_data( + train_dat, split="train" ) test_dat, test_ann, test_samples, test_features = self.process_data( test_dat, split="test" @@ -212,12 +206,8 @@ def import_data(self): # Normalize the training data (for testing data, use normalisation factors # learned from training data to apply on test data (see fit = False) - train_dat = self.normalize_data( - train_dat, scaler_type="standard", fit=True - ) - test_dat = self.normalize_data( - test_dat, scaler_type="standard", fit=False - ) + train_dat = self.normalize_data(train_dat, scaler_type="standard", fit=True) + test_dat = self.normalize_data(test_dat, scaler_type="standard", fit=False) # if covariates are defined, create a covariate matrix and add to the dictionary of data matrices if self.covariates: @@ -235,12 +225,8 @@ def import_data(self): train_dat, test_dat = self.harmonize(train_dat, test_dat) # encode the variable annotations, convert data matrices and annotations pytorch datasets - training_dataset = self.get_torch_dataset( - train_dat, train_ann, train_samples - ) - testing_dataset = self.get_torch_dataset( - test_dat, test_ann, test_samples - ) + training_dataset = self.get_torch_dataset(train_dat, train_ann, train_samples) + testing_dataset = self.get_torch_dataset(test_dat, test_ann, test_samples) # for early fusion, concatenate all data matrices and feature lists if self.concatenate: @@ -253,9 +239,7 @@ def import_data(self): } training_dataset.features = { "all": list( - chain( - *[training_dataset.features[x] for x in modality_order] - ) + chain(*[training_dataset.features[x] for x in modality_order]) ) } @@ -266,9 +250,7 @@ def import_data(self): } testing_dataset.features = { "all": list( - chain( - *[testing_dataset.features[x] for x in modality_order] - ) + chain(*[testing_dataset.features[x] for x in modality_order]) ) } # Save final feature lists AFTER concatenation (for inference mode) @@ -342,11 +324,7 @@ def filter_by_features(self, dat, features): subset train/test data to only include those features """ dat_filtered = { - key: ( - df - if key == "clin" - else df.loc[df.index.intersection(features)] - ) + key: (df if key == "clin" else df.loc[df.index.intersection(features)]) for key, df in dat.items() } @@ -460,9 +438,7 @@ def cleanup_data(self, df_dict): for key in cleaned_dfs.keys(): original_samples_count = cleaned_dfs[key].shape[1] cleaned_dfs[key] = cleaned_dfs[key].loc[:, common_mask] - removed_samples_count = ( - original_samples_count - cleaned_dfs[key].shape[1] - ) + removed_samples_count = original_samples_count - cleaned_dfs[key].shape[1] print( f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%)." ) @@ -503,9 +479,7 @@ def select_features(self, dat): topN=counts[layer], correlation_threshold=self.correlation_threshold, ) - dat_filtered[layer] = ( - X_filt.T - ) # transpose after laplacian filtering again + dat_filtered[layer] = X_filt.T # transpose after laplacian filtering again # Features will be stored after concatenation feature_logs[layer] = log_df # update main feature logs with events from this function @@ -513,9 +487,7 @@ def select_features(self, dat): return dat_filtered def harmonize(self, dat1, dat2): - print( - "\n[INFO] ----------------- Harmonizing Data Sets ----------------- " - ) + print("\n[INFO] ----------------- Harmonizing Data Sets ----------------- ") # common data layers common_layers = dat1.keys() & dat2.keys() # Get common features @@ -525,9 +497,7 @@ def harmonize(self, dat1, dat2): # Subset both datasets to only include common features dat1 = {x: dat1[x].loc[common_features[x]] for x in common_layers} dat2 = {x: dat2[x].loc[common_features[x]] for x in common_layers} - print( - "\n[INFO] ----------------- Finished Harmonizing ----------------- " - ) + print("\n[INFO] ----------------- Finished Harmonizing ----------------- ") return dat1, dat2 @@ -542,17 +512,11 @@ def normalize_data(self, data, scaler_type="standard", fit=True): # while scaling methods assume features to be on the columns. if fit: if scaler_type == "standard": - self.scalers = { - x: StandardScaler().fit(data[x].T) for x in data.keys() - } + self.scalers = {x: StandardScaler().fit(data[x].T) for x in data.keys()} elif scaler_type == "min_max": - self.scalers = { - x: MinMaxScaler().fit(data[x].T) for x in data.keys() - } + self.scalers = {x: MinMaxScaler().fit(data[x].T) for x in data.keys()} else: - raise ValueError( - "Invalid scaler_type. Choose 'standard' or 'min_max'." - ) + raise ValueError("Invalid scaler_type. Choose 'standard' or 'min_max'.") normalized_data = { x: pd.DataFrame( @@ -567,9 +531,7 @@ def normalize_data(self, data, scaler_type="standard", fit=True): def get_torch_dataset(self, dat, ann, samples): features = {x: dat[x].index for x in dat.keys()} - dat = { - x: torch.from_numpy(np.array(dat[x].T)).float() for x in dat.keys() - } + dat = {x: torch.from_numpy(np.array(dat[x].T)).float() for x in dat.keys()} ann, variable_types, label_mappings = self.encode_labels(ann) @@ -603,23 +565,19 @@ def encode_column(series): # NEW: Store encoder for inference mode self.label_encoders[series.name] = self.encoders[series.name] else: - encoded_series = self.encoders[series.name].transform( - series.to_frame() - ) + encoded_series = self.encoders[series.name].transform(series.to_frame()) # also save label mappings label_mappings[series.name] = { int(code): label - for code, label in enumerate( - self.encoders[series.name].categories_[0] - ) + for code, label in enumerate(self.encoders[series.name].categories_[0]) } return encoded_series.ravel() # Select only the categorical columns - df_categorical = df.select_dtypes( - include=["object", "category"] - ).apply(encode_column) + df_categorical = df.select_dtypes(include=["object", "category"]).apply( + encode_column + ) # Combine the encoded categorical data with the numerical data df_encoded = pd.concat( @@ -632,9 +590,7 @@ def encode_column(series): variable_types.update( { col: "numerical" - for col in df.select_dtypes( - exclude=["object", "category"] - ).columns + for col in df.select_dtypes(exclude=["object", "category"]).columns } ) @@ -652,9 +608,7 @@ def check_rownames(dat, split): for file_name, df in dat.items(): if not df.index.is_unique: identifier_type = ( - "Sample labels" - if file_name == "clin" - else "Feature names" + "Sample labels" if file_name == "clin" else "Feature names" ) errors.append( f"Error in {split}/{file_name}.csv: {identifier_type} in the first column must be unique." @@ -681,9 +635,7 @@ def check_common_features(train_dat, test_dat): if file_name != "clin" and file_name in test_dat: train_features = set(train_dat[file_name].index) test_features = set(test_dat[file_name].index) - common_features = train_features.intersection( - test_features - ) + common_features = train_features.intersection(test_features) if not common_features: errors.append( f"Error: No common features found between train/{file_name}.csv and test/{file_name}.csv." @@ -707,9 +659,7 @@ def check_common_features(train_dat, test_dat): print("[INFO] Found problems with the input data:\n") for i, error in enumerate(errors, 1): print(f"[ERROR] {i}. {error}") - raise Exception( - "[ERROR] Please correct the above errors and try again." - ) + raise Exception("[ERROR] Please correct the above errors and try again.") if not warnings and not errors: print("[INFO] Data structure is valid with no errors or warnings.") @@ -788,9 +738,7 @@ def import_data(self): else: # Filter out 'covariates' from modalities_to_load # Covariates will be created from clinical data later - modalities_to_load = [ - m for m in self.modalities if m != "covariates" - ] + modalities_to_load = [m for m in self.modalities if m != "covariates"] # Load each modality (skip 'covariates' - it's in clin.csv) for modality in modalities_to_load: @@ -798,9 +746,7 @@ def import_data(self): continue # Covariates are in clin.csv, not a separate file file_path = os.path.join(self.test_data_path, f"{modality}.csv") if not os.path.exists(file_path): - raise FileNotFoundError( - f"[ERROR] Required file not found: {file_path}" - ) + raise FileNotFoundError(f"[ERROR] Required file not found: {file_path}") df = pd.read_csv(file_path, index_col=0) # Transpose if needed: data files have features as rows, samples as columns @@ -860,9 +806,7 @@ def import_data(self): covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: if self.verbose: - print( - f"[INFO] Creating covariate matrix for: {covariate_vars}" - ) + print(f"[INFO] Creating covariate matrix for: {covariate_vars}") variable_types_clin = get_variable_types(labels_df) covariates_df = create_covariate_matrix( covariate_vars, variable_types_clin, labels_df @@ -892,9 +836,7 @@ def import_data(self): df_reordered = pd.DataFrame( test_data[modality].numpy(), index=old_samples ).loc[common_samples] - test_data[modality] = torch.from_numpy( - df_reordered.values - ).float() + test_data[modality] = torch.from_numpy(df_reordered.values).float() samples = common_samples @@ -917,30 +859,23 @@ def import_data(self): else: # OrdinalEncoder if valid_mask.sum() > 0: encoded[valid_mask] = encoder.transform( - labels_df[col][valid_mask].values.reshape( - -1, 1 - ) + labels_df[col][valid_mask].values.reshape(-1, 1) ).ravel() ann_dict[col] = torch.from_numpy(encoded) variable_types[col] = "categorical" label_mappings[col] = { - int(c): l - for c, l in enumerate(encoder.categories_[0]) + int(c): l for c, l in enumerate(encoder.categories_[0]) } label_mappings[col][-1] = "Unknown" # For missing values else: - ann_dict[col] = torch.from_numpy( - labels_df[col].values - ).float() + ann_dict[col] = torch.from_numpy(labels_df[col].values).float() variable_types[col] = "numerical" # Create features dict # For early fusion, get features from scalers since feature_lists only has 'all' if self.modalities == ["all"]: - modalities_for_features = self.artifacts.get( - "original_modalities", [] - ) + modalities_for_features = self.artifacts.get("original_modalities", []) # Get features from scalers for each modality features = { modality: list(self.scalers[modality].feature_names_in_) @@ -948,8 +883,7 @@ def import_data(self): } else: features = { - modality: self.feature_names[modality] - for modality in self.modalities + modality: self.feature_names[modality] for modality in self.modalities } # CRITICAL: Reorder test_data dict to match self.modalities order (model expects specific order) @@ -978,28 +912,20 @@ def import_data(self): ) # Concatenate the data tensors - concatenated_data = torch.cat( - [test_data[x] for x in modality_order], dim=1 - ) + concatenated_data = torch.cat([test_data[x] for x in modality_order], dim=1) # Chain features in the same order - all_features = list( - chain(*[dataset.features[x] for x in modality_order]) - ) + all_features = list(chain(*[dataset.features[x] for x in modality_order])) # Filter to expected features from artifacts expected_all_features = self.feature_names["all"] feature_indices = [ - i - for i, f in enumerate(all_features) - if f in expected_all_features + i for i, f in enumerate(all_features) if f in expected_all_features ] # Update dataset with concatenated data dataset.dat = {"all": concatenated_data[:, feature_indices]} - dataset.features = { - "all": [all_features[i] for i in feature_indices] - } + dataset.features = {"all": [all_features[i] for i in feature_indices]} return dataset @@ -1093,17 +1019,14 @@ def get_feature_subset(self, feature_df): A pandas DataFrame that concatenates the data matrices for the specified features from all layers. """ # Convert the DataFrame to a dictionary - feature_dict = ( - feature_df.groupby("layer")["name"].apply(list).to_dict() - ) + feature_dict = feature_df.groupby("layer")["name"].apply(list).to_dict() dfs = [] for layer, features in feature_dict.items(): if layer in self.dat: # Create a dictionary to look up indices by feature name for each layer feature_index_dict = { - feature: i - for i, feature in enumerate(self.features[layer]) + feature: i for i, feature in enumerate(self.features[layer]) } # Get the indices for the requested features indices = [ @@ -1182,9 +1105,7 @@ def __getitem__(self, index): import random negative_label = random.choice(list(self.labels_set - set([label]))) - negative_index = np.random.choice( - self.label_to_indices[negative_label] - ) + negative_index = np.random.choice(self.label_to_indices[negative_label]) pos = self.dataset[positive_index][0] # positive example neg = self.dataset[negative_index][0] # negative example @@ -1216,9 +1137,7 @@ def __init__(self, multiomic_dataset, interaction_df, modality_order=None): self.multiomic_dataset = multiomic_dataset self.interaction_df = interaction_df self.modality_order = ( - modality_order - if modality_order - else sorted(multiomic_dataset.dat.keys()) + modality_order if modality_order else sorted(multiomic_dataset.dat.keys()) ) # Compute union of features in the data matrices that also appear in the network @@ -1244,10 +1163,7 @@ def __init__(self, multiomic_dataset, interaction_df, modality_order=None): def find_union_features(self): # Find the union of all features in the multiomic dataset all_omic_features = set().union( - *( - set(features) - for features in self.multiomic_dataset.features.values() - ) + *(set(features) for features in self.multiomic_dataset.features.values()) ) # Find the union of proteins involved in interactions interaction_genes = set(self.interaction_df["protein1"]).union( @@ -1299,9 +1215,7 @@ def precompute_node_features(self): ) # Fill in the available data - all_features[:, feature_positions, i] = data_matrix[ - :, valid_indices - ] + all_features[:, feature_positions, i] = data_matrix[:, valid_indices] # Precompute medians for all data types, ignoring NaN values medians = torch.nanmedian( @@ -1324,8 +1238,7 @@ def subset(self, indices): def __getitem__(self, idx): node_features_tensor = self.node_features_tensor[idx] y_dict = { - target_name: self.labels[target_name][idx] - for target_name in self.labels + target_name: self.labels[target_name][idx] for target_name in self.labels } return node_features_tensor, y_dict, self.samples[idx] @@ -1342,9 +1255,7 @@ def print_stats(self): # Calculate degree for each node degrees = torch.zeros(num_nodes, dtype=torch.long) - degrees.index_add_( - 0, self.edge_index[0], torch.ones_like(self.edge_index[0]) - ) + degrees.index_add_(0, self.edge_index[0], torch.ones_like(self.edge_index[0])) degrees.index_add_( 0, self.edge_index[1], torch.ones_like(self.edge_index[1]) ) # For undirected graphs @@ -1353,9 +1264,7 @@ def print_stats(self): non_singletons = degrees[degrees > 0] mean_edges_per_node = ( - non_singletons.float().mean().item() - if len(non_singletons) > 0 - else 0 + non_singletons.float().mean().item() if len(non_singletons) > 0 else 0 ) median_edges_per_node = ( non_singletons.median().item() if len(non_singletons) > 0 else 0 @@ -1465,10 +1374,7 @@ def get(self, idx: int): @property def raw_file_names(self) -> list[str]: - return [ - f"{self.organism}.protein.{f}.v{self.version}.txt" - for f in self.files - ] + return [f"{self.organism}.protein.{f}.v{self.version}.txt" for f in self.files] @property def processed_file_names(self) -> str: @@ -1482,9 +1388,7 @@ def _ensure_raw_in_cache(self) -> None: dest = self._cache_raw_dir / fname if dest.exists(): continue - url = self.url.format( - organism=self.organism, data=d, version=self.version - ) + url = self.url.format(organism=self.organism, data=d, version=self.version) gz_path = download_url(url, str(self._cache_raw_dir)) extract_gz(gz_path, str(self._cache_raw_dir)) os.unlink(gz_path) @@ -1546,15 +1450,11 @@ def read_user_graph(fpath, sep=None, header="infer", **pd_read_csv_kw): sniffer = csv.Sniffer() dialect = sniffer.sniff(sample, delimiters="\t,| ") sep = dialect.delimiter - print( - f"[INFO] Auto-detected separator using CSV Sniffer: {repr(sep)}" - ) + print(f"[INFO] Auto-detected separator using CSV Sniffer: {repr(sep)}") except csv.Error: # Fallback to tab if Sniffer fails sep = "\t" - print( - f"[INFO] CSV Sniffer failed, using default separator: {repr(sep)}" - ) + print(f"[INFO] CSV Sniffer failed, using default separator: {repr(sep)}") # Read the file df = pd.read_csv(fpath, sep=sep, header=header, **pd_read_csv_kw) @@ -1704,9 +1604,7 @@ def score_column_match(col, col_idx, category, total_cols): # Convert score to numeric if not already if not pd.api.types.is_numeric_dtype(result_df["combined_score"]): try: - result_df["combined_score"] = pd.to_numeric( - result_df["combined_score"] - ) + result_df["combined_score"] = pd.to_numeric(result_df["combined_score"]) except ValueError: raise ValueError( f"Score column must contain numeric values. " @@ -1721,9 +1619,7 @@ def score_column_match(col, col_idx, category, total_cols): f"[WARNING] Removed {original_len - len(result_df)} rows with missing values" ) - print( - f"[INFO] Successfully loaded user graph: {len(result_df)} interactions" - ) + print(f"[INFO] Successfully loaded user graph: {len(result_df)} interactions") print( f"[INFO] Score range: [{result_df['combined_score'].min():.4f}, {result_df['combined_score'].max():.4f}]" ) @@ -1764,19 +1660,13 @@ def read_stringdb_links(fname, top_neighbors=5): ] ) # Sort the expanded DataFrame by 'combined_score' in descending order - df_expanded_sorted = df_expanded.sort_values( - by="combined_score", ascending=False - ) + df_expanded_sorted = df_expanded.sort_values(by="combined_score", ascending=False) # Reduce to unique interactions to avoid counting duplicates df_expanded_unique = df_expanded_sorted.drop_duplicates( subset=["protein", "partner"] ) - top_interactions = df_expanded_unique.groupby("protein").head( - top_neighbors - ) - df = top_interactions.rename( - columns={"protein": "protein1", "partner": "protein2"} - ) + top_interactions = df_expanded_unique.groupby("protein").head(top_neighbors) + df = top_interactions.rename(columns={"protein": "protein1", "partner": "protein2"}) df[["protein1", "protein2"]] = df[["protein1", "protein2"]].map( lambda a: a.split(".")[-1] ) @@ -1799,10 +1689,7 @@ def read_stringdb_aliases(fname: str, node_name: str) -> dict[str, str]: protein_id_to_gene_id[data[0].split(".")[1]] = data[1] elif data[-1].endswith(source[1]): # TODO: Check here if the values are the same - if ( - protein_id_to_gene_id.get(data[0].split(".")[1], None) - is None - ): + if protein_id_to_gene_id.get(data[0].split(".")[1], None) is None: protein_id_to_gene_id[data[0].split(".")[1]] = data[1] else: continue @@ -1818,9 +1705,7 @@ def read_stringdb_graph(node_name, edges_data_path, nodes_data_path): if node_name in ("gene_name", "gene_id"): node_name_mapping = read_stringdb_aliases(nodes_data_path, node_name) else: - raise NotImplementedError( - "Node name must be either 'gene_name' or 'gene_id'." - ) + raise NotImplementedError("Node name must be either 'gene_name' or 'gene_id'.") def fn(a): try: @@ -1830,9 +1715,7 @@ def fn(a): out = pd.NA return out - graph_df[["protein1", "protein2"]] = graph_df[ - ["protein1", "protein2"] - ].map(fn) + graph_df[["protein1", "protein2"]] = graph_df[["protein1", "protein2"]].map(fn) return graph_df @@ -1848,9 +1731,7 @@ def split_by_median(tensor_dict): if tensor.dtype in {torch.float16, torch.float32, torch.float64}: # Remove NaNs and compute median tensor_no_nan = tensor[torch.isfinite(tensor)] - median_val = tensor_no_nan.sort().values[ - tensor_no_nan.numel() // 2 - ] + median_val = tensor_no_nan.sort().values[tensor_no_nan.numel() // 2] # Convert to categorical, but preserve NaNs tensor_cat = (tensor > median_val).float() diff --git a/flexynesis/feature_selection.py b/flexynesis/feature_selection.py index f0496259..2b65cb1b 100644 --- a/flexynesis/feature_selection.py +++ b/flexynesis/feature_selection.py @@ -150,9 +150,7 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): "correlated_with": X.columns[ redundant_features[idx]["correlated_with"] ], - "correlation_score": redundant_features[idx][ - "correlation_score" - ], + "correlation_score": redundant_features[idx]["correlation_score"], } for idx in redundant_features ] @@ -162,9 +160,7 @@ def remove_redundant_features(X, laplacian_scores, threshold, topN=None): return X.columns[selected_features], pd.DataFrame() -def filter_by_laplacian( - X, layer, k=5, t=None, topN=100, correlation_threshold=0.9 -): +def filter_by_laplacian(X, layer, k=5, t=None, topN=100, correlation_threshold=0.9): """ Filters features in a dataset based on Laplacian scores and removes highly correlated features, retaining only the top N features with the lowest scores and optionally considering correlation. @@ -224,9 +220,7 @@ def filter_by_laplacian( " samples ", ) - feature_log = pd.DataFrame( - {"feature": X.columns, "laplacian_score": np.nan} - ) + feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": np.nan}) # only apply filtering if topN < n_features if topN >= X.shape[1]: print( @@ -238,9 +232,7 @@ def filter_by_laplacian( # compute laplacian scores scores = laplacian_score(X.values, k, t) - feature_log = pd.DataFrame( - {"feature": X.columns, "laplacian_score": scores} - ) + feature_log = pd.DataFrame({"feature": X.columns, "laplacian_score": scores}) # Sort the features based on their Laplacian Scores sorted_indices = np.argsort(scores) @@ -277,8 +269,6 @@ def filter_by_laplacian( X_selected = X[selected_features] feature_log["selected"] = False - feature_log.loc[ - feature_log["feature"].isin(selected_features), "selected" - ] = True + feature_log.loc[feature_log["feature"].isin(selected_features), "selected"] = True return X_selected, feature_log diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index 0fe3f6fb..285399ef 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -79,9 +79,7 @@ def build_network( pbar.update(end_i - i) data = ranks elif method != "pearson": - raise ValueError( - f"Unknown method: {method}. Use 'spearman' or 'pearson'" - ) + raise ValueError(f"Unknown method: {method}. Use 'spearman' or 'pearson'") # Standardize for correlation computation print("Standardizing data...") @@ -188,17 +186,13 @@ def generate_coexpression_network( print(f"[ERROR] Failed to load file: {e2}") sys.exit(1) - print( - f" Expression matrix: {expr_df.shape[0]} genes × {expr_df.shape[1]} samples" - ) + print(f" Expression matrix: {expr_df.shape[0]} genes × {expr_df.shape[1]} samples") # Check for missing values na_count = expr_df.isna().sum().sum() if na_count > 0: genes_with_na = expr_df.isna().any(axis=1).sum() - print( - f" [WARNING] Found {na_count} missing values in {genes_with_na} genes" - ) + print(f" [WARNING] Found {na_count} missing values in {genes_with_na} genes") print(f" [INFO] Removing genes with missing data.") expr_df = expr_df.dropna() print( @@ -215,9 +209,7 @@ def generate_coexpression_network( network_df = pd.DataFrame(edges) if len(network_df) == 0: - print( - "[WARNING] No edges found! Try lowering min_correlation threshold." - ) + print("[WARNING] No edges found! Try lowering min_correlation threshold.") print("[ERROR] No edges in network! Exiting.") sys.exit(1) @@ -228,9 +220,7 @@ def generate_coexpression_network( lambda row: tuple(sorted([row["GeneA"], row["GeneB"]])), axis=1 ) original_len = len(network_df) - network_df = network_df.drop_duplicates(subset="pair").drop( - columns="pair" - ) + network_df = network_df.drop_duplicates(subset="pair").drop(columns="pair") print(f" Removed {original_len - len(network_df)} duplicate edges") # Save network (auto-detect format from extension) diff --git a/flexynesis/inference.py b/flexynesis/inference.py index d5d570e8..6245cc51 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -102,9 +102,7 @@ def _build_dataset_namespace(config, artifacts): cats = ( enc.categories_[0].tolist() if hasattr(enc, "categories_") - else ( - enc.get("categories", [[]])[0] if isinstance(enc, dict) else [] - ) + else (enc.get("categories", [[]])[0] if isinstance(enc, dict) else []) ) if cats: ann[var] = cats @@ -128,15 +126,11 @@ def _resolve_input_dims(config, artifacts): """Ensure input_dims is present in config, deriving from feature_lists if needed.""" feature_lists = artifacts.get("feature_lists", {}) layers = ( - config.get("input_layers") - or config.get("layers") - or list(feature_lists.keys()) + config.get("input_layers") or config.get("layers") or list(feature_lists.keys()) ) input_dims = config.get("input_dims") if not input_dims: - input_dims = [ - len(feature_lists[l]) for l in layers if l in feature_lists - ] + input_dims = [len(feature_lists[l]) for l in layers if l in feature_lists] config["input_dims"] = input_dims return config @@ -242,9 +236,7 @@ def _deserialize_json_artifacts(artifacts): encoder_type = encoder_dict.get("type") if encoder_type == "LabelEncoder": enc = LabelEncoder() - enc.classes_ = np.array( - encoder_dict.get("classes", []), dtype=object - ) + enc.classes_ = np.array(encoder_dict.get("classes", []), dtype=object) label_encoders[variable] = enc continue @@ -315,9 +307,7 @@ def _load_artifacts(artifacts_path): return raw -def reconstruct_model( - safetensors_path, config_path, artifacts_path, device="cpu" -): +def reconstruct_model(safetensors_path, config_path, artifacts_path, device="cpu"): """ Reconstruct a full Flexynesis model from: - safetensors_path : .safetensors weights file diff --git a/flexynesis/main.py b/flexynesis/main.py index be37e315..eeb701ff 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -96,7 +96,9 @@ def __init__( num_workers=2, ): self.dataset = dataset # dataset for model initiation - self.loader_dataset = dataset # dataset for defining data loaders (this can be model specific) + self.loader_dataset = ( + dataset # dataset for defining data loaders (this can be model specific) + ) self.model_class = model_class self.target_variables = target_variables self.device_type = device_type @@ -110,7 +112,9 @@ def __init__( self.batch_variables = batch_variables self.config_name = config_name self.n_iter = n_iter - self.plot_losses = plot_losses # Whether to show live loss plots (useful in interactive mode) + self.plot_losses = ( + plot_losses # Whether to show live loss plots (useful in interactive mode) + ) self.val_size = val_size self.use_cv = use_cv self.n_splits = cv_splits @@ -180,14 +184,10 @@ def get_batch_space(self, min_size=32, max_size=128): end = int(np.log2(max_size)) if m < end: end = m - s = Categorical( - [np.power(2, x) for x in range(st, end + 1)], name="batch_size" - ) + s = Categorical([np.power(2, x) for x in range(st, end + 1)], name="batch_size") return s - def setup_trainer( - self, params, current_step, total_steps, full_train=False - ): + def setup_trainer(self, params, current_step, total_steps, full_train=False): # Configure callbacks and trainer for the current fold mycallbacks = [] if self.plot_losses: @@ -282,12 +282,8 @@ def objective(self, params, current_step, total_steps, full_train=False): print( f"[INFO] {'training cross-validation fold' if self.use_cv else 'training validation split'} {i}" ) - train_subset = torch.utils.data.Subset( - self.loader_dataset, train_index - ) - val_subset = torch.utils.data.Subset( - self.loader_dataset, val_index - ) + train_subset = torch.utils.data.Subset(self.loader_dataset, train_index) + val_subset = torch.utils.data.Subset(self.loader_dataset, val_index) train_loader = self.DataLoader( train_subset, batch_size=int(params["batch_size"]), @@ -322,9 +318,7 @@ def objective(self, params, current_step, total_steps, full_train=False): epochs.append(early_stop_callback.stopped_epoch) else: epochs.append(int(params["epochs"])) - validation_result = trainer.validate( - model, dataloaders=val_loader - ) + validation_result = trainer.validate(model, dataloaders=val_loader) val_loss = validation_result[0]["val_loss"] validation_losses.append(val_loss) i += 1 @@ -376,13 +370,9 @@ def perform_tuning(self, hpo_patience=0): best_params = suggested_params_list best_epochs = avg_epochs best_model = model - no_improvement_count = ( - 0 # Reset the no improvement counter - ) + no_improvement_count = 0 # Reset the no improvement counter else: - no_improvement_count += ( - 1 # Increment the no improvement counter - ) + no_improvement_count += 1 # Increment the no improvement counter # Print result of each iteration pbar.set_postfix({"Iteration": i + 1, "Best Loss": best_loss}) @@ -395,10 +385,7 @@ def perform_tuning(self, hpo_patience=0): ) break # Break out of the loop best_params_dict = ( - { - param.name: value - for param, value in zip(self.space, best_params) - } + {param.name: value for param, value in zip(self.space, best_params)} if best_params else None ) @@ -559,9 +546,7 @@ def train_dataloader(self): # Override to load data for the current fold train_idx, val_idx = self.folds_data[self.current_fold] train_subset = torch.utils.data.Subset(self.dataset, train_idx) - return DataLoader( - train_subset, batch_size=self.batch_size, shuffle=True - ) + return DataLoader(train_subset, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): # Override to load validation data for the current fold @@ -645,9 +630,7 @@ def run_experiments(self): ) # Find the best configuration based on validation loss - best_config = min( - val_loss_results, key=lambda x: x["average_val_loss"] - ) + best_config = min(val_loss_results, key=lambda x: x["average_val_loss"]) print( f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}", diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index be4efca9..d4959582 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -46,9 +46,7 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [ - self.surv_event_var - ] + self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = ( self.target_variables + self.batch_variables @@ -58,9 +56,7 @@ def __init__( self.variable_types = dataset.variable_types self.ann = dataset.ann - self.input_layers = ( - input_layers if input_layers else list(dataset.dat.keys()) - ) + self.input_layers = input_layers if input_layers else list(dataset.dat.keys()) self.output_layers = ( output_layers if output_layers else list(dataset.dat.keys()) ) @@ -178,9 +174,7 @@ def forward(self, x_list_input): z = self.reparameterization(mean, log_var) # decode the latent space to target output layer(s) - x_hat_list = [ - self.decoders[i](z) for i in range(len(self.output_layers)) - ] + x_hat_list = [self.decoders[i](z) for i in range(len(self.output_layers))] # run the supervisor heads using the latent layer as input outputs = {} @@ -257,9 +251,7 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor( - 0.0, device=y_hat.device, requires_grad=True - ) + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss def compute_total_loss(self, losses): @@ -534,9 +526,7 @@ def MMD_loss(self, latent_dim, z, xhat, x): # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): - input_data = list( - args[:-2] - ) # one or more tensors (one per omics layer) + input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] @@ -616,18 +606,12 @@ def compute_feature_importance( for batch in dataloader: dat, _, _ = batch - x_list = [ - to_device_safe(dat[x], device) for x in self.input_layers - ] - input_data = tuple( - [data.unsqueeze(0).requires_grad_() for data in x_list] - ) + x_list = [to_device_safe(dat[x], device) for x in self.input_layers] + input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif ( - method == "GradientShap" - ): # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( torch.cat( [torch.zeros_like(x) for _ in range(steps_or_samples)], @@ -690,9 +674,7 @@ def compute_feature_importance( # Process each layer within the class for layer_idx in range(num_layers): # Extract all batch tensors for this layer across all batches for the current class - layer_tensors = [ - batch_attr[layer_idx] for batch_attr in class_attr - ] + layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] # Concatenate tensors along the batch dimension attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index b649bec7..3ee56071 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -45,9 +45,7 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [ - self.surv_event_var - ] + self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = ( self.target_variables + batch_variables @@ -68,8 +66,7 @@ def __init__( self.ann = dataset.ann self.layers = list(dataset.dat.keys()) self.input_dims = [ - len(dataset.features[self.layers[i]]) - for i in range(len(self.layers)) + len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) ] self.encoders = nn.ModuleList( @@ -187,9 +184,7 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor( - 0.0, device=y_hat.device, requires_grad=True - ) + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss def compute_total_loss(self, losses): @@ -337,9 +332,7 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: - logits = ( - outputs[var].detach().cpu() - ) # Raw model outputs (logits) + logits = outputs[var].detach().cpu() # Raw model outputs (logits) if dataset.variable_types[var] == "categorical": probs = torch.softmax( @@ -379,9 +372,7 @@ def transform(self, dataset): dataset, batch_size=64, shuffle=False ) # Adjust the batch size as needed - embeddings_list = ( - [] - ) # Initialize a list to collect all batch embeddings + embeddings_list = [] # Initialize a list to collect all batch embeddings sample_names = [] # List to collect sample names # Process each batch @@ -423,9 +414,7 @@ def transform(self, dataset): # Adaptor forward function for captum integrated gradients or gradient shap def forward_target(self, *args): - input_data = list( - args[:-2] - ) # one or more tensors (one per omics layer) + input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest steps = args[ -1 @@ -500,15 +489,11 @@ def compute_feature_importance( for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] - input_data = tuple( - [data.unsqueeze(0).requires_grad_() for data in x_list] - ) + input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif ( - method == "GradientShap" - ): # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( torch.cat( [torch.zeros_like(x) for _ in range(steps_or_samples)], @@ -567,9 +552,7 @@ def compute_feature_importance( class_attr = aggregated_attributions[class_idx] layer_attributions = [] for layer_idx in range(num_layers): - layer_tensors = [ - batch_attr[layer_idx] for batch_attr in class_attr - ] + layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) processed_attributions.append(layer_attributions) diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index 87ec49d2..62484b04 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -70,9 +70,7 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [ - self.surv_event_var - ] + self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = ( self.target_variables + self.batch_variables @@ -285,9 +283,7 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor( - 0.0, device=y_hat.device, requires_grad=True - ) + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss def compute_total_loss(self, losses): @@ -359,9 +355,7 @@ def predict(self, dataset): outputs = self.forward(x, edge_index) for var in self.variables: - logits = ( - outputs[var].detach().cpu() - ) # Raw model outputs (logits) + logits = outputs[var].detach().cpu() # Raw model outputs (logits) if dataset.variable_types[var] == "categorical": probs = torch.softmax( @@ -428,9 +422,7 @@ def transform(self, dataset): # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): - input_data = list( - args[:-2] - ) # expect a single tensor (early integration) + input_data = list(args[:-2]) # expect a single tensor (early integration) target_var = args[-2] # target variable of interest steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] @@ -527,14 +519,9 @@ def bytes_to_gb(bytes): if method == "IntegratedGradients": baseline = torch.zeros_like(input_data) - elif ( - method == "GradientShap" - ): # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = torch.cat( - [ - torch.zeros_like(input_data) - for _ in range(steps_or_samples) - ], + [torch.zeros_like(input_data) for _ in range(steps_or_samples)], dim=0, ) @@ -586,15 +573,12 @@ def bytes_to_gb(bytes): for class_idx in range(len(aggregated_attributions)): class_attr = aggregated_attributions[class_idx] # Concatenate tensors along the batch dimension - attr_concat = torch.cat( - [batch_attr for batch_attr in class_attr], dim=1 - ) + attr_concat = torch.cat([batch_attr for batch_attr in class_attr], dim=1) processed_attributions.append(attr_concat) # compute absolute importance and move to cpu abs_attr = [ - torch.abs(attr_class).cpu() - for attr_class in processed_attributions + torch.abs(attr_class).cpu() for attr_class in processed_attributions ] # average over samples imp = [a.mean(dim=1) for a in abs_attr] @@ -612,9 +596,7 @@ def bytes_to_gb(bytes): # MPS and CPU don't have detailed memory tracking like CUDA df_list = [] - layers = list( - getattr(dataset, "multiomic_dataset", dataset).dat.keys() - ) + layers = list(getattr(dataset, "multiomic_dataset", dataset).dat.keys()) for i in range(num_class): features = dataset.common_features target_class_label = ( @@ -626,7 +608,9 @@ def bytes_to_gb(bytes): # Extracting node feature attributes coming from different omic layers importances_array = imp[i].squeeze().detach().numpy() if importances_array.ndim == 1: - importances = importances_array # Use the array as is if it is 1-dimensional + importances = ( + importances_array # Use the array as is if it is 1-dimensional + ) else: importances = importances_array[ :, l diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index bddaaacc..7ba5ac7f 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -58,9 +58,7 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [ - self.surv_event_var - ] + self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = ( self.target_variables + batch_variables @@ -83,9 +81,7 @@ def __init__( self.log_vars[loss_type] = nn.Parameter(torch.zeros(1)) layers = list(dataset.dat.keys()) - input_dims = [ - len(dataset.features[layers[i]]) for i in range(len(layers)) - ] + input_dims = [len(dataset.features[layers[i]]) for i in range(len(layers))] # create a list of Encoder instances for separately encoding each omics layer self.encoders = nn.ModuleList( [ @@ -255,9 +251,7 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor( - 0.0, device=y_hat.device, requires_grad=True - ) + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss def compute_total_loss(self, losses): @@ -431,9 +425,7 @@ def transform(self, dataset): sample_names.extend(samples) # Collect sample names for this batch # Concatenate all batch latent representations into one array - concatenated_latents = np.concatenate( - all_latent_representations, axis=0 - ) + concatenated_latents = np.concatenate(all_latent_representations, axis=0) # Convert the array to a DataFrame z = pd.DataFrame(concatenated_latents) @@ -482,9 +474,7 @@ def predict(self, dataset): # Collect predictions for each variable for var in self.variables: - logits = ( - outputs[var].detach().cpu() - ) # Raw model outputs (logits) + logits = outputs[var].detach().cpu() # Raw model outputs (logits) if dataset.variable_types[var] == "categorical": probs = torch.softmax( @@ -560,9 +550,7 @@ def MMD_loss(self, latent_dim, z, xhat, x): # Adaptor forward function for captum integrated gradients. def forward_target(self, *args): - input_data = list( - args[:-2] - ) # one or more tensors (one per omics layer) + input_data = list(args[:-2]) # one or more tensors (one per omics layer) target_var = args[-2] # target variable of interest steps = args[-1] # number of steps for IntegratedGradients().attribute outputs_list = [] @@ -642,15 +630,11 @@ def compute_feature_importance( for batch in dataloader: dat, _, _ = batch x_list = [to_device_safe(dat[x], device) for x in dat.keys()] - input_data = tuple( - [data.unsqueeze(0).requires_grad_() for data in x_list] - ) + input_data = tuple([data.unsqueeze(0).requires_grad_() for data in x_list]) if method == "IntegratedGradients": baseline = tuple(torch.zeros_like(x) for x in input_data) - elif ( - method == "GradientShap" - ): # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = tuple( torch.cat( [torch.zeros_like(x) for _ in range(steps_or_samples)], @@ -712,9 +696,7 @@ def compute_feature_importance( # Process each layer within the class for layer_idx in range(num_layers): # Extract all batch tensors for this layer across all batches for the current class - layer_tensors = [ - batch_attr[layer_idx] for batch_attr in class_attr - ] + layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] # Concatenate tensors along the batch dimension attr_concat = torch.cat(layer_tensors, dim=1) layer_attributions.append(attr_concat) diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 9aa929de..064696df 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -52,9 +52,7 @@ def __init__( # both surv event and time variables are assumed to be numerical variables # we create only one survival variable for the pair (surv_time_var and surv_event_var) if self.surv_event_var is not None and self.surv_time_var is not None: - self.target_variables = self.target_variables + [ - self.surv_event_var - ] + self.target_variables = self.target_variables + [self.surv_event_var] self.batch_variables = batch_variables self.variables = ( self.target_variables + batch_variables @@ -85,8 +83,7 @@ def __init__( self.layers = list(dataset.dat.keys()) self.input_dims = [ - len(dataset.features[self.layers[i]]) - for i in range(len(self.layers)) + len(dataset.features[self.layers[i]]) for i in range(len(self.layers)) ] self.encoders = nn.ModuleList( @@ -239,9 +236,7 @@ def compute_loss(self, var, y, y_hat): y = y[valid_indices] loss = F.cross_entropy(y_hat, y.long()) else: - loss = torch.tensor( - 0.0, device=y_hat.device, requires_grad=True - ) + loss = torch.tensor(0.0, device=y_hat.device, requires_grad=True) return loss def compute_total_loss(self, losses): @@ -511,9 +506,7 @@ def compute_feature_importance( # define data loader triplet_dataset = TripletMultiOmicDataset(dataset, self.main_var) - dataloader = DataLoader( - triplet_dataset, batch_size=batch_size, shuffle=False - ) + dataloader = DataLoader(triplet_dataset, batch_size=batch_size, shuffle=False) # Choose the attribution method dynamically if method == "IntegratedGradients": @@ -533,7 +526,7 @@ def compute_feature_importance( aggregated_attributions = [[] for _ in range(num_class)] for batch in dataloader: # see training_step to see how elements are accessed in batches - anchor, positive, negative= ( + anchor, positive, negative = ( batch[0], batch[1], batch[2], @@ -541,29 +534,18 @@ def compute_feature_importance( # Move tensors to the specified device using MPS-safe method anchor = {k: to_device_safe(v, device) for k, v in anchor.items()} - positive = { - k: to_device_safe(v, device) for k, v in positive.items() - } - negative = { - k: to_device_safe(v, device) for k, v in negative.items() - } + positive = {k: to_device_safe(v, device) for k, v in positive.items()} + negative = {k: to_device_safe(v, device) for k, v in negative.items()} anchor = [data.requires_grad_() for data in list(anchor.values())] - positive = [ - data.requires_grad_() for data in list(positive.values()) - ] - negative = [ - data.requires_grad_() for data in list(negative.values()) - ] + positive = [data.requires_grad_() for data in list(positive.values())] + negative = [data.requires_grad_() for data in list(negative.values())] # concatenate multiomic layers of each list element # then stack the anchor/positive/negative # the purpose is to get a single tensor input_data = torch.stack( - [ - torch.cat(sublist, dim=1) - for sublist in [anchor, positive, negative] - ] + [torch.cat(sublist, dim=1) for sublist in [anchor, positive, negative]] ).unsqueeze(0) # layer sizes will be needed to revert the concatenated tensor @@ -573,14 +555,9 @@ def compute_feature_importance( # Define a baseline if method == "IntegratedGradients": baseline = torch.zeros_like(input_data) - elif ( - method == "GradientShap" - ): # provide multiple baselines for Gr.Shap + elif method == "GradientShap": # provide multiple baselines for Gr.Shap baseline = torch.cat( - [ - torch.zeros_like(input_data) - for _ in range(steps_or_samples) - ], + [torch.zeros_like(input_data) for _ in range(steps_or_samples)], dim=0, ) @@ -652,9 +629,7 @@ def compute_feature_importance( # Process each layer within the class for layer_idx in range(num_layers): # Extract all batch tensors for this layer across all batches for the current class - layer_tensors = [ - batch_attr[layer_idx] for batch_attr in class_attr - ] + layer_tensors = [batch_attr[layer_idx] for batch_attr in class_attr] # Concatenate tensors along the batch dimension attr_concat = torch.cat(layer_tensors, dim=2) layer_attributions.append(attr_concat) diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 81f4a67b..32cdaa8a 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -234,9 +234,7 @@ def __init__( self.dropout = nn.Dropout(dropout_rate) # Initialize the first convolution layer separately if different input size - self.convs.append( - conv_options[conv](node_feature_count, node_embedding_dim) - ) + self.convs.append(conv_options[conv](node_feature_count, node_embedding_dim)) self.bns.append(nn.BatchNorm1d(node_embedding_dim)) # Loop to create the remaining convolution and BN layers @@ -287,9 +285,7 @@ def cox_ph_loss(outputs, durations, events): hazards = hazards.unsqueeze(0) # Make hazards 1D if it's a scalar # Calculate the risk set sum log_risk_set_sum = torch.log( - torch.cumsum( - hazards[torch.argsort(durations, descending=True)], dim=0 - ) + torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0) ) # Get the indices that sort the durations in descending order sorted_indices = torch.argsort(durations, descending=True) @@ -301,9 +297,7 @@ def cox_ph_loss(outputs, durations, events): ) - torch.sum(log_risk_set_sum[events_sorted == 1]) total_loss = -uncensored_loss / torch.sum(events) else: - total_loss = torch.tensor( - 0.0, device=outputs.device, requires_grad=True - ) + total_loss = torch.tensor(0.0, device=outputs.device, requires_grad=True) if not torch.isfinite(total_loss): return torch.tensor(0.0, device=outputs.device, requires_grad=True) return total_loss diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 2985550f..9f60487c 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -195,9 +195,7 @@ def plot_dim_reduced( + theme_minimal() ) else: - raise ValueError( - "Invalid color_type. Choose 'categorical' or 'numerical'." - ) + raise ValueError("Invalid color_type. Choose 'categorical' or 'numerical'.") return p @@ -380,9 +378,7 @@ def plot_boxplot( verticalalignment="top", horizontalalignment="left", fontsize=12, - bbox=dict( - boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray" - ), + bbox=dict(boxstyle="round,pad=0.3", facecolor="white", edgecolor="gray"), ) plt.tight_layout() @@ -572,9 +568,7 @@ def plot_roc_curves(y_true, y_probs): ggplot(all_data, aes(x="fpr", y="tpr", color="label")) + geom_line(size=1.2) + geom_abline(intercept=0, slope=1, linetype="dashed", color="gray") - + labs( - title="ROC Curve", x="False Positive Rate", y="True Positive Rate" - ) + + labs(title="ROC Curve", x="False Positive Rate", y="True Positive Rate") + theme_minimal() ) @@ -691,9 +685,7 @@ def evaluate_wrapper( if var == surv_event_var: events = dataset.ann[surv_event_var] durations = dataset.ann[surv_time_var] - metrics = evaluate_survival( - y_pred_dict[var], durations, events - ) + metrics = evaluate_survival(y_pred_dict[var], durations, events) else: ind = ~torch.isnan(dataset.ann[var]) metrics = evaluate_regressor( @@ -701,9 +693,7 @@ def evaluate_wrapper( ) else: ind = ~torch.isnan(dataset.ann[var]) - metrics = evaluate_classifier( - dataset.ann[var][ind], y_pred_dict[var][ind] - ) + metrics = evaluate_classifier(dataset.ann[var][ind], y_pred_dict[var][ind]) for metric, value in metrics.items(): metrics_list.append( @@ -754,16 +744,13 @@ def get_predicted_labels(y_pred_dict, dataset, split, method_name): for idx in range(probabilities.shape[1]) ] else: - class_labels = [ - f"class_{i}" for i in range(probabilities.shape[1]) - ] + class_labels = [f"class_{i}" for i in range(probabilities.shape[1])] # Get true labels (y_true) y_true = [ ( dataset.label_mappings[var][int(x.item())] - if var in dataset.label_mappings.keys() - and not np.isnan(x.item()) + if var in dataset.label_mappings.keys() and not np.isnan(x.item()) else np.nan ) for x in dataset.ann[var] @@ -846,9 +833,7 @@ def evaluate_baseline_performance( def prepare_data(data_object, pca_model=None, fit_pca=False): # Concatenate Data Matrices - X = np.concatenate( - [tensor for tensor in data_object.dat.values()], axis=1 - ) + X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) y = np.array(data_object.ann[variable_name]) # Filter out samples without a valid label @@ -956,11 +941,7 @@ def prepare_data(data_object, pca_model=None, fit_pca=False): metrics_list.append( { "method": method - + ( - "Classifier" - if variable_type == "categorical" - else "Regressor" - ), + + ("Classifier" if variable_type == "categorical" else "Regressor"), "var": variable_name, "variable_type": variable_type, "metric": metric, @@ -1000,9 +981,7 @@ def evaluate_baseline_survival_performance( def prepare_data(data_object, duration_col, event_col): # Concatenate Data Matrices - X = np.concatenate( - [tensor for tensor in data_object.dat.values()], axis=1 - ) + X = np.concatenate([tensor for tensor in data_object.dat.values()], axis=1) # Prepare Survival Data (Durations and Events) durations = np.array(data_object.ann[duration_col]) @@ -1022,9 +1001,7 @@ def prepare_data(data_object, duration_col, event_col): X_train, y_train, train_indices = prepare_data( train_dataset, duration_col, event_col ) - X_test, y_test, test_indices = prepare_data( - test_dataset, duration_col, event_col - ) + X_test, y_test, test_indices = prepare_data(test_dataset, duration_col, event_col) # Initialize Random Survival Forest rsf = RandomSurvivalForest( @@ -1175,9 +1152,7 @@ def subset_assays_by_features(dataset, features_dict): # data matrix for each key in features_dict subset_dat = {} for layer in features_dict.keys(): - indices = [ - dataset.features[layer].get_loc(x) for x in features_dict[layer] - ] + indices = [dataset.features[layer].get_loc(x) for x in features_dict[layer]] subset_dat[layer] = dataset.dat[layer][:, indices] # Convert subset_dat to pandas DataFrame and prepend feature names with layer names df_list = [] @@ -1186,9 +1161,7 @@ def subset_assays_by_features(dataset, features_dict): df_temp = pd.DataFrame(data) # Rename columns to prepend with layer name - df_temp.columns = [ - f"{layer}_{feature}" for feature in features_dict[layer] - ] + df_temp.columns = [f"{layer}_{feature}" for feature in features_dict[layer]] df_list.append(df_temp) # Concatenate dataframes horizontally concatenated_df = pd.concat(df_list, axis=1) @@ -1290,9 +1263,7 @@ def recursive_binary_split_minN( continue try: - cutoff, pval = find_optimal_cutoff( - node[score], node[time], node[event] - ) + cutoff, pval = find_optimal_cutoff(node[score], node[time], node[event]) except Exception: cutoff, pval = None, 1.0 @@ -1306,10 +1277,7 @@ def recursive_binary_split_minN( right = node[node[score] > cutoff] # if either resulting child is smaller than required, reject split - if ( - len(left) < min_samples_per_group - or len(right) < min_samples_per_group - ): + if len(left) < min_samples_per_group or len(right) < min_samples_per_group: groups.update({i: next_gid for i in node.index}) next_gid += 1 continue @@ -1396,9 +1364,7 @@ def build_cox_model( event_col: str, n_splits: int = 5, random_state: int = 42, - eval_time: ( - float | None - ) = None, # single horizon, same units as duration_col + eval_time: float | None = None, # single horizon, same units as duration_col low_variance_threshold: float = 0.01, cox_penalizer=0.05, return_metrics: bool = True, @@ -1416,9 +1382,7 @@ def build_cox_model( } """ - def remove_low_variance_survival_features( - df, duration_col, event_col, threshold - ): + def remove_low_variance_survival_features(df, duration_col, event_col, threshold): events = df[event_col].astype(bool) low_var = [] for feature in df.drop(columns=[duration_col, event_col]).columns: @@ -1486,12 +1450,8 @@ def remove_low_variance_survival_features( auc_val = float(np.atleast_1d(auc_val)[0]) auc_per_fold.append(auc_val) # Aggregate CV metrics - metrics["cv_cindex_mean"] = ( - float(np.mean(c_indices)) if c_indices else None - ) - metrics["cv_auc_mean"] = ( - float(np.mean(auc_per_fold)) if auc_per_fold else None - ) + metrics["cv_cindex_mean"] = float(np.mean(c_indices)) if c_indices else None + metrics["cv_auc_mean"] = float(np.mean(auc_per_fold)) if auc_per_fold else None # Fit final model on full data for downstream use (forest plots, HRs, etc.) final_model = CoxPHFitter(penalizer=cox_penalizer) @@ -1553,9 +1513,7 @@ def louvain_clustering(X, threshold=None, k=None): cluster_labels = np.full(len(X), np.nan, dtype=float) # Fill the array with the cluster labels from the partition dictionary for node_id, cluster_label in partition.items(): - if node_id in range( - len(X) - ): # Check if the node_id is a valid index in X + if node_id in range(len(X)): # Check if the node_id is a valid index in X cluster_labels[node_id] = cluster_label else: # If node_id is not a valid index in X, it's already set to NaN @@ -1717,9 +1675,7 @@ def create_covariate_matrix(covariates, variable_types, ann): if variable_types.get(var) == "categorical": # One-hot-encode categorical variables with 0/1 encoding one_hot = pd.get_dummies(ann[var], prefix=var).astype(int) - covariate_features.append( - one_hot.T - ) # Transpose to make features rows + covariate_features.append(one_hot.T) # Transpose to make features rows feature_names.extend(one_hot.columns.tolist()) elif variable_types.get(var) == "numerical": # Handle numerical variables with missing values @@ -1770,9 +1726,7 @@ def generate_synthetic_batches(n_samples_per_batch=150, n_features=50): return synthetic_data, batch_labels -def optimal_transport_align( - embeddings, batch_labels, standardize_by_labels=False -): +def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=False): """ Align embeddings from two batches using Optimal Transport, preserving the order of samples. @@ -1790,9 +1744,7 @@ def optimal_transport_align( # Identify unique batches unique_batches = np.unique(batch_labels_np) if len(unique_batches) != 2: - raise ValueError( - "Optimal transport supports aligning exactly two batches." - ) + raise ValueError("Optimal transport supports aligning exactly two batches.") # Split embeddings by batch, preserving the original indices batch1_indices = np.where(batch_labels_np == unique_batches[0])[0] @@ -1802,9 +1754,7 @@ def optimal_transport_align( batch2_embeddings = embeddings.iloc[batch2_indices].to_numpy() # Compute the cost matrix (e.g., Euclidean distances) - cost_matrix = ot.dist( - batch1_embeddings, batch2_embeddings, metric="euclidean" - ) + cost_matrix = ot.dist(batch1_embeddings, batch2_embeddings, metric="euclidean") # Solve the optimal transport problem n_samples_1 = batch1_embeddings.shape[0] @@ -1876,9 +1826,7 @@ def reciprocal_pca_mnn( # Identify unique batches unique_batches = np.unique(batch_labels_np) if len(unique_batches) != 2: - raise ValueError( - "Reciprocal PCA supports aligning exactly two batches." - ) + raise ValueError("Reciprocal PCA supports aligning exactly two batches.") # Split embeddings by batch, preserving the original indices batch1_indices = np.where(batch_labels_np == unique_batches[0])[0] @@ -1908,12 +1856,8 @@ def reciprocal_pca_mnn( batch2_to_batch1 = pca1.transform(batch2_embeddings) # Use MNN to identify anchors - neighbors1 = NearestNeighbors(n_neighbors=n_neighbors).fit( - batch2_to_batch1 - ) - neighbors2 = NearestNeighbors(n_neighbors=n_neighbors).fit( - batch1_to_batch2 - ) + neighbors1 = NearestNeighbors(n_neighbors=n_neighbors).fit(batch2_to_batch1) + neighbors2 = NearestNeighbors(n_neighbors=n_neighbors).fit(batch1_to_batch2) distances1, indices1 = neighbors1.kneighbors(batch1_pca) distances2, indices2 = neighbors2.kneighbors(batch2_pca) @@ -1926,9 +1870,7 @@ def reciprocal_pca_mnn( mutual_anchors.append((i, neighbor)) if not mutual_anchors: - raise ValueError( - "No mutual nearest neighbors (MNN) found between the batches." - ) + raise ValueError("No mutual nearest neighbors (MNN) found between the batches.") # Align the datasets using anchors mutual_anchors = np.array(mutual_anchors) @@ -1966,9 +1908,7 @@ def reciprocal_pca_mnn( class CBioPortalData: - def __init__( - self, study_id, base_url="https://datahub.assets.cbioportal.org" - ): + def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): self.base_url = base_url self.study_id = study_id self.data_files = None @@ -1982,9 +1922,7 @@ def download_study_archive(self, force=False, timeout=60): return dest_file print(f"Downloading {url}...") - r = requests.get( - url, stream=True, allow_redirects=True, timeout=timeout - ) + r = requests.get(url, stream=True, allow_redirects=True, timeout=timeout) r.raise_for_status() # <-- key: fail fast on 404/403/etc. with open(dest_file, "wb") as f: @@ -2003,9 +1941,7 @@ def extract_archive(self, archive_path): tar.extractall() self.data_files = [ - f - for f in os.listdir(base) - if f.startswith("data_") and f.endswith(".txt") + f for f in os.listdir(base) if f.startswith("data_") and f.endswith(".txt") ] return base @@ -2017,9 +1953,7 @@ def read_data(self, files=None): for datatype, file in files.items(): print(f"Importing {file}...") file_path = os.path.join(self.study_id, file) - df = pd.read_csv( - file_path, sep="\t", comment="#", low_memory=False - ) + df = pd.read_csv(file_path, sep="\t", comment="#", low_memory=False) if "mutations" in file: print(f"Binarizing and converting {file} to matrix...") @@ -2090,9 +2024,7 @@ def split_data(self, samples=None, ratio=0.7): if samples is None: samples = self.data["clin"].index.tolist() - train_samples = list( - pd.Series(samples).sample(frac=ratio, random_state=42) - ) + train_samples = list(pd.Series(samples).sample(frac=ratio, random_state=42)) test_samples = list(set(samples) - set(train_samples)) train_data = {} @@ -2135,17 +2067,13 @@ def compute_correlation_loss(embeddings, batch_labels): ) # Normalize batch labels - batch_labels = (batch_labels - batch_labels.mean()) / ( - batch_labels.std() + 1e-8 - ) + batch_labels = (batch_labels - batch_labels.mean()) / (batch_labels.std() + 1e-8) # Reshape batch_labels to (num_samples, 1) for broadcasting batch_labels = batch_labels.unsqueeze(1) # Compute covariance (dot product of batch_labels and embeddings) - covariance = torch.matmul(batch_labels.T, embeddings) / ( - embeddings.shape[0] - 1 - ) + covariance = torch.matmul(batch_labels.T, embeddings) / (embeddings.shape[0] - 1) # Compute sum of squared correlations loss = torch.sum(torch.abs(covariance)) @@ -2263,8 +2191,7 @@ def get_device_memory_info(device_str): return { "allocated": torch.cuda.memory_allocated() / (1024**2), # MB "reserved": torch.cuda.memory_reserved() / (1024**2), # MB - "max_allocated": torch.cuda.max_memory_allocated() - / (1024**2), # MB + "max_allocated": torch.cuda.max_memory_allocated() / (1024**2), # MB "device_name": torch.cuda.get_device_name(0), "device_count": torch.cuda.device_count(), } diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index 23057298..e24c3411 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -11,12 +11,8 @@ def test_mps_device_detection(): if not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") - assert ( - device_type == "mps" - ), f"Expected device type 'mps', got '{device_type}'" - assert ( - device_str == "mps" - ), f"Expected device string 'mps', got '{device_str}'" + assert device_type == "mps", f"Expected device type 'mps', got '{device_type}'" + assert device_str == "mps", f"Expected device string 'mps', got '{device_str}'" def test_mps_tensor_operations(): @@ -49,13 +45,11 @@ def test_mps_memory_allocation(): if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") -# Test memory tracking + # Test memory tracking memory_before = torch.mps.current_allocated_memory() memory_after = torch.mps.current_allocated_memory() - assert ( - memory_after > memory_before - ), "Memory did not increase after allocation." + assert memory_after > memory_before, "Memory did not increase after allocation." def test_float64_to_float32_conversion(): @@ -72,6 +66,4 @@ def test_float64_to_float32_conversion(): x_mps = to_device_safe(x_float64, device) assert x_mps.dtype == torch.float32, f"Expected float32, got {x_mps.dtype}" - assert ( - x_mps.device.type == "mps" - ), f"Tensor not on MPS device: {x_mps.device.type}" + assert x_mps.device.type == "mps", f"Tensor not on MPS device: {x_mps.device.type}" From eaaf05edd696264051cbaa5e1984e97e1a4e869d Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:25:23 +0200 Subject: [PATCH 10/32] run isort --- flexynesis/__main__.py | 21 ++++++---------- flexynesis/data.py | 6 ++--- flexynesis/inference.py | 7 ++---- flexynesis/main.py | 5 ++-- flexynesis/utils.py | 55 +++++++++-------------------------------- 5 files changed, 25 insertions(+), 69 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 86379548..28efc6ef 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -871,14 +871,10 @@ def main(): # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import ( - evaluate_baseline_performance, - evaluate_baseline_survival_performance, - evaluate_wrapper, - get_device_memory_info, - get_optimal_device, - get_predicted_labels, - ) + from .utils import (evaluate_baseline_performance, + evaluate_baseline_survival_performance, + evaluate_wrapper, get_device_memory_info, + get_optimal_device, get_predicted_labels) if not (args.pretrained_model and args.artifacts and args.data_path_test): import json @@ -893,7 +889,6 @@ def main(): from .data import STRING, DataImporter, MultiOmicDatasetNW from .main import FineTuner, HyperparameterTuning from .models.crossmodal_pred import CrossModalPred - # models from .models.direct_pred import DirectPred from .models.gnn_early import GNN @@ -1510,11 +1505,9 @@ def main(): elif args.safetensors: import numpy as np - from sklearn.preprocessing import ( - LabelEncoder, - OrdinalEncoder, - StandardScaler, - ) + from sklearn.preprocessing import (LabelEncoder, + OrdinalEncoder, + StandardScaler) json_ready = { "schema_version": artifacts["schema_version"], diff --git a/flexynesis/data.py b/flexynesis/data.py index 054a3ea4..c6dfdaeb 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -798,10 +798,8 @@ def import_data(self): # Create covariates matrix if needed if "covariates" in self.modalities and labels_df is not None: - from flexynesis.utils import ( - create_covariate_matrix, - get_variable_types, - ) + from flexynesis.utils import (create_covariate_matrix, + get_variable_types) covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 6245cc51..8134cdc5 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -181,11 +181,8 @@ def check_file_type(file_path): def _deserialize_json_artifacts(artifacts): import numpy as np - from sklearn.preprocessing import ( - LabelEncoder, - OrdinalEncoder, - StandardScaler, - ) + from sklearn.preprocessing import (LabelEncoder, OrdinalEncoder, + StandardScaler) # Rebuild sklearn objects expected by inference code. deserialized = dict(artifacts) diff --git a/flexynesis/main.py b/flexynesis/main.py index eeb701ff..96944db9 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -9,9 +9,8 @@ import torch import yaml from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import ( - RichProgressBarTheme, -) +from lightning.pytorch.callbacks.progress.rich_progress import \ + RichProgressBarTheme from skopt import Optimizer from skopt.space import Categorical, Integer, Real from torch.utils.data import DataLoader, random_split diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 9f60487c..c7c1fc55 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -11,22 +11,12 @@ from scipy.stats import linregress, pearsonr from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.feature_selection import ( - SelectFromModel, - mutual_info_classif, - mutual_info_regression, -) -from sklearn.metrics import ( - adjusted_mutual_info_score, - adjusted_rand_score, - average_precision_score, - balanced_accuracy_score, - classification_report, - cohen_kappa_score, - f1_score, - mean_squared_error, - roc_auc_score, -) +from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, + mutual_info_regression) +from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, + average_precision_score, balanced_accuracy_score, + classification_report, cohen_kappa_score, + f1_score, mean_squared_error, roc_auc_score) from sklearn.model_selection import GridSearchCV, KFold from sklearn.svm import SVC, SVR from sksurv.metrics import cumulative_dynamic_auc @@ -42,7 +32,6 @@ import community as community_louvain import matplotlib.pyplot as plt import networkx as nx - # imports import numpy as np import ot @@ -50,26 +39,10 @@ from lifelines import CoxPHFitter, KaplanMeierFitter from lifelines.statistics import logrank_test, multivariate_logrank_test from lifelines.utils import concordance_index -from plotnine import ( - aes, - annotate, - element_text, - geom_abline, - geom_errorbarh, - geom_line, - geom_point, - geom_smooth, - geom_step, - geom_text, - ggplot, - ggtitle, - labs, - scale_color_gradient, - scale_color_manual, - theme, - theme_bw, - theme_minimal, -) +from plotnine import (aes, annotate, element_text, geom_abline, geom_errorbarh, + geom_line, geom_point, geom_smooth, geom_step, geom_text, + ggplot, ggtitle, labs, scale_color_gradient, + scale_color_manual, theme, theme_bw, theme_minimal) from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.metrics import silhouette_score @@ -509,12 +482,8 @@ def evaluate_classifier(y_true, y_probs, print_report=False): } -from sklearn.metrics import ( - average_precision_score, - precision_recall_curve, - roc_auc_score, - roc_curve, -) +from sklearn.metrics import (average_precision_score, precision_recall_curve, + roc_auc_score, roc_curve) from sklearn.preprocessing import label_binarize From d4d8ce289fff484c4388b583e4ddfe584d58e367 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:31:58 +0200 Subject: [PATCH 11/32] fix module level import not at top of the file --- flexynesis/data.py | 4 ---- flexynesis/main.py | 23 +++++++---------------- flexynesis/utils.py | 20 ++++++-------------- 3 files changed, 13 insertions(+), 34 deletions(-) diff --git a/flexynesis/data.py b/flexynesis/data.py index c6dfdaeb..8b97c17d 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -670,10 +670,6 @@ def check_common_features(train_dat, test_dat): Add this to flexynesis/data.py """ -import numpy as np -import pandas as pd - - class DataImporterInference: """ Data importer for inference mode. diff --git a/flexynesis/main.py b/flexynesis/main.py index 96944db9..562d3059 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -1,16 +1,18 @@ -from lightning import seed_everything - -seed_everything(42, workers=True) - +import copy +import logging import os import lightning as pl +import matplotlib.pyplot as plt import numpy as np import torch import yaml +from IPython.display import display +from lightning import Callback, seed_everything from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar from lightning.pytorch.callbacks.progress.rich_progress import \ RichProgressBarTheme +from sklearn.model_selection import KFold from skopt import Optimizer from skopt.space import Categorical, Integer, Real from torch.utils.data import DataLoader, random_split @@ -20,7 +22,7 @@ from .data import TripletMultiOmicDataset torch.set_float32_matmul_precision("medium") - +seed_everything(42, workers=True) class HyperparameterTuning: """ @@ -451,12 +453,6 @@ def load_and_convert_config(self, config_path): return search_space_user -import copy -import logging - -import numpy as np -from sklearn.model_selection import KFold -from torch.utils.data import DataLoader, random_split class FineTuner(pl.LightningModule): @@ -651,11 +647,6 @@ def run_experiments(self): final_trainer.fit(self, train_dataloaders=dl) -import matplotlib.pyplot as plt -from IPython.display import display -from lightning import Callback - - class LiveLossPlot(Callback): """ A callback for visualizing training loss in real-time during hyperparameter optimization. diff --git a/flexynesis/utils.py b/flexynesis/utils.py index c7c1fc55..7c51854f 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -1,6 +1,7 @@ import os import tarfile import warnings +from collections import deque import matplotlib.pyplot as plt import numpy as np @@ -8,7 +9,7 @@ import requests import seaborn as sns import torch -from scipy.stats import linregress, pearsonr +from scipy.stats import kruskal, linregress, mannwhitneyu, pearsonr from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, @@ -16,8 +17,11 @@ from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, average_precision_score, balanced_accuracy_score, classification_report, cohen_kappa_score, - f1_score, mean_squared_error, roc_auc_score) + f1_score, mean_squared_error, + precision_recall_curve, roc_auc_score, roc_curve) from sklearn.model_selection import GridSearchCV, KFold +from sklearn.neighbors import NearestNeighbors +from sklearn.preprocessing import label_binarize from sklearn.svm import SVC, SVR from sksurv.metrics import cumulative_dynamic_auc from sksurv.util import Surv @@ -294,9 +298,6 @@ def plot_scatter(true_values, predicted_values): return plot -from scipy.stats import kruskal, mannwhitneyu - - def plot_boxplot( categorical_x, numerical_y, @@ -482,9 +483,6 @@ def evaluate_classifier(y_true, y_probs, print_report=False): } -from sklearn.metrics import (average_precision_score, precision_recall_curve, - roc_auc_score, roc_curve) -from sklearn.preprocessing import label_binarize def plot_roc_curves(y_true, y_probs): @@ -1201,8 +1199,6 @@ def find_optimal_cutoff( return best_cutoff, best_p -from collections import deque - def recursive_binary_split_minN( df, @@ -1763,8 +1759,6 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=Fals return aligned_embeddings_df, aligned_batch_labels -from sklearn.neighbors import NearestNeighbors - def reciprocal_pca_mnn( embeddings, @@ -1871,9 +1865,7 @@ def reciprocal_pca_mnn( return aligned_embeddings_df, aligned_batch_labels -import tarfile -import requests class CBioPortalData: From ec764681f1c008f35cf3e20a407f858b75f0b782 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:33:30 +0200 Subject: [PATCH 12/32] fix too many blank lines --- flexynesis/main.py | 2 -- flexynesis/utils.py | 7 ------- 2 files changed, 9 deletions(-) diff --git a/flexynesis/main.py b/flexynesis/main.py index 562d3059..60ada342 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -453,8 +453,6 @@ def load_and_convert_config(self, config_path): return search_space_user - - class FineTuner(pl.LightningModule): """ FineTuner class is designed for fine-tuning trained flexynesis models with flexible control over parameters such as diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 7c51854f..f48b141f 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -483,8 +483,6 @@ def evaluate_classifier(y_true, y_probs, print_report=False): } - - def plot_roc_curves(y_true, y_probs): """ Plot ROC curves using plotnine for binary or multiclass classification. @@ -1199,7 +1197,6 @@ def find_optimal_cutoff( return best_cutoff, best_p - def recursive_binary_split_minN( df, score="pred_risk", @@ -1759,7 +1756,6 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=Fals return aligned_embeddings_df, aligned_batch_labels - def reciprocal_pca_mnn( embeddings, batch_labels, @@ -1865,9 +1861,6 @@ def reciprocal_pca_mnn( return aligned_embeddings_df, aligned_batch_labels - - - class CBioPortalData: def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): self.base_url = base_url From f89f6f33292824974ee421083040972b92a8ec17 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 2 Apr 2026 21:35:30 +0200 Subject: [PATCH 13/32] fix f-string is missing placeholders --- flexynesis/data.py | 2 +- flexynesis/generate_coexpression_network.py | 18 +++++++++--------- flexynesis/inference.py | 2 +- flexynesis/utils.py | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/flexynesis/data.py b/flexynesis/data.py index 8b97c17d..53a48404 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -1580,7 +1580,7 @@ def score_column_match(col, col_idx, category, total_cols): or score_s < min_threshold ): print( - f"[WARNING] Low confidence in column detection. Using first 3 columns as fallback." + "[WARNING] Low confidence in column detection. Using first 3 columns as fallback." ) col_gene_a = df.columns[0] col_gene_b = df.columns[1] diff --git a/flexynesis/generate_coexpression_network.py b/flexynesis/generate_coexpression_network.py index 285399ef..71ccb556 100644 --- a/flexynesis/generate_coexpression_network.py +++ b/flexynesis/generate_coexpression_network.py @@ -193,14 +193,14 @@ def generate_coexpression_network( if na_count > 0: genes_with_na = expr_df.isna().any(axis=1).sum() print(f" [WARNING] Found {na_count} missing values in {genes_with_na} genes") - print(f" [INFO] Removing genes with missing data.") + print(" [INFO] Removing genes with missing data.") expr_df = expr_df.dropna() print( f" [INFO] Retained {expr_df.shape[0]} genes ({genes_with_na} genes removed)" ) # Build network directly - print(f"\n[2/3] Building network...") + print("\n[2/3] Building network...") edges = build_network( expr_df, method=method, min_correlation=min_correlation, top_k=top_k ) @@ -232,7 +232,7 @@ def generate_coexpression_network( print("\n" + "=" * 70) print("Network Generation Complete!") print("=" * 70) - print(f"\nNetwork Statistics:") + print("\nNetwork Statistics:") print(f" Total edges: {len(network_df):,}") print(f" Unique genes (GeneA): {network_df['GeneA'].nunique():,}") print(f" Unique genes (GeneB): {network_df['GeneB'].nunique():,}") @@ -246,17 +246,17 @@ def generate_coexpression_network( print(f" Median score: {network_df['Score'].median():.4f}") # Show sample - print(f"\nSample edges (first 5):") + print("\nSample edges (first 5):") print(network_df.head().to_string(index=False)) print(f"\n{'=' * 70}") print("Usage with Flexynesis:") print(f"{'=' * 70}") - print(f"\nflexynesis --data_path \\") - print(f" --model_class GNN \\") - print(f" --gnn_conv_type GCN \\") - print(f" --target_variables \\") - print(f" --data_types gex,cnv \\") + print("\nflexynesis --data_path \\") + print(" --model_class GNN \\") + print(" --gnn_conv_type GCN \\") + print(" --target_variables \\") + print(" --data_types gex,cnv \\") print(f" --user_graph {output_file}") print() diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 8134cdc5..361bb1b2 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -314,7 +314,7 @@ def reconstruct_model(safetensors_path, config_path, artifacts_path, device="cpu Returns a fully instantiated, weights-loaded, eval-mode model. """ - print(f"[INFO] Reconstructing model from safetensors") + print("[INFO] Reconstructing model from safetensors") print(f"[INFO] config : {config_path}") print(f"[INFO] artifacts : {artifacts_path}") print(f"[INFO] weights : {safetensors_path}") diff --git a/flexynesis/utils.py b/flexynesis/utils.py index f48b141f..aa261ea1 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -942,7 +942,7 @@ def evaluate_baseline_survival_performance( listed along with the method name and variable details. """ - print(f"[INFO] Evaluating baseline survival prediction performance") + print("[INFO] Evaluating baseline survival prediction performance") def prepare_data(data_object, duration_col, event_col): # Concatenate Data Matrices From 2c0a87a263f6412dc40c987deec9d4ab36b07e6d Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 10:00:16 +0200 Subject: [PATCH 14/32] fix redefinition of unused --- flexynesis/__main__.py | 23 ++++++++++------- flexynesis/data.py | 4 +-- flexynesis/inference.py | 3 +-- flexynesis/main.py | 4 +-- flexynesis/utils.py | 55 ++++++++++++++++++++++++++++------------- 5 files changed, 57 insertions(+), 32 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 28efc6ef..b7bed66b 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -739,8 +739,6 @@ def main(): if args.pretrained_model and args.artifacts and args.data_path_test: import torch - from .utils import get_optimal_device - # quick existence checks if not os.path.exists(args.pretrained_model): raise FileNotFoundError( @@ -871,10 +869,14 @@ def main(): # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import (evaluate_baseline_performance, - evaluate_baseline_survival_performance, - evaluate_wrapper, get_device_memory_info, - get_optimal_device, get_predicted_labels) + from .utils import ( + evaluate_baseline_performance, + evaluate_baseline_survival_performance, + evaluate_wrapper, + get_device_memory_info, + get_optimal_device, + get_predicted_labels, + ) if not (args.pretrained_model and args.artifacts and args.data_path_test): import json @@ -889,6 +891,7 @@ def main(): from .data import STRING, DataImporter, MultiOmicDatasetNW from .main import FineTuner, HyperparameterTuning from .models.crossmodal_pred import CrossModalPred + # models from .models.direct_pred import DirectPred from .models.gnn_early import GNN @@ -1505,9 +1508,11 @@ def main(): elif args.safetensors: import numpy as np - from sklearn.preprocessing import (LabelEncoder, - OrdinalEncoder, - StandardScaler) + from sklearn.preprocessing import ( + LabelEncoder, + OrdinalEncoder, + StandardScaler, + ) json_ready = { "schema_version": artifacts["schema_version"], diff --git a/flexynesis/data.py b/flexynesis/data.py index 53a48404..ff379f4a 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -670,6 +670,7 @@ def check_common_features(train_dat, test_dat): Add this to flexynesis/data.py """ + class DataImporterInference: """ Data importer for inference mode. @@ -794,8 +795,7 @@ def import_data(self): # Create covariates matrix if needed if "covariates" in self.modalities and labels_df is not None: - from flexynesis.utils import (create_covariate_matrix, - get_variable_types) + from flexynesis.utils import create_covariate_matrix, get_variable_types covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 361bb1b2..d469df03 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -181,8 +181,7 @@ def check_file_type(file_path): def _deserialize_json_artifacts(artifacts): import numpy as np - from sklearn.preprocessing import (LabelEncoder, OrdinalEncoder, - StandardScaler) + from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler # Rebuild sklearn objects expected by inference code. deserialized = dict(artifacts) diff --git a/flexynesis/main.py b/flexynesis/main.py index 60ada342..f2dd603b 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -10,8 +10,7 @@ from IPython.display import display from lightning import Callback, seed_everything from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import \ - RichProgressBarTheme +from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme from sklearn.model_selection import KFold from skopt import Optimizer from skopt.space import Categorical, Integer, Real @@ -24,6 +23,7 @@ torch.set_float32_matmul_precision("medium") seed_everything(42, workers=True) + class HyperparameterTuning: """ A class dedicated to performing hyperparameter tuning using Bayesian optimization for various types of models. diff --git a/flexynesis/utils.py b/flexynesis/utils.py index aa261ea1..efee92cf 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -12,13 +12,24 @@ from scipy.stats import kruskal, linregress, mannwhitneyu, pearsonr from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, - mutual_info_regression) -from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, - average_precision_score, balanced_accuracy_score, - classification_report, cohen_kappa_score, - f1_score, mean_squared_error, - precision_recall_curve, roc_auc_score, roc_curve) +from sklearn.feature_selection import ( + SelectFromModel, + mutual_info_classif, + mutual_info_regression, +) +from sklearn.metrics import ( + adjusted_mutual_info_score, + adjusted_rand_score, + average_precision_score, + balanced_accuracy_score, + classification_report, + cohen_kappa_score, + f1_score, + mean_squared_error, + precision_recall_curve, + roc_auc_score, + roc_curve, +) from sklearn.model_selection import GridSearchCV, KFold from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import label_binarize @@ -34,27 +45,37 @@ XGBRegressor = None import community as community_louvain -import matplotlib.pyplot as plt import networkx as nx -# imports -import numpy as np import ot -import pandas as pd from lifelines import CoxPHFitter, KaplanMeierFitter from lifelines.statistics import logrank_test, multivariate_logrank_test from lifelines.utils import concordance_index -from plotnine import (aes, annotate, element_text, geom_abline, geom_errorbarh, - geom_line, geom_point, geom_smooth, geom_step, geom_text, - ggplot, ggtitle, labs, scale_color_gradient, - scale_color_manual, theme, theme_bw, theme_minimal) +from plotnine import ( + aes, + annotate, + element_text, + geom_abline, + geom_errorbarh, + geom_line, + geom_point, + geom_smooth, + geom_step, + geom_text, + ggplot, + ggtitle, + labs, + scale_color_gradient, + scale_color_manual, + theme, + theme_bw, + theme_minimal, +) from sklearn.cluster import KMeans -from sklearn.decomposition import PCA from sklearn.metrics import silhouette_score from sklearn.metrics.pairwise import euclidean_distances from sklearn.preprocessing import StandardScaler from sksurv.ensemble import RandomSurvivalForest from sksurv.metrics import concordance_index_censored -from umap import UMAP try: from geomloss import SamplesLoss From 31197d12c845ca66cc33f0c2d20fc679f7eaf9e9 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 10:01:39 +0200 Subject: [PATCH 15/32] fix isort --- flexynesis/__main__.py | 21 ++++++------------ flexynesis/data.py | 3 ++- flexynesis/inference.py | 3 ++- flexynesis/main.py | 3 ++- flexynesis/utils.py | 49 +++++++++-------------------------------- 5 files changed, 24 insertions(+), 55 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index b7bed66b..96bceb08 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -869,14 +869,10 @@ def main(): # Continue to evaluation section (skip training) # ------------- Heavy imports only when training ------------- - from .utils import ( - evaluate_baseline_performance, - evaluate_baseline_survival_performance, - evaluate_wrapper, - get_device_memory_info, - get_optimal_device, - get_predicted_labels, - ) + from .utils import (evaluate_baseline_performance, + evaluate_baseline_survival_performance, + evaluate_wrapper, get_device_memory_info, + get_optimal_device, get_predicted_labels) if not (args.pretrained_model and args.artifacts and args.data_path_test): import json @@ -891,7 +887,6 @@ def main(): from .data import STRING, DataImporter, MultiOmicDatasetNW from .main import FineTuner, HyperparameterTuning from .models.crossmodal_pred import CrossModalPred - # models from .models.direct_pred import DirectPred from .models.gnn_early import GNN @@ -1508,11 +1503,9 @@ def main(): elif args.safetensors: import numpy as np - from sklearn.preprocessing import ( - LabelEncoder, - OrdinalEncoder, - StandardScaler, - ) + from sklearn.preprocessing import (LabelEncoder, + OrdinalEncoder, + StandardScaler) json_ready = { "schema_version": artifacts["schema_version"], diff --git a/flexynesis/data.py b/flexynesis/data.py index ff379f4a..86425954 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -795,7 +795,8 @@ def import_data(self): # Create covariates matrix if needed if "covariates" in self.modalities and labels_df is not None: - from flexynesis.utils import create_covariate_matrix, get_variable_types + from flexynesis.utils import (create_covariate_matrix, + get_variable_types) covariate_vars = self.artifacts.get("covariate_vars", []) if covariate_vars: diff --git a/flexynesis/inference.py b/flexynesis/inference.py index d469df03..361bb1b2 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -181,7 +181,8 @@ def check_file_type(file_path): def _deserialize_json_artifacts(artifacts): import numpy as np - from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, StandardScaler + from sklearn.preprocessing import (LabelEncoder, OrdinalEncoder, + StandardScaler) # Rebuild sklearn objects expected by inference code. deserialized = dict(artifacts) diff --git a/flexynesis/main.py b/flexynesis/main.py index f2dd603b..4f37dc8a 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -10,7 +10,8 @@ from IPython.display import display from lightning import Callback, seed_everything from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar -from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme +from lightning.pytorch.callbacks.progress.rich_progress import \ + RichProgressBarTheme from sklearn.model_selection import KFold from skopt import Optimizer from skopt.space import Categorical, Integer, Real diff --git a/flexynesis/utils.py b/flexynesis/utils.py index efee92cf..318014f5 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -12,24 +12,13 @@ from scipy.stats import kruskal, linregress, mannwhitneyu, pearsonr from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor -from sklearn.feature_selection import ( - SelectFromModel, - mutual_info_classif, - mutual_info_regression, -) -from sklearn.metrics import ( - adjusted_mutual_info_score, - adjusted_rand_score, - average_precision_score, - balanced_accuracy_score, - classification_report, - cohen_kappa_score, - f1_score, - mean_squared_error, - precision_recall_curve, - roc_auc_score, - roc_curve, -) +from sklearn.feature_selection import (SelectFromModel, mutual_info_classif, + mutual_info_regression) +from sklearn.metrics import (adjusted_mutual_info_score, adjusted_rand_score, + average_precision_score, balanced_accuracy_score, + classification_report, cohen_kappa_score, + f1_score, mean_squared_error, + precision_recall_curve, roc_auc_score, roc_curve) from sklearn.model_selection import GridSearchCV, KFold from sklearn.neighbors import NearestNeighbors from sklearn.preprocessing import label_binarize @@ -50,26 +39,10 @@ from lifelines import CoxPHFitter, KaplanMeierFitter from lifelines.statistics import logrank_test, multivariate_logrank_test from lifelines.utils import concordance_index -from plotnine import ( - aes, - annotate, - element_text, - geom_abline, - geom_errorbarh, - geom_line, - geom_point, - geom_smooth, - geom_step, - geom_text, - ggplot, - ggtitle, - labs, - scale_color_gradient, - scale_color_manual, - theme, - theme_bw, - theme_minimal, -) +from plotnine import (aes, annotate, element_text, geom_abline, geom_errorbarh, + geom_line, geom_point, geom_smooth, geom_step, geom_text, + ggplot, ggtitle, labs, scale_color_gradient, + scale_color_manual, theme, theme_bw, theme_minimal) from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from sklearn.metrics.pairwise import euclidean_distances From 06155d767ec38c2b3f584d0e3c082a7e2e869b3b Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 10:04:36 +0200 Subject: [PATCH 16/32] fix do not assign a lambda expression, use a def --- flexynesis/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 318014f5..d8a29424 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -100,11 +100,12 @@ def get_color_mapping(labels): colors.extend(colors) colors = colors[:n] - to_hex = lambda c: "#%02x%02x%02x" % ( - int(c[0] * 255), - int(c[1] * 255), - int(c[2] * 255), - ) + def to_hex(c): + return "#%02x%02x%02x" % ( + int(c[0] * 255), + int(c[1] * 255), + int(c[2] * 255), + ) color_hex = [to_hex(c) for c in colors] return dict(zip(unique_labels, color_hex)) From a70a745db6c6ee58f7ac979e0bb69a6e29fab6b3 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 10:06:40 +0200 Subject: [PATCH 17/32] remove unused --- flexynesis/__main__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 96bceb08..fd9c7565 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -879,7 +879,6 @@ def main(): import tracemalloc import pandas as pd - import psutil import torch from safetensors.torch import save_file From 420aa877e718928262e43b674e4882a10076319a Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 10:19:07 +0200 Subject: [PATCH 18/32] fix may be undefined or defined from ... --- flexynesis/models/crossmodal_pred.py | 2 +- flexynesis/models/direct_pred.py | 2 +- flexynesis/models/supervised_vae.py | 2 +- flexynesis/models/triplet_encoder.py | 2 +- flexynesis/utils.py | 1 + 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index d4959582..0a5d733e 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -9,7 +9,7 @@ from torch.nn import functional as F from torch.utils.data import DataLoader -from ..modules import * +from ..modules import MLP, Decoder, Encoder, cox_ph_loss class CrossModalPred(pl.LightningModule): diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 3ee56071..7cbb8fa7 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -7,7 +7,7 @@ from torch.nn import functional as F from torch.utils.data import DataLoader -from ..modules import * +from ..modules import MLP, cox_ph_loss from ..utils import to_device_safe diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index 7ba5ac7f..b6efda6e 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -10,7 +10,7 @@ from torch.nn import functional as F from torch.utils.data import DataLoader -from ..modules import * +from ..modules import MLP, Decoder, Encoder, cox_ph_loss # Supervised Variational Auto-encoder that can train one or more layers of omics datasets diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 064696df..6928c134 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from ..data import TripletMultiOmicDataset -from ..modules import * +from ..modules import MLP, cox_ph_loss class MultiTripletNetwork(pl.LightningModule): diff --git a/flexynesis/utils.py b/flexynesis/utils.py index d8a29424..3066b9a0 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -106,6 +106,7 @@ def to_hex(c): int(c[1] * 255), int(c[2] * 255), ) + color_hex = [to_hex(c) for c in colors] return dict(zip(unique_labels, color_hex)) From 5677fbf677cc648e36bb12cfb098527b7b00324c Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 10:27:25 +0200 Subject: [PATCH 19/32] fix remaining small fixes except length --- flexynesis/__main__.py | 2 -- flexynesis/data.py | 3 +-- flexynesis/inference.py | 6 +++++- flexynesis/main.py | 2 +- flexynesis/models/gnn_early.py | 6 +++--- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index fd9c7565..72ffd468 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -1247,8 +1247,6 @@ def main(): steps_or_samples=25, method=explainer, ) - import pandas as pd - df_imp = pd.concat( [model.feature_importances[x] for x in model.target_variables], ignore_index=True, diff --git a/flexynesis/data.py b/flexynesis/data.py index 86425954..e6a51dca 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -552,7 +552,6 @@ def encode_labels(self, df): label_mappings = {} def encode_column(series): - nonlocal label_mappings # Declare as nonlocal so that we can modify it # Fill NA values with 'missing' # series = series.fillna('missing') if series.name not in self.encoders: @@ -1111,7 +1110,7 @@ def __len__(self): def get_label_indices(self, labels_array): # Filter out NaNs for a clean set of valid classes - valid_labels = [l for l in labels_array if not np.isnan(l)] + valid_labels = [label for label in labels_array if not np.isnan(label)] labels_set = set(valid_labels) label_to_indices = { diff --git a/flexynesis/inference.py b/flexynesis/inference.py index 361bb1b2..13bc3740 100644 --- a/flexynesis/inference.py +++ b/flexynesis/inference.py @@ -130,7 +130,11 @@ def _resolve_input_dims(config, artifacts): ) input_dims = config.get("input_dims") if not input_dims: - input_dims = [len(feature_lists[l]) for l in layers if l in feature_lists] + input_dims = [ + len(feature_lists[layer_name]) + for layer_name in layers + if layer_name in feature_lists + ] config["input_dims"] = input_dims return config diff --git a/flexynesis/main.py b/flexynesis/main.py index 4f37dc8a..3f1315b8 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -205,7 +205,7 @@ def setup_trainer(self, params, current_step, total_steps, full_train=False): # when training on a full dataset; no cross-validation or no validation splits; # we don't do early stopping early_stop_callback = None - if self.early_stop_patience > 0 and full_train == False: + if self.early_stop_patience > 0 and not full_train: early_stop_callback = self.init_early_stopping() mycallbacks.append(early_stop_callback) diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index 62484b04..6c17bc59 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -604,7 +604,7 @@ def bytes_to_gb(bytes): if target_var in dataset.label_mappings else "" ) - for l in range(len(layers)): + for layer_idx, layer_name in enumerate(layers): # Extracting node feature attributes coming from different omic layers importances_array = imp[i].squeeze().detach().numpy() if importances_array.ndim == 1: @@ -613,7 +613,7 @@ def bytes_to_gb(bytes): ) else: importances = importances_array[ - :, l + :, layer_idx ] # Use the original indexing for 2D arrays df_list.append( pd.DataFrame( @@ -621,7 +621,7 @@ def bytes_to_gb(bytes): "target_variable": target_var, "target_class": i, "target_class_label": target_class_label, - "layer": layers[l], + "layer": layer_name, "name": features, "importance": importances, } From 5040f06c052d5ce90f1a64794f95b2662b1cb48e Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 19:15:06 +0200 Subject: [PATCH 20/32] add flake8 lint check in CI --- .github/workflows/lint.yml | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..b6b8fabc --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,44 @@ +name: Lint Python scripts + +on: + push: + paths: + - '**.py' + pull_request: + paths: + - '**.py' + +jobs: + lint: + name: Lint Python scripts + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ['3.10', '3.11'] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache pip + uses: actions/cache@v4 + id: cache-pip + with: + path: ~/.cache/pip + key: pip_cache_py_${{ matrix.python-version }}_${{ hashFiles('**/pyproject.toml') }} + + - name: Install linters + run: | + python -m pip install --upgrade pip + pip install isort flake8 + + - name: Run isort + run: isort --check-only --diff . + + - name: Run flake8 + run: flake8 . From e185a994027d084406136505574f96662b1a8169 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 20:24:05 +0200 Subject: [PATCH 21/32] fix max line length using claude --- flexynesis/__main__.py | 336 ++++++++++++++++++--------- flexynesis/data.py | 39 +++- flexynesis/main.py | 22 +- flexynesis/models/crossmodal_pred.py | 58 +++-- flexynesis/models/direct_pred.py | 6 +- flexynesis/models/gnn_early.py | 12 +- flexynesis/models/supervised_vae.py | 3 +- flexynesis/models/triplet_encoder.py | 48 ++-- flexynesis/modules.py | 11 +- flexynesis/utils.py | 114 +++++---- 10 files changed, 427 insertions(+), 222 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 72ffd468..0ced9436 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -13,19 +13,26 @@ def print_test_installation(): print("Test Installation:") print(" # Download and extract test dataset") print( - " curl -L -o dataset1.tgz https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dataset1.tgz" + " curl -L -o dataset1.tgz " + "https://bimsbstatic.mdc-berlin.de/akalin/buyar/" + "flexynesis-benchmark-datasets/dataset1.tgz" ) print(" tar -xzvf dataset1.tgz") print() print(" # Test the installation (should finish within a minute on a typical CPU)") print( - " flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --hpo_iter 1 --features_top_percentile 5 --data_types gex,cnv" + " flexynesis --data_path dataset1 --model_class DirectPred " + "--target_variables Erlotinib --hpo_iter 1 " + "--features_top_percentile 5 --data_types gex,cnv" ) def print_help(): print( - "usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} --data_types DATA_TYPES" + "usage: flexynesis [-h] --data_path DATA_PATH --model_class " + "{DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN," + "RandomForest,SVM,XGBoost,RandomSurvivalForest} --data_types " + "DATA_TYPES" ) print() print("Flexynesis model training interface") @@ -37,42 +44,63 @@ def print_help(): " (Required) Path to the folder with train/test data files" ) print( - " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" + " --model_class {DirectPred,supervised_vae,MultiTripletNetwork," + "CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) print(" (Required) The kind of model class to instantiate") print(" --data_types DATA_TYPES") print( - " (Required) Which omic data matrices to work on, comma-separated: e.g. 'gex,cnv'" + " (Required) Which omic data matrices to work on, " + "comma-separated: e.g. 'gex,cnv'" ) print( - " --hpo_iter HPO_ITER Number of iterations for hyperparameter optimisation (default: 100)" + " --hpo_iter HPO_ITER Number of iterations for hyperparameter " + "optimisation (default: 100)" ) print(" --device {auto,cuda,mps,cpu}") print( - " Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" + " Device type: 'auto' (automatic detection), 'cuda' " + "(NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" ) print( - " --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available." + " --use_gpu (Optional) DEPRECATED: Use --device instead. If " + "set, attempts to use CUDA/GPU if available." ) print() print_test_installation() print() print( - " See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/." + " See the documentation for more details at " + "https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/." ) def print_full_help(): print( - "usage: flexynesis [-h] --data_path DATA_PATH --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest} " - "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] [--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] [--surv_time_var SURV_TIME_VAR] " - "[--config_path CONFIG_PATH] [--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] [--finetuning_samples FINETUNING_SAMPLES] " - "[--variance_threshold VARIANCE_THRESHOLD] [--correlation_threshold CORRELATION_THRESHOLD] [--restrict_to_features RESTRICT_TO_FEATURES] " - "[--subsample SUBSAMPLE] [--features_min FEATURES_MIN] [--features_top_percentile FEATURES_TOP_PERCENTILE] --data_types DATA_TYPES " - "[--input_layers INPUT_LAYERS] [--output_layers OUTPUT_LAYERS] [--outdir OUTDIR] [--prefix PREFIX] [--log_transform {True,False}] " - "[--early_stop_patience EARLY_STOP_PATIENCE] [--hpo_patience HPO_PATIENCE] [--val_size VAL_SIZE] [--use_cv] [--use_loss_weighting {True,False}] " - "[--evaluate_baseline_performance] [--threads THREADS] [--num_workers NUM_WORKERS] [--device {auto,cuda,mps,cpu}] [--use_gpu] [--feature_importance_method {IntegratedGradients,GradientShap,Both}] " - "[--disable_marker_finding] [--string_organism STRING_ORGANISM] [--string_node_name {gene_name,gene_id}] [--user_graph USER_GRAPH] [--safetensors]" + "usage: flexynesis [-h] --data_path DATA_PATH --model_class " + "{DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN," + "RandomForest,SVM,XGBoost,RandomSurvivalForest} " + "[--gnn_conv_type {GC,GCN,SAGE}] [--target_variables TARGET_VARIABLES] " + "[--covariates COVARIATES] [--surv_event_var SURV_EVENT_VAR] " + "[--surv_time_var SURV_TIME_VAR] [--config_path CONFIG_PATH] " + "[--fusion_type {early,intermediate}] [--hpo_iter HPO_ITER] " + "[--finetuning_samples FINETUNING_SAMPLES] " + "[--variance_threshold VARIANCE_THRESHOLD] " + "[--correlation_threshold CORRELATION_THRESHOLD] " + "[--restrict_to_features RESTRICT_TO_FEATURES] [--subsample SUBSAMPLE] " + "[--features_min FEATURES_MIN] [--features_top_percentile " + "FEATURES_TOP_PERCENTILE] --data_types DATA_TYPES [--input_layers " + "INPUT_LAYERS] [--output_layers OUTPUT_LAYERS] [--outdir OUTDIR] " + "[--prefix PREFIX] [--log_transform {True,False}] " + "[--early_stop_patience EARLY_STOP_PATIENCE] [--hpo_patience " + "HPO_PATIENCE] [--val_size VAL_SIZE] [--use_cv] " + "[--use_loss_weighting {True,False}] " + "[--evaluate_baseline_performance] [--threads THREADS] " + "[--num_workers NUM_WORKERS] [--device {auto,cuda,mps,cpu}] [--use_gpu] " + "[--feature_importance_method {IntegratedGradients,GradientShap,Both}] " + "[--disable_marker_finding] [--string_organism STRING_ORGANISM] " + "[--string_node_name {gene_name,gene_id}] [--user_graph USER_GRAPH] " + "[--safetensors]" ) print() print("Flexynesis model training interface") @@ -82,7 +110,8 @@ def print_full_help(): # --- NEW: inference-only flags (keep in full help) --- print(" --pretrained_model PRETRAINED_MODEL") print( - " Use a saved .pth/.safetensors model for inference (skip training)" + " Use a saved .pth/.safetensors model for inference " + "(skip training)" ) print(" --artifacts ARTIFACTS") print(" Path to training-time artifacts .joblib or .json") @@ -97,32 +126,40 @@ def print_full_help(): " (Required) Path to the folder with train/test data files" ) print( - " --model_class {DirectPred,supervised_vae,MultiTripletNetwork,CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" + " --model_class {DirectPred,supervised_vae,MultiTripletNetwork," + "CrossModalPred,GNN,RandomForest,SVM,XGBoost,RandomSurvivalForest}" ) print(" (Required) The kind of model class to instantiate") print(" --gnn_conv_type {GC,GCN,SAGE}") print( - " If model_class is set to GNN, choose which graph convolution type to use" + " If model_class is set to GNN, choose which graph " + "convolution type to use" ) print(" --target_variables TARGET_VARIABLES") print( - " (Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple" + " (Optional if survival variables are not set to " + "None). Which variables in 'clin.csv' to use for predictions, " + "comma-separated if multiple" ) print(" --covariates COVARIATES") print( - " Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple" + " Which variables in 'clin.csv' to be used as feature " + "covariates, comma-separated if multiple" ) print(" --surv_event_var SURV_EVENT_VAR") print( - " Which column in 'clin.csv' to use as event/status indicator for survival modeling" + " Which column in 'clin.csv' to use as " + "event/status indicator for survival modeling" ) print(" --surv_time_var SURV_TIME_VAR") print( - " Which column in 'clin.csv' to use as time/duration indicator for survival modeling" + " Which column in 'clin.csv' to use as time/duration " + "indicator for survival modeling" ) print(" --config_path CONFIG_PATH") print( - " Optional path to an external hyperparameter configuration file in YAML format." + " Optional path to an external hyperparameter " + "configuration file in YAML format." ) print(" --fusion_type {early,intermediate}") print( @@ -133,23 +170,29 @@ def print_full_help(): ) print(" --finetuning_samples FINETUNING_SAMPLES") print( - " Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning (default: 0)" + " Number of samples from the test dataset to use " + "for fine-tuning the model. Set to 0 to disable fine-tuning " + "(default: 0)" ) print(" --variance_threshold VARIANCE_THRESHOLD") print( - " Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)" + " Variance threshold (as percentile) to drop low " + "variance features (default is 1; set to 0 for no variance filtering)" ) print(" --correlation_threshold CORRELATION_THRESHOLD") print( - " Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)" + " Correlation threshold to drop highly redundant " + "features (default is 0.8; set to 1 for no redundancy filtering)" ) print(" --restrict_to_features RESTRICT_TO_FEATURES") print( - " Restrict the analysis to the list of features provided by the user (default is None)" + " Restrict the analysis to the list of features " + "provided by the user (default is None)" ) print(" --subsample SUBSAMPLE") print( - " Downsample training set to randomly drawn N samples for training. Disabled when set to 0 (default: 0)" + " Downsample training set to randomly drawn N samples " + "for training. Disabled when set to 0 (default: 0)" ) print(" --features_min FEATURES_MIN") print( @@ -157,7 +200,9 @@ def print_full_help(): ) print(" --features_top_percentile FEATURES_TOP_PERCENTILE") print( - " Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection (default: 20)" + " Top percentile features (among the features " + "remaining after variance filtering and data cleanup) to retain after " + "feature selection (default: 20)" ) print(" --data_types DATA_TYPES") print( @@ -165,63 +210,84 @@ def print_full_help(): ) print(" --input_layers INPUT_LAYERS") print( - " If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple" + " If model_class is set to CrossModalPred, choose " + "which data types to use as input/encoded layers. Comma-separated if " + "multiple" ) print(" --output_layers OUTPUT_LAYERS") print( - " If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple" + " If model_class is set to CrossModalPred, choose " + "which data types to use as output/decoded layers. Comma-separated if " + "multiple" ) print( - " --outdir OUTDIR Path to the output folder to save the model outputs (default: current working directory)" + " --outdir OUTDIR Path to the output folder to save the model " + "outputs (default: current working directory)" ) print(" --prefix PREFIX Job prefix to use for output files (default: 'job')") print(" --log_transform {True,False}") print( - " whether to apply log-transformation to input data matrices (default: False)" + " whether to apply log-transformation to input " + "data matrices (default: False)" ) print(" --early_stop_patience EARLY_STOP_PATIENCE") print( - " How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)" + " How many epochs to wait when no improvements in " + "validation loss is observed (default 10; set to -1 to disable early " + "stopping)" ) print(" --hpo_patience HPO_PATIENCE") print( - " How many hyperparameter optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)" + " How many hyperparameter optimisation iterations " + "to wait for when no improvements are observed (default is 10; set to 0 " + "to disable early stopping)" ) print( - " --val_size VAL_SIZE Proportion of training data to be used as validation split (default: 0.2)" + " --val_size VAL_SIZE Proportion of training data to be used as " + "validation split (default: 0.2)" ) print( - " --use_cv (Optional) If set, a 5-fold cross-validation training will be done. Otherwise, a single training on 80 percent of the dataset is done." + " --use_cv (Optional) If set, a 5-fold cross-validation " + "training will be done. Otherwise, a single training on 80 percent of " + "the dataset is done." ) print(" --use_loss_weighting {True,False}") print( - " whether to apply loss-balancing using uncertainty weights method (default: True)" + " whether to apply loss-balancing using uncertainty " + "weights method (default: True)" ) print(" --evaluate_baseline_performance") print( - " whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset" + " whether to run Random Forest + SVMs to see the " + "performance of off-the-shelf tools on the same dataset" ) print( - " --threads THREADS (Optional) How many threads to use when using CPU (default is 4)" + " --threads THREADS (Optional) How many threads to use when using " + "CPU (default is 4)" ) print(" --num_workers NUM_WORKERS") print( - " (Optional) How many workers to use for model training (default is 0)" + " (Optional) How many workers to use for model " + "training (default is 0)" ) print(" --device {auto,cuda,mps,cpu}") print( - " Device type: 'auto' (automatic detection), 'cuda' (NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" + " Device type: 'auto' (automatic detection), 'cuda' " + "(NVIDIA GPU), 'mps' (Apple Silicon), 'cpu' (default: auto)" ) print( - " --use_gpu (Optional) DEPRECATED: Use --device instead. If set, attempts to use CUDA/GPU if available." + " --use_gpu (Optional) DEPRECATED: Use --device instead. If " + "set, attempts to use CUDA/GPU if available." ) print(" --feature_importance_method {IntegratedGradients,GradientShap,Both}") print( - " Choose feature importance score method (default: IntegratedGradients)" + " Choose feature importance score method " + "(default: IntegratedGradients)" ) print(" --disable_marker_finding") print( - " (Optional) If set, marker discovery after model training is disabled." + " (Optional) If set, marker discovery after model " + "training is disabled." ) print(" --string_organism STRING_ORGANISM") print(" STRING DB organism id. (default: 9606)") @@ -237,14 +303,17 @@ def print_full_help(): ) print(" --safetensors") print( - " If set, the model will be saved in the SafeTensors format and the artifacts saved as JSON." + " If set, the model will be saved in the " + "SafeTensors format and the artifacts saved as JSON." ) print(" Default is False.") print() print_test_installation() print() print( - " See the documentation for more details at https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/getting_started/." + " See the documentation for more details at " + "https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis/site/" + "getting_started/." ) @@ -414,7 +483,9 @@ def main(): Inference only: ``` - flexynesis --pretrained_model ./outputs/best_model.pth --artifacts ./outputs/artifacts.joblib --data_path_test ./data_test --outdir ./predictions --prefix run1 + flexynesis --pretrained_model ./outputs/best_model.pth --artifacts \ + ./outputs/artifacts.joblib --data_path_test ./data_test --outdir \ + ./predictions --prefix run1 ``` """ @@ -470,31 +541,36 @@ def main(): "--target_variables", type=str, default=None, - help="(Optional if survival variables are not set to None). Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", + help="(Optional if survival variables are not set to None). Which variables " + "in 'clin.csv' to use for predictions, comma-separated if multiple", ) parser.add_argument( "--covariates", type=str, default=None, - help="Which variables in 'clin.csv' to be used as feature covariates, comma-separated if multiple", + help="Which variables in 'clin.csv' to be used as feature covariates, " + "comma-separated if multiple", ) parser.add_argument( "--surv_event_var", type=str, default=None, - help="Which column in 'clin.csv' to use as event/status indicator for survival modeling", + help="Which column in 'clin.csv' to use as event/status indicator for " + "survival modeling", ) parser.add_argument( "--surv_time_var", type=str, default=None, - help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling", + help="Which column in 'clin.csv' to use as time/duration indicator for " + "survival modeling", ) parser.add_argument( "--config_path", type=str, default=None, - help="Optional path to an external hyperparameter configuration file in YAML format.", + help="Optional path to an external hyperparameter configuration file in " + "YAML format.", ) parser.add_argument( "--fusion_type", @@ -513,19 +589,22 @@ def main(): "--finetuning_samples", type=int, default=0, - help="Number of samples from the test dataset to use for fine-tuning the model. Set to 0 to disable fine-tuning", + help="Number of samples from the test dataset to use for fine-tuning the " + "model. Set to 0 to disable fine-tuning", ) parser.add_argument( "--variance_threshold", type=float, default=1, - help="Variance threshold (as percentile) to drop low variance features (default is 1; set to 0 for no variance filtering)", + help="Variance threshold (as percentile) to drop low variance features " + "(default is 1; set to 0 for no variance filtering)", ) parser.add_argument( "--correlation_threshold", type=float, default=0.8, - help="Correlation threshold to drop highly redundant features (default is 0.8; set to 1 for no redundancy filtering)", + help="Correlation threshold to drop highly redundant features (default is " + "0.8; set to 1 for no redundancy filtering)", ) parser.add_argument( "--restrict_to_features", @@ -549,7 +628,8 @@ def main(): "--features_top_percentile", type=float, default=20, - help="Top percentile features (among the features remaining after variance filtering and data cleanup) to retain after feature selection", + help="Top percentile features (among the features remaining after variance " + "filtering and data cleanup) to retain after feature selection", ) parser.add_argument( "--data_types", @@ -561,13 +641,15 @@ def main(): "--input_layers", type=str, default=None, - help="If model_class is set to CrossModalPred, choose which data types to use as input/encoded layers. Comma-separated if multiple", + help="If model_class is set to CrossModalPred, choose which data types to " + "use as input/encoded layers. Comma-separated if multiple", ) parser.add_argument( "--output_layers", type=str, default=None, - help="If model_class is set to CrossModalPred, choose which data types to use as output/decoded layers. Comma-separated if multiple", + help="If model_class is set to CrossModalPred, choose which data types to " + "use as output/decoded layers. Comma-separated if multiple", ) parser.add_argument( "--outdir", @@ -592,13 +674,16 @@ def main(): "--early_stop_patience", type=int, default=10, - help="How many epochs to wait when no improvements in validation loss is observed (default 10; set to -1 to disable early stopping)", + help="How many epochs to wait when no improvements in validation loss is " + "observed (default 10; set to -1 to disable early stopping)", ) parser.add_argument( "--hpo_patience", type=int, default=20, - help="How many hyperparamater optimisation iterations to wait for when no improvements are observed (default is 10; set to 0 to disable early stopping)", + help="How many hyperparamater optimisation iterations to wait for when no " + "improvements are observed (default is 10; set to 0 to disable early " + "stopping)", ) parser.add_argument( "--val_size", @@ -609,7 +694,8 @@ def main(): parser.add_argument( "--use_cv", action="store_true", - help="(Optional) If set, the a 5-fold cross-validation training will be done. Otherwise, a single trainig on 80 percent of the dataset is done.", + help="(Optional) If set, the a 5-fold cross-validation training will be done. " + "Otherwise, a single trainig on 80 percent of the dataset is done.", ) parser.add_argument( "--use_loss_weighting", @@ -621,7 +707,8 @@ def main(): parser.add_argument( "--evaluate_baseline_performance", action="store_true", - help="whether to run Random Forest + SVMs to see the performance of off-the-shelf tools on the same dataset", + help="whether to run Random Forest + SVMs to see the performance of " + "off-the-shelf tools on the same dataset", ) parser.add_argument( "--threads", @@ -677,15 +764,16 @@ def main(): "--user_graph", type=str, default=None, - help="Path to user-provided gene-gene interaction network file. " - "Must have at least 3 columns: GeneA, GeneB, Score. " - "If provided, this will be used instead of STRING DB.", + help="Path to user-provided gene-gene interaction network file. Must have " + "at least 3 columns: GeneA, GeneB, Score. If provided, this will be " + "used instead of STRING DB.", ) # safetensors args parser.add_argument( "--safetensors", action="store_true", - help="If set, use SafeTensors + JSON artifacts for save/load (training and inference). Default is False.", + help="If set, use SafeTensors + JSON artifacts for save/load (training and " + "inference). Default is False.", ) # NEW: inference flags parser.add_argument( @@ -756,7 +844,8 @@ def main(): if args.device != "auto": device_preference = args.device print( - f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}." + f"[WARN] Both --use_gpu and --device {args.device} specified. " + f"Using --device {args.device}." ) else: # Let auto-detection find the best GPU device (CUDA or MPS) @@ -775,7 +864,8 @@ def main(): if args.safetensors and model_format != "safetensors": raise ValueError( - f"[ERROR] The file {args.pretrained_model} is not a valid safetensors file." + f"[ERROR] The file {args.pretrained_model} is not a valid safetensors " + f"file." ) # Route to safetensors reconstruction or standard torch.load @@ -868,29 +958,29 @@ def main(): print(f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples") # Continue to evaluation section (skip training) - # ------------- Heavy imports only when training ------------- - from .utils import (evaluate_baseline_performance, - evaluate_baseline_survival_performance, - evaluate_wrapper, get_device_memory_info, - get_optimal_device, get_predicted_labels) - + # ------------- Heavy imports only when training or in inference with evaluation --- if not (args.pretrained_model and args.artifacts and args.data_path_test): - import json - import tracemalloc + import json # noqa: F401 + import tracemalloc # noqa: F401 - import pandas as pd - import torch - from safetensors.torch import save_file + import pandas as pd # noqa: F401 + import torch # noqa: F401 + from safetensors.torch import save_file # noqa: F401 # data + utils - from .data import STRING, DataImporter, MultiOmicDatasetNW - from .main import FineTuner, HyperparameterTuning - from .models.crossmodal_pred import CrossModalPred + from .data import (STRING, DataImporter, # noqa: F401 + MultiOmicDatasetNW) + from .main import FineTuner, HyperparameterTuning # noqa: F401 + from .models.crossmodal_pred import CrossModalPred # noqa: F401 # models - from .models.direct_pred import DirectPred - from .models.gnn_early import GNN - from .models.supervised_vae import supervised_vae - from .models.triplet_encoder import MultiTripletNetwork + from .models.direct_pred import DirectPred # noqa: F401 + from .models.gnn_early import GNN # noqa: F401 + from .models.supervised_vae import supervised_vae # noqa: F401 + from .models.triplet_encoder import MultiTripletNetwork # noqa: F401 + from .utils import evaluate_baseline_performance # noqa: F401 + from .utils import (evaluate_baseline_survival_performance, + evaluate_wrapper, get_device_memory_info, + get_optimal_device, get_predicted_labels) # --------- Sanity checks on args --------- # 1. survival variables consistency @@ -903,15 +993,18 @@ def main(): if args.model_class not in ("supervised_vae", "CrossModalPred"): if not any([args.target_variables, args.surv_event_var]): parser.error( - "When selecting a model other than 'supervised_vae' or 'CrossModalPred', you must provide at least one of --target_variables, or survival variables (--surv_event_var and --surv_time_var)" + "When selecting a model other than 'supervised_vae' or " + "'CrossModalPred', you must provide at least one of " + "--target_variables, or survival variables (--surv_event_var and " + "--surv_time_var)" ) # 3. Check for compatibility of fusion_type with CrossModalPred if args.fusion_type == "early": if args.model_class == "CrossModalPred": parser.error( - "The 'CrossModalPred' model cannot be used with early fusion type. " - "Use --fusion_type intermediate instead." + "The 'CrossModalPred' model cannot be used with early fusion " + "type. Use --fusion_type intermediate instead." ) # 4. Handle device selection with MPS support @@ -923,10 +1016,12 @@ def main(): ) # If --device is not explicitly set (still at default auto), let auto-detection handle it if args.device != "auto": - # If both --use_gpu and explicit --device are provided, respect --device but warn + # If both --use_gpu and explicit --device are provided, respect + # --device but warn device_preference = args.device print( - f"[WARN] Both --use_gpu and --device {args.device} specified. Using --device {args.device}." + f"[WARN] Both --use_gpu and --device {args.device} specified. " + f"Using --device {args.device}." ) else: # Let auto-detection find the best GPU device (CUDA or MPS) @@ -949,7 +1044,8 @@ def main(): if args.model_class == "GNN": if not args.gnn_conv_type: warnings.warn( - "\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). Falling back to GC !!!\n" + "\n\n!!! When running GNN, set --gnn_conv_type (GC/GCN/SAGE). " + "Falling back to GC !!!\n" ) time.sleep(3) gnn_conv_type = "GC" @@ -967,13 +1063,15 @@ def main(): input_layers = input_layers.strip().split(",") if not all(layer in datatypes for layer in input_layers): raise ValueError( - f"Input layers {input_layers} are not a valid subset of the data types: ({datatypes})." + f"Input layers {input_layers} are not a valid subset of the " + f"data types: ({datatypes})." ) if args.output_layers: output_layers = output_layers.strip().split(",") if not all(layer in datatypes for layer in output_layers): raise ValueError( - f"Output layers {output_layers} are not a valid subset of the data types: ({datatypes})." + f"Output layers {output_layers} are not a valid subset of the " + f"data types: ({datatypes})." ) # paths @@ -1011,7 +1109,8 @@ def main(): if args.covariates: if args.model_class == "GNN": # Covariates not yet supported for GNNs warnings.warn( - "\n\n!!! Covariates are currently not supported for GNN models, they will be ignored. !!!\n" + "\n\n!!! Covariates are currently not supported for GNN models, " + "they will be ignored. !!!\n" ) time.sleep(3) covariates = None @@ -1045,7 +1144,8 @@ def main(): from xgboost import XGBClassifier # noqa: F401 except Exception: raise ImportError( - "XGBoost is not available. On macOS, install the OpenMP runtime: brew install libomp" + "XGBoost is not available. On macOS, install the OpenMP runtime: " + "brew install libomp" ) if args.model_class in ["RandomForest", "SVM", "XGBoost"]: if args.target_variables: @@ -1076,13 +1176,15 @@ def main(): sys.exit(0) else: raise ValueError( - "At least one target variable is required to run RandomForest/SVM/XGBoost models. Set --target_variables" + "At least one target variable is required to run " + "RandomForest/SVM/XGBoost models. Set --target_variables" ) if args.model_class == "RandomSurvivalForest": if args.surv_event_var and args.surv_time_var: print( - f"Training {args.model_class} on survival variables: {args.surv_event_var} and {args.surv_time_var}" + f"Training {args.model_class} on survival variables: " + f"{args.surv_event_var} and {args.surv_time_var}" ) metrics, predictions = evaluate_baseline_survival_performance( train_dataset, @@ -1122,7 +1224,9 @@ def main(): print(f"[INFO] Loaded {len(graph_df)} interactions from user graph") else: # Fallback to STRING DB (default behavior) - print("[INFO] No user graph provided. Using STRING DB network (default)") + print( + "[INFO] No user graph provided. Using STRING DB network (default)" + ) obj = STRING( os.path.join(args.data_path, "_".join(["processed", args.prefix])), args.string_organism, @@ -1237,7 +1341,7 @@ def main(): for explainer in explainers: print( - "[INFO] Computing variable importance scores using explainer:", + "[INFO] Computing variable importance scores using explainer: ", explainer, ) for var in model.target_variables: @@ -1248,7 +1352,10 @@ def main(): method=explainer, ) df_imp = pd.concat( - [model.feature_importances[x] for x in model.target_variables], + [ + model.feature_importances[x] + for x in model.target_variables + ], ignore_index=True, ) df_imp["explainer"] = explainer @@ -1376,7 +1483,8 @@ def main(): if model.surv_event_var and model.surv_time_var: print( - "[INFO] Computing off-the-shelf method performance on survival variable:", + "[INFO] Computing off-the-shelf method performance on survival " + "variable:", model.surv_time_var, ) metrics_baseline_survival = evaluate_baseline_survival_performance( @@ -1463,7 +1571,8 @@ def main(): list(data_importer.train_features.keys()) if hasattr(data_importer, "train_features") else args.data_types.split(",") - ), # Use actual data structure keys (e.g. ['all'] for early fusion) + ), # Use actual data structure keys (e.g. ['all'] for early + # fusion "original_modalities": args.data_types.split( "," ), # Original modalities from CLI before concatenation @@ -1499,9 +1608,9 @@ def main(): print(f"[INFO] Wrote inference artifacts to {joblib_path}") elif args.safetensors: - import numpy as np - from sklearn.preprocessing import (LabelEncoder, - OrdinalEncoder, + import numpy as np # noqa: F401 + from sklearn.preprocessing import LabelEncoder # noqa: F401 + from sklearn.preprocessing import (OrdinalEncoder, StandardScaler) json_ready = { @@ -1528,7 +1637,8 @@ def main(): continue if not isinstance(scaler, StandardScaler): raise ValueError( - f"Unsupported scaler type for modality '{modality}': {type(scaler).__name__}." + f"Unsupported scaler type for modality '{modality}': " + f"{type(scaler).__name__}." ) scaler_dict = { "type": "StandardScaler", @@ -1569,7 +1679,9 @@ def main(): elif isinstance(encoder, OrdinalEncoder): encoder_dict = { "type": "OrdinalEncoder", - "categories": [cat.tolist() for cat in encoder.categories_], + "categories": [ + cat.tolist() for cat in encoder.categories_ + ], "handle_unknown": encoder.handle_unknown, "unknown_value": encoder.unknown_value, } diff --git a/flexynesis/data.py b/flexynesis/data.py index e6a51dca..e6f12831 100644 --- a/flexynesis/data.py +++ b/flexynesis/data.py @@ -36,7 +36,8 @@ class DataImporter: variance_threshold (float): The variance threshold for removing low-variance features. na_threshold (float): The threshold for removing features with too many NA values. string_organism (int): STRING organism (species) id (default: 9606 (human)). - string_node_name (str): The type of node names used in the graph. Available options: "gene_name", "gene_id" (default: "gene_name"). + string_node_name (str): The type of node names used in the graph. + Available options: "gene_name", "gene_id" (default: "gene_name"). Methods: import_data(): The primary method to orchestrate the data import and preprocessing workflow. It follows these steps: @@ -64,9 +65,10 @@ class DataImporter: Prepares the data for model input by cleaning, filtering, and selecting features and samples. select_features(dat): - Implements an unsupervised feature selection by ranking features by the Laplacian score, keeping the features at - the top percentile range and removing highly redundant features (optional) based on a correlation threshold, - while keeping a minimum number of top features as requested by the user. + Implements an unsupervised feature selection by ranking features by the Laplacian + score, keeping the features at the top percentile range and removing highly + redundant features (optional) based on a correlation threshold, while keeping a + minimum number of top features as requested by the user. harmonize(dat1, dat2): Aligns the feature sets of two datasets (e.g., training and testing) to have the same features. @@ -439,8 +441,12 @@ def cleanup_data(self, df_dict): original_samples_count = cleaned_dfs[key].shape[1] cleaned_dfs[key] = cleaned_dfs[key].loc[:, common_mask] removed_samples_count = original_samples_count - cleaned_dfs[key].shape[1] + removed_samples_pct = ( + removed_samples_count / original_samples_count * 100 + ) print( - f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples ({removed_samples_count / original_samples_count * 100:.2f}%)." + f"[INFO] DataFrame {key} - Removed {removed_samples_count} samples " + f"({removed_samples_pct:.2f}%)." ) # update feature logs from this process @@ -621,12 +627,14 @@ def check_sample_labels(dat, split): matching_samples = clin_samples.intersection(omics_samples) if not matching_samples: errors.append( - f"Error: No matching sample labels found between {split}/clin.csv and {split}/{file_name}.csv." + f"Error: No matching sample labels found between " + f"{split}/clin.csv and {split}/{file_name}.csv." ) elif len(matching_samples) < len(clin_samples): missing_samples = clin_samples - matching_samples warnings.append( - f"Warning: Some sample labels in {split}/clin.csv are missing in {split}/{file_name}.csv: {missing_samples}" + f"Warning: Some sample labels in {split}/clin.csv are " + f"missing in {split}/{file_name}.csv: {missing_samples}" ) def check_common_features(train_dat, test_dat): @@ -928,9 +936,12 @@ class MultiOmicDataset(Dataset): """A PyTorch dataset for multiomic data. Args: - dat (dict): A dictionary with keys corresponding to different types of data and values corresponding to matrices of the same shape. All matrices must have the same number of samples (rows). + dat (dict): A dictionary with keys corresponding to different types of data and values + corresponding to matrices of the same shape. All matrices must have the same number + of samples (rows). ann (data.frame): Data frame with samples on the rows, sample annotations on the columns - features (list or np.array): A 1D array of feature names with length equal to the number of columns in each matrix. + features (list or np.array): A 1D array of feature names with length equal to the + number of columns in each matrix. samples (list or np.array): A 1D array of sample names with length equal to the number of rows in each matrix. Returns: @@ -964,7 +975,9 @@ def __getitem__(self, index): Returns: A tuple of two elements: - 1. A dictionary with keys corresponding to the different types of data in the input dictionary `dat`, and values corresponding to the data for the given sample. + 1. A dictionary with keys corresponding to the different + types of data in the input dictionary `dat`, and values + corresponding to the data for the given sample. 2. The label for the given sample. """ subset_dat = {x: self.dat[x][index] for x in self.dat.keys()} @@ -1004,10 +1017,12 @@ def subset(self, indices): ) def get_feature_subset(self, feature_df): - """Get a subset of data matrices corresponding to specified features and concatenate them into a pandas DataFrame. + """Get a subset of data matrices corresponding to specified features and + concatenate them into a pandas DataFrame. Args: - feature_df (pandas.DataFrame): A DataFrame which contains at least two columns: 'layer' and 'name'. + feature_df (pandas.DataFrame): A DataFrame which contains at least + two columns: 'layer' and 'name'. Returns: A pandas DataFrame that concatenates the data matrices for the specified features from all layers. diff --git a/flexynesis/main.py b/flexynesis/main.py index 3f1315b8..e08abd80 100644 --- a/flexynesis/main.py +++ b/flexynesis/main.py @@ -58,8 +58,8 @@ class HyperparameterTuning: get_batch_space(min_size=16, max_size=128): Determines the batch size search space based on the dataset size. - setup_trainer(params, current_step, total_steps, full_train=False): Sets up the trainer with appropriate callbacks - and configurations for either full training or validation based training. + setup_trainer(params, current_step, total_steps, full_train=False): Sets up the trainer + with appropriate callbacks and configurations for either full training or validation based training. objective(params, current_step, total_steps, full_train=False): Evaluates a set of parameters to determine the performance of the model using the specified parameters. @@ -383,7 +383,8 @@ def perform_tuning(self, hpo_patience=0): # Early stopping condition if no_improvement_count >= hpo_patience & hpo_patience > 0: print( - f"No improvement in best loss for {hpo_patience} iterations, stopping hyperparameter optimisation early." + f"No improvement in best loss for {hpo_patience} iterations, " + "stopping hyperparameter optimisation early." ) break # Break out of the loop best_params_dict = ( @@ -392,7 +393,8 @@ def perform_tuning(self, hpo_patience=0): else None ) print( - f"[INFO] current best val loss: {best_loss}; best params: {best_params_dict} since {no_improvement_count} hpo iterations" + f"[INFO] current best val loss: {best_loss}; best params: " + f"{best_params_dict} since {no_improvement_count} hpo iterations" ) # Convert best parameters from list to dictionary and include epochs @@ -608,11 +610,16 @@ def run_experiments(self): val_loss[0]["val_loss"] ) # Adjust based on your validation output format epochs.append(stopped_epoch) - # print(f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, val_loss: {val_loss}, freeze {config}") + # print( + # f"[INFO] Finetuning ... training fold: {fold}, learning rate: {lr}, " + # f"val_loss: {val_loss}, freeze {config}" + # ) avg_val_loss = np.mean(fold_losses) avg_epochs = int(np.mean(epochs)) print( - f"[INFO] average 5-fold cross-validation loss {avg_val_loss} for learning rate: {lr} freeze {config}, average epochs {avg_epochs}" + f"[INFO] average 5-fold cross-validation loss {avg_val_loss} " + f"for learning rate: {lr} freeze {config}, " + f"average epochs {avg_epochs}" ) val_loss_results.append( { @@ -627,7 +634,8 @@ def run_experiments(self): best_config = min(val_loss_results, key=lambda x: x["average_val_loss"]) print( f"Best learning rate: {best_config['learning_rate']} and freeze {best_config['freeze']}", - f"with average validation loss: {best_config['average_val_loss']} and average epochs: {best_config['epochs']}", + f"with average validation loss: {best_config['average_val_loss']} " + f"and average epochs: {best_config['epochs']}", ) # build a final model using the best setup on all samples diff --git a/flexynesis/models/crossmodal_pred.py b/flexynesis/models/crossmodal_pred.py index 0a5d733e..0ddef098 100644 --- a/flexynesis/models/crossmodal_pred.py +++ b/flexynesis/models/crossmodal_pred.py @@ -20,8 +20,10 @@ class CrossModalPred(pl.LightningModule): The network also can be connected to one or more MLPs for outcome variable prediction. dataset: dictionary of data matrices - input_layers: which data modalities from `dataset` to encode (use a subset of keys from `dataset`) - output_layers: which data modalities are aimed to be reconsructed via decoders (use a subset of keys from `dataset`). + input_layers: which data modalities from `dataset` to encode (use a + subset of keys from `dataset`) + output_layers: which data modalities are aimed to be reconsructed via + decoders (use a subset of keys from `dataset`). """ @@ -297,14 +299,19 @@ def training_step(self, train_batch, batch_idx, log=True): log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. Returns: - torch.Tensor: The total loss for the current training batch, combining MMD loss for latent space regularization, - reconstruction losses for each output layer, and losses from supervisor heads for specified target variables. - - This method processes the batch by encoding input features from specified layers, decoding them to reconstruct the - output layers, and calculating the Maximum Mean Discrepancy (MMD) loss for latent space regularization. It computes - the reconstruction loss for each target/output layer. Additional losses are computed for other target variables in the - dataset, particularly handling survival analysis if applicable. All losses are aggregated to compute a total loss, - which is logged and returned. + torch.Tensor: The total loss for the current training batch, + combining MMD loss for latent space regularization, + reconstruction losses for each output layer, and losses from + supervisor heads for specified target variables. + + This method processes the batch by encoding input features from + specified layers, decoding them to reconstruct the output layers, and + calculating the Maximum Mean Discrepancy (MMD) loss for latent space + regularization. It computes the reconstruction loss for each + target/output layer. Additional losses are computed for other target + variables in the dataset, particularly handling survival analysis if + applicable. All losses are aggregated to compute a total loss, which is + logged and returned. """ dat, y_dict, samples = train_batch @@ -344,23 +351,29 @@ def training_step(self, train_batch, batch_idx, log=True): def validation_step(self, val_batch, batch_idx, log=True): """ - Executes one validation step using a single batch of data, assessing the model's performance on the validation set. + Executes one validation step using a single batch of data, assessing + the model's performance on the validation set. Args: - val_batch (tuple): The batch data containing input features and target labels for validation. + val_batch (tuple): The batch data containing input features and + target labels for validation. batch_idx (int): The index of the current batch in the validation process. log (bool, optional): Indicates whether to log the validation losses during this step. Defaults to True. Returns: - torch.Tensor: The total loss for the current validation batch, calculated by combining MMD loss, reconstruction - losses, and losses from supervisor heads for specified target variables. - - In this method, the model processes input data by encoding it through specified input layers and decoding it to - targeted output layers. It computes the Maximum Mean Discrepancy (MMD) loss to measure the divergence between - the model's latent representations and a predefined distribution, along with reconstruction losses for output layers. - Additionally, it calculates losses for other target variables in the dataset, handling complex scenarios like survival - analysis where applicable. The aggregated losses are then summed up to form the total validation loss, which is logged - and returned. + torch.Tensor: The total loss for the current validation batch, + calculated by combining MMD loss, reconstruction losses, and + losses from supervisor heads for specified target variables. + + In this method, the model processes input data by encoding it through + specified input layers and decoding it to targeted output layers. It + computes the Maximum Mean Discrepancy (MMD) loss to measure the + divergence between the model's latent representations and a predefined + distribution, along with reconstruction losses for output layers. + Additionally, it calculates losses for other target variables in the + dataset, handling complex scenarios like survival analysis if applicable. + The aggregated losses are then summed up to form the total validation + loss, which is logged and returned. """ dat, y_dict, samples = val_batch @@ -546,7 +559,8 @@ def compute_feature_importance( batch_size=64, ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes the feature importance for each variable in the dataset + using either Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. diff --git a/flexynesis/models/direct_pred.py b/flexynesis/models/direct_pred.py index 7cbb8fa7..d2cd12b7 100644 --- a/flexynesis/models/direct_pred.py +++ b/flexynesis/models/direct_pred.py @@ -111,7 +111,8 @@ def forward(self, x_list): x_list (list of torch.Tensor): A list of input matrices (omics layers), one for each layer. Returns: - dict: A dictionary where each key-value pair corresponds to the target variable name and its predicted output respectively. + dict: A dictionary where each key-value pair corresponds to the + target variable name and its predicted output respectively. """ embeddings_list = [] # Process each input matrix with its corresponding Encoder @@ -436,7 +437,8 @@ def compute_feature_importance( batch_size=64, ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes feature importance for each variable in the dataset using either + Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index 6c17bc59..858d5d3f 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -46,7 +46,8 @@ class GNN(pl.LightningModule): surv_event_var (str, optional): The variable name representing survival events. surv_time_var (str, optional): The variable name representing survival times. use_loss_weighting (bool, optional): Whether to use uncertainty weighting in loss calculation. Defaults to True. - device_type (str, optional): Specifies the computation device ('gpu' or 'cpu'). Default is None, which uses 'cpu' if 'gpu' is not available. + device_type (str, optional): Specifies the computation device ('gpu' or 'cpu'). Default is None, which uses + 'cpu' if 'gpu' is not available. gnn_conv_type (str, optional): Specifies the type of graph convolutional layer to use. """ @@ -160,7 +161,8 @@ def training_step(self, batch, batch_idx, log=True): Performs a training step including loss calculation and logging. Args: - batch (tuple): A batch of data consisting of features, target labels as a dictionary of tensors, and sample ids. + batch (tuple): A batch of data consisting of features, target labels + as a dictionary of tensors, and sample ids. Returns: float: Total loss for the batch. @@ -198,7 +200,8 @@ def validation_step(self, batch, batch_idx, log=True): Performs a validation step, computing losses for a batch of data. Args: - batch (tuple): A batch of data consisting of features, target labels as a dictionary of tensors, and sample ids. + batch (tuple): A batch of data consisting of features, target labels as a + dictionary of tensors, and sample ids. Returns: float: Total validation loss for the batch. @@ -442,7 +445,8 @@ def compute_feature_importance( batch_size=64, ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes the feature importance for each variable in the dataset + using either Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. diff --git a/flexynesis/models/supervised_vae.py b/flexynesis/models/supervised_vae.py index b6efda6e..24f45981 100644 --- a/flexynesis/models/supervised_vae.py +++ b/flexynesis/models/supervised_vae.py @@ -570,7 +570,8 @@ def compute_feature_importance( batch_size=64, ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes feature importance for each variable in the dataset using either + Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 6928c134..4c8a4564 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -277,17 +277,22 @@ def training_step(self, train_batch, batch_idx, log=True): Perform a training step using a single batch of data, including triplet components and target labels. Args: - train_batch (tuple): The batch containing data tuples (anchor, positive, negative) and a dictionary of labels. + train_batch (tuple): The batch containing data tuples (anchor, positive, + negative) and a dictionary of labels. batch_idx (int): The index of the current batch. - log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. + log (bool, optional): Flag to determine if logging should occur at each + step. Defaults to True. Returns: - torch.Tensor: The total loss for the current training batch, which includes triplet loss and any additional - losses from supervisor heads. - - This method computes the embedding for the anchor, positive, and negative samples and calculates the triplet loss. - Additional losses are computed for other target variables in the dataset, particularly handling survival analysis - if applicable. All losses are combined to compute a total loss, which is logged and returned. + torch.Tensor: The total loss for the current training batch, which + includes triplet loss and any additional losses from supervisor + heads. + + This method computes the embedding for the anchor, positive, and negative + samples and calculates the triplet loss. Additional losses are computed for + other target variables in the dataset, particularly handling survival + analysis if applicable. All losses are combined to compute a total loss, + which is logged and returned. """ anchor, positive, negative, y_dict = ( train_batch[0], @@ -328,17 +333,22 @@ def validation_step(self, val_batch, batch_idx, log=True): Perform a validation step using a single batch of data, including triplet components and target labels. Args: - val_batch (tuple): The batch containing data tuples (anchor, positive, negative) and a dictionary of labels. + val_batch (tuple): The batch containing data tuples (anchor, positive, + negative) and a dictionary of labels. batch_idx (int): The index of the current batch. - log (bool, optional): Flag to determine if logging should occur at each step. Defaults to True. + log (bool, optional): Flag to determine if logging should occur at each + step. Defaults to True. Returns: - torch.Tensor: The total loss for the current validation batch, which includes triplet loss and any additional - losses from supervisor heads. - - Similar to the training step, this method computes the embedding for the anchor, positive, and negative samples - and calculates the triplet loss. It computes additional losses for other target variables in the dataset, aggregates - all losses, and returns the total loss. The losses are logged if specified. + torch.Tensor: The total loss for the current validation batch, which + includes triplet loss and any additional losses from supervisor + heads. + + Similar to the training step, this method computes the embedding for the + anchor, positive, and negative samples and calculates the triplet loss. It + computes additional losses for other target variables in the dataset, + aggregates all losses, and returns the total loss. The losses are logged if + specified. """ anchor, positive, negative, y_dict = ( val_batch[0], @@ -400,7 +410,8 @@ def predict(self, dataset): dataset: The dataset to evaluate the model on. Returns: - A dictionary where each key is a target variable and the corresponding value is the predicted output for that variable. + A dictionary where each key is a target variable and the corresponding + value is the predicted output for that variable. """ self.eval() # get anchor embedding @@ -465,7 +476,8 @@ def compute_feature_importance( batch_size=64, ): """ - Computes the feature importance for each variable in the dataset using either Integrated Gradients or Gradient SHAP. + Computes the feature importance for each variable in the dataset using + either Integrated Gradients or Gradient SHAP. Args: dataset: The dataset object containing the features and data. diff --git a/flexynesis/modules.py b/flexynesis/modules.py index 32cdaa8a..c1d2722d 100644 --- a/flexynesis/modules.py +++ b/flexynesis/modules.py @@ -173,11 +173,14 @@ class flexGCN(nn.Module): output_dim (int): The size of the output vector, which is the final feature vector for the whole graph. num_convs (int, optional): Number of convolutional layers in the network. Defaults to 2. dropout_rate (float, optional): The dropout probability used for regularization. Defaults to 0.2. - conv (str, optional): Type of convolution layer to use. Supported types include 'GCN' for Graph Convolution Network, - 'GAT' for Graph Attention Network, 'SAGE' for GraphSAGE, and 'GC' for generic Graph Convolution. + conv (str, optional): Type of convolution layer to use. Supported types + include 'GCN' for Graph Convolution Network, + 'GAT' for Graph Attention Network, 'SAGE' for + GraphSAGE, and 'GC' for generic Graph Convolution. Defaults to 'GC'. - act (str, optional): Type of activation function to use. Supported types include 'relu', 'sigmoid', - 'leakyrelu', 'tanh', and 'gelu'. Defaults to 'relu'. + act (str, optional): Type of activation function to use. Supported types + include 'relu', 'sigmoid', 'leakyrelu', 'tanh', + and 'gelu'. Defaults to 'relu'. Raises: ValueError: If an unsupported activation function or convolution type is specified. diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 3066b9a0..df8784f8 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -245,7 +245,8 @@ def plot_kaplan_meier_curves(durations, events, categorical_variable): def plot_scatter(true_values, predicted_values): """ - Plots a scatterplot of true vs predicted values, with a regression line and annotated with the Pearson correlation coefficient. + Plots a scatterplot of true vs predicted values, with a regression line and + annotated with the Pearson correlation coefficient. Args: true_values (list or np.array): True values @@ -414,10 +415,12 @@ def bootstrap_metric(y_true, y_pred, indices_list, metric_fn, ci=95, **kwargs): def evaluate_classifier(y_true, y_probs, print_report=False): """ - Evaluate the performance of a classifier using multiple metrics and optionally print a detailed classification report. + Evaluate the performance of a classifier using multiple metrics and optionally + print a detailed classification report. - This function computes balanced accuracy, F1 score (weighted), Cohen's Kappa score, average AUROC score, and - weighted-average AUC-PR score for the given true labels and predicted probabilities. + This function computes balanced accuracy, F1 score (weighted), Cohen's Kappa + score, average AUROC score, and weighted-average AUC-PR score for the given + true labels and predicted probabilities. If `print_report` is set to True, it prints a detailed classification report. Args: @@ -596,21 +599,28 @@ def plot_pr_curves(y_true, y_probs): def evaluate_regressor(y_true, y_pred): """ - Evaluate the performance of a regression model using mean squared error, R-squared, and Pearson correlation coefficient. + Evaluate the performance of a regression model using mean squared error, + R-squared, and Pearson correlation coefficient. - This function computes the mean squared error (MSE) between true and predicted values as a measure of prediction accuracy. - It also performs a linear regression analysis between the true and predicted values to obtain the R-squared value, which - explains the variance ratio, and the Pearson correlation coefficient, providing insight into the linear relationship strength. + This function computes the mean squared error (MSE) between true and predicted + values as a measure of prediction accuracy. It also performs a linear regression + analysis between the true and predicted values to obtain the R-squared value, + which explains the variance ratio, and the Pearson correlation coefficient, + providing insight into the linear relationship strength. Args: - y_true (array-like): True values of the dependent variable, must be a 1D list or array. - y_pred (array-like): Predicted values as returned by a regressor, must match the dimensions of y_true. + y_true (array-like): True values of the dependent variable, must be a 1D + list or array. + y_pred (array-like): Predicted values as returned by a regressor, must + match the dimensions of y_true. Returns: dict: A dictionary containing: - 'mse': The mean squared error between the true and predicted values. - - 'r2': The R-squared value indicating the proportion of variance in the dependent variable predictable from the independent variable. - - 'pearson_corr': The Pearson correlation coefficient indicating the linear relationship strength between the true and predicted values. + - 'r2': The R-squared value indicating the proportion of variance in + the dependent variable predictable from the independent variable. + - 'pearson_corr': The Pearson correlation coefficient indicating the + linear relationship strength between the true and predicted values. """ mse = mean_squared_error(y_true, y_pred) slope, intercept, r_value, p_value, std_err = linregress(y_true, y_pred) @@ -622,22 +632,30 @@ def evaluate_wrapper( method, y_pred_dict, dataset, surv_event_var=None, surv_time_var=None ): """ - Evaluates predictions for different variables within a dataset using appropriate metrics based on the variable type. - Supports evaluation for numerical, categorical, and survival data. + Evaluates predictions for different variables within a dataset using + appropriate metrics based on the variable type. Supports evaluation for + numerical, categorical, and survival data. - This function loops through each variable in the predictions dictionary, determines the type of the variable, - and evaluates the predictions using the appropriate method: regression, classification, or survival analysis. - It compiles the metrics into a list of dictionaries, which is then converted into a pandas DataFrame. + This function loops through each variable in the predictions dictionary, + determines the type of the variable, and evaluates the predictions using the + appropriate method: regression, classification, or survival analysis. It + compiles the metrics into a list of dictionaries, which is then converted into + a pandas DataFrame. Args: method (str): Identifier for the prediction method or model used. - y_pred_dict (dict): A dictionary where keys are variable names and values are arrays of predicted values. - dataset (Dataset): A dataset object containing actual values and metadata such as variable types. - surv_event_var (str, optional): The name of the survival event variable. Required if survival analysis is performed. - surv_time_var (str, optional): The name of the survival time variable. Required if survival analysis is performed. + y_pred_dict (dict): A dictionary where keys are variable names and values + are arrays of predicted values. + dataset (Dataset): A dataset object containing actual values and metadata + such as variable types. + surv_event_var (str, optional): The name of the survival event variable. + Required if survival analysis is performed. + surv_time_var (str, optional): The name of the survival time variable. + Required if survival analysis is performed. Returns: - pd.DataFrame: A DataFrame where each row contains the method, variable name, variable type, metric name, and metric value. + pd.DataFrame: A DataFrame where each row contains the method, variable + name, variable type, metric name, and metric value. """ metrics_list = [] @@ -776,7 +794,8 @@ def evaluate_baseline_performance( n_components=100, ): """ - Evaluates the performance of machine learning models on a given variable with optional PCA for dimensionality reduction. + Evaluates the performance of machine learning models on a given variable with + optional PCA for dimensionality reduction. Args: train_dataset (Dataset): A MultiOmicDataset object containing training data and metadata. @@ -919,23 +938,29 @@ def evaluate_baseline_survival_performance( train_dataset, test_dataset, duration_col, event_col, n_folds=5, n_jobs=4 ): """ - Evaluates the baseline performance of a Random Survival Forest model on survival data using the Concordance Index. + Evaluates the baseline performance of a Random Survival Forest model on + survival data using the Concordance Index. - The function preprocesses both training and testing datasets to prepare appropriate survival data (comprising durations - and event occurrences), performs cross-validation to assess model robustness, and then calculates the Concordance Index on - the test data. It uses a Random Survival Forest (RSF) as the predictive model. + The function preprocesses both training and testing datasets to prepare + appropriate survival data (comprising durations and event occurrences), performs + cross-validation to assess model robustness, and then calculates the Concordance + Index on the test data. It uses a Random Survival Forest (RSF) as the + predictive model. Args: train_dataset (Dataset): The training dataset (a MultiOmicDataset object) containing features and survival data. test_dataset (Dataset): The testing dataset (a MultiOmicDataset object) containing features and survival data. duration_col (str): Column name in the dataset for survival time. event_col (str): Column name in the dataset for the event occurrence (1 if event occurred, 0 otherwise). - n_folds (int, optional): Number of folds for K-fold cross-validation. Defaults to 5. - n_jobs (int, optional): Number of parallel jobs to run for Random Survival Forest training. Defaults to 4. + n_folds (int, optional): Number of folds for K-fold cross-validation. + Defaults to 5. + n_jobs (int, optional): Number of parallel jobs to run for Random Survival + Forest training. Defaults to 4. Returns: - pd.DataFrame: A DataFrame containing the performance metrics of the RSF model, specifically the Concordance Index, - listed along with the method name and variable details. + pd.DataFrame: A DataFrame containing the performance metrics of the RSF + model, specifically the Concordance Index, listed along with the method + name and variable details. """ print("[INFO] Evaluating baseline survival prediction performance") @@ -1618,13 +1643,16 @@ def create_covariate_matrix(covariates, variable_types, ann): Missing values in numerical variables are imputed using the median. Args: - covariates (list of str): List of variable names that must exist in the "clin.csv". - variable_types (dict): Dictionary mapping variable names to their types ('categorical' or 'numerical'). + covariates (list of str): List of variable names that must exist in the + "clin.csv". + variable_types (dict): Dictionary mapping variable names to their types + ('categorical' or 'numerical'). ann (pd.DataFrame): Annotation DataFrame containing batch variable values. Returns: - pd.DataFrame: A covariate matrix DataFrame where categorical variables are one-hot-encoded as 0/1 and numerical variables are imputed, - with features as rows and samples as columns. + pd.DataFrame: A covariate matrix DataFrame where categorical variables are + one-hot-encoded as 0/1 and numerical variables are imputed, with + features as rows and samples as columns. """ covariate_features = [] feature_names = [] @@ -1693,8 +1721,10 @@ def optimal_transport_align(embeddings, batch_labels, standardize_by_labels=Fals - batch_labels (np.ndarray or pd.Series): Batch labels corresponding to the rows of embeddings. Returns: - - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned embeddings for all samples, with original indices preserved. - - aligned_batch_labels (pd.Series): A Series containing the corresponding batch labels for the aligned embeddings. + - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned + embeddings for all samples, with original indices preserved. + - aligned_batch_labels (pd.Series): A Series containing the corresponding + batch labels for the aligned embeddings. """ # Ensure batch labels are a NumPy array batch_labels_np = np.array(batch_labels) @@ -1772,8 +1802,10 @@ def reciprocal_pca_mnn( - random_state (int, optional): Random seed for reproducibility. Returns: - - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned embeddings for all samples, with original indices preserved. - - aligned_batch_labels (pd.Series): A Series containing the corresponding batch labels for the aligned embeddings. + - aligned_embeddings (pd.DataFrame): A DataFrame containing the aligned + embeddings for all samples, with original indices preserved. + - aligned_batch_labels (pd.Series): A Series containing the corresponding + batch labels for the aligned embeddings. """ # Ensure batch labels are a NumPy array batch_labels_np = np.array(batch_labels) @@ -1955,7 +1987,9 @@ def get_cbioportal_data(self, study_id, files=None): if files is None: self.print_data_files() print( - "\n\nPlease select a list of files to import. Example:\n get_cbioportal_data('study_id', files={'mut': 'data_mutations.txt', 'clin': 'data_clinical_patient.txt'})" + "\n\nPlease select a list of files to import. Example:\n " + "get_cbioportal_data('study_id', files={'mut': 'data_mutations.txt', " + "'clin': 'data_clinical_patient.txt'})" ) return From f686a8d276105bb39a05439f4d4742cee1f59334 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 3 Apr 2026 20:25:13 +0200 Subject: [PATCH 22/32] set max len of line to 120 in CI --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index b6b8fabc..db21baf5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -41,4 +41,4 @@ jobs: run: isort --check-only --diff . - name: Run flake8 - run: flake8 . + run: flake8 . --max-line-length=120 From db1cf9dbfb4b82d647118ceedbd59aa207701488 Mon Sep 17 00:00:00 2001 From: Amirhossein Nilchi <66441226+nilchia@users.noreply.github.com> Date: Fri, 3 Apr 2026 20:47:53 +0200 Subject: [PATCH 23/32] Update .github/workflows/lint.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index db21baf5..74e0155f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -15,7 +15,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.10', '3.11'] + python-version: ['3.11', '3.12'] steps: - uses: actions/checkout@v4 with: From 99b272ae6e299525b6919f96ef33936a6e747e80 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 9 Apr 2026 22:50:55 +0200 Subject: [PATCH 24/32] add removed libraries --- flexynesis/__main__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index 0ced9436..fce06a86 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -956,14 +956,19 @@ def main(): if hasattr(test_dataset, "to_device"): test_dataset.to_device(device) print(f"[INFO] Test dataset loaded: {len(test_dataset.samples)} samples") - # Continue to evaluation section (skip training) + + # Import evaluation utilities needed in the shared evaluation section below + import pandas as pd # noqa: F401 + + from .utils import evaluate_wrapper, get_predicted_labels # noqa: F401 + + # Continue to evaluation section (skip training) # ------------- Heavy imports only when training or in inference with evaluation --- if not (args.pretrained_model and args.artifacts and args.data_path_test): import json # noqa: F401 import tracemalloc # noqa: F401 - import pandas as pd # noqa: F401 import torch # noqa: F401 from safetensors.torch import save_file # noqa: F401 @@ -979,8 +984,7 @@ def main(): from .models.triplet_encoder import MultiTripletNetwork # noqa: F401 from .utils import evaluate_baseline_performance # noqa: F401 from .utils import (evaluate_baseline_survival_performance, - evaluate_wrapper, get_device_memory_info, - get_optimal_device, get_predicted_labels) + get_device_memory_info, get_optimal_device) # --------- Sanity checks on args --------- # 1. survival variables consistency From 1498d2246b441c3ba67b9d31d9e94cacaf84f898 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 9 Apr 2026 23:04:18 +0200 Subject: [PATCH 25/32] string is imported from data.py --- flexynesis/__main__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index fce06a86..f0b3488b 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -929,8 +929,7 @@ def main(): # Convert to GNN dataset if needed if args.model_class == "GNN": print("[INFO] Overlaying the dataset with network data from STRINGDB") - from .data import MultiOmicDatasetNW - from .main import STRING + from .data import MultiOmicDatasetNW, STRING # Get STRING organism from artifacts string_organism = importer.artifacts.get( From afa6d778ea1cba9bf85ed7cf45b8d92e1e9ba1ef Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 9 Apr 2026 23:07:28 +0200 Subject: [PATCH 26/32] isort fix --- flexynesis/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexynesis/__main__.py b/flexynesis/__main__.py index f0b3488b..ecb5fe69 100644 --- a/flexynesis/__main__.py +++ b/flexynesis/__main__.py @@ -929,7 +929,7 @@ def main(): # Convert to GNN dataset if needed if args.model_class == "GNN": print("[INFO] Overlaying the dataset with network data from STRINGDB") - from .data import MultiOmicDatasetNW, STRING + from .data import STRING, MultiOmicDatasetNW # Get STRING organism from artifacts string_organism = importer.artifacts.get( From 17edca97fd2d4210573a1144a9828c16bde7f997 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Thu, 9 Apr 2026 23:54:36 +0200 Subject: [PATCH 27/32] fix mps memory test --- tests/test_mps_device.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_mps_device.py b/tests/test_mps_device.py index e24c3411..ca431cf4 100644 --- a/tests/test_mps_device.py +++ b/tests/test_mps_device.py @@ -45,8 +45,10 @@ def test_mps_memory_allocation(): if device_str != "mps" or not torch.backends.mps.is_available(): pytest.skip("MPS device not available. Skipping test.") + device = torch.device(device_str) # Test memory tracking memory_before = torch.mps.current_allocated_memory() + large_tensor = to_device_safe(torch.randn(1000, 1000), device) # noqa: F841 memory_after = torch.mps.current_allocated_memory() assert memory_after > memory_before, "Memory did not increase after allocation." From 8c12ab05adad83ce63587bc9b67e972f1b1df83c Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Fri, 10 Apr 2026 17:02:19 +0200 Subject: [PATCH 28/32] use coad_cptac_2019 --- examples/tutorials/cbioportal.ipynb | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/tutorials/cbioportal.ipynb b/examples/tutorials/cbioportal.ipynb index a800b030..fa3885bc 100644 --- a/examples/tutorials/cbioportal.ipynb +++ b/examples/tutorials/cbioportal.ipynb @@ -13,7 +13,7 @@ "import numpy as np\n", "import pandas as pd\n", "\n", - "STUDY_ID = 'pancan_pcawg_2020'" + "STUDY_ID = 'coad_cptac_2019'" ] }, { @@ -144,7 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "counts = np.unique(list(data['clin'].CANCER_TYPE), return_counts=True) " + "counts = np.unique(list(data['clin'].CANCER_TYPE), return_counts=True)" ] }, { @@ -274,8 +274,8 @@ "metadata": {}, "outputs": [], "source": [ - "data_importer = flexynesis.data.DataImporter(path=f'{STUDY_ID}/', \n", - " data_types = ['mut', 'cna'], \n", + "data_importer = flexynesis.data.DataImporter(path=f'{STUDY_ID}/',\n", + " data_types = ['mut', 'cna'],\n", " concatenate=False, top_percentile=10, variance_threshold=0.8,\n", " min_features=500)\n", "train_dataset, test_dataset = data_importer.import_data()" @@ -324,7 +324,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "flexy", "language": "python", "name": "python3" }, @@ -338,7 +338,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.14.4" } }, "nbformat": 4, From 154bfbec1bb339e36a52b1d63c85bbe65c27f487 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Sun, 12 Apr 2026 23:15:28 +0200 Subject: [PATCH 29/32] increase cbioportal timeout and check if the files are correctly downloaded and unzipped --- examples/tutorials/cbioportal.ipynb | 18 ++++++++++++++++++ flexynesis/utils.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/examples/tutorials/cbioportal.ipynb b/examples/tutorials/cbioportal.ipynb index fa3885bc..f489ff80 100644 --- a/examples/tutorials/cbioportal.ipynb +++ b/examples/tutorials/cbioportal.ipynb @@ -37,6 +37,24 @@ "cbio.get_cbioportal_data(cbio.study_id)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "b08670ec", + "metadata": {}, + "outputs": [], + "source": [ + "# Check the specific files the notebook needs are present\n", + "required_files = [\"data_cna.txt\", \"data_mutations.txt\", \"data_clinical_sample.txt\"]\n", + "missing = [f for f in required_files if not os.path.exists(os.path.join(STUDY_ID, f))]\n", + "assert not missing, f\"Missing required files in {STUDY_ID}: {missing}\"\n", + "\n", + "print(f\"All required files present in ./{STUDY_ID}/:\")\n", + "for f in required_files:\n", + " size = os.path.getsize(os.path.join(STUDY_ID, f))\n", + " print(f\" {f} ({size:,} bytes)\")" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/flexynesis/utils.py b/flexynesis/utils.py index df8784f8..92ca75f0 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -1896,7 +1896,7 @@ def __init__(self, study_id, base_url="https://datahub.assets.cbioportal.org"): self.data_files = None self.data = None - def download_study_archive(self, force=False, timeout=60): + def download_study_archive(self, force=False, timeout=120): url = f"{self.base_url}/{self.study_id}.tar.gz" dest_file = f"{self.study_id}.tar.gz" From 87631a93019ebfd55b10694e10f6c07d1392ca58 Mon Sep 17 00:00:00 2001 From: Amirhossein Naghsh Nilchi Date: Mon, 13 Apr 2026 14:48:07 +0200 Subject: [PATCH 30/32] bring back the removed extract.archive --- flexynesis/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flexynesis/utils.py b/flexynesis/utils.py index 92ca75f0..c08a76cb 100644 --- a/flexynesis/utils.py +++ b/flexynesis/utils.py @@ -1983,6 +1983,8 @@ def print_data_files(self): print(df.to_string(index=False)) def get_cbioportal_data(self, study_id, files=None): + archive_path = self.download_study_archive() + self.extract_archive(archive_path) if files is None: self.print_data_files() From f33b98b12657ddc92f9efe0e7b5fd1d5ba945451 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 19 Apr 2026 12:10:59 +0200 Subject: [PATCH 31/32] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flexynesis/models/gnn_early.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flexynesis/models/gnn_early.py b/flexynesis/models/gnn_early.py index 858d5d3f..32b1fee1 100644 --- a/flexynesis/models/gnn_early.py +++ b/flexynesis/models/gnn_early.py @@ -462,7 +462,7 @@ def compute_feature_importance( """ def bytes_to_gb(bytes): - return bytes / 1024**2 + return bytes / 1024**3 from ..utils import create_device_from_string, to_device_safe From 4007e3571e67adcc71172e7bc7d93de3db78f5f9 Mon Sep 17 00:00:00 2001 From: Bora Uyar Date: Sun, 19 Apr 2026 12:33:21 +0200 Subject: [PATCH 32/32] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- flexynesis/models/triplet_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flexynesis/models/triplet_encoder.py b/flexynesis/models/triplet_encoder.py index 4c8a4564..4ba8b60b 100644 --- a/flexynesis/models/triplet_encoder.py +++ b/flexynesis/models/triplet_encoder.py @@ -456,8 +456,8 @@ def forward_target(self, input_data, layer_sizes, target_var, steps): # convert to dict anchor = {k: anchor[k] for k in range(len(anchor))} - positive = {k: anchor[k] for k in range(len(positive))} - negative = {k: anchor[k] for k in range(len(negative))} + positive = {k: positive[k] for k in range(len(positive))} + negative = {k: negative[k] for k in range(len(negative))} ( anchor_embedding, positive_embedding,